/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.rtree;

import com.oracle.labs.mlrg.olcut.config.Config;
import org.tribuo.Dataset;
import org.tribuo.Trainer;
import org.tribuo.common.tree.AbstractCARTTrainer;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.rtree.impl.JointRegressorTrainingNode;
import org.tribuo.regression.rtree.impurity.MeanSquaredError;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;

public class CARTJointRegressionTrainer
extends AbstractCARTTrainer<Regressor> {
    @Config(description="The regression impurity to use.")
    private RegressorImpurity impurity = new MeanSquaredError();
    @Config(description="Normalize the output of each leaf so it sums to one.")
    private boolean normalize = false;

    public CARTJointRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, RegressorImpurity impurity, boolean normalize, long seed) {
        super(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, useRandomSplitPoints, seed);
        this.impurity = impurity;
        this.normalize = normalize;
        this.postConfig();
    }

    public CARTJointRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, RegressorImpurity impurity, boolean normalize, long seed) {
        this(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, false, impurity, normalize, seed);
    }

    public CARTJointRegressionTrainer() {
        this(Integer.MAX_VALUE, 5.0f, 0.0f, 1.0f, false, new MeanSquaredError(), false, 12345L);
    }

    public CARTJointRegressionTrainer(int maxDepth) {
        this(maxDepth, 5.0f, 0.0f, 1.0f, false, new MeanSquaredError(), false, 12345L);
    }

    public CARTJointRegressionTrainer(int maxDepth, boolean normalize) {
        this(maxDepth, 5.0f, 0.0f, 1.0f, false, new MeanSquaredError(), normalize, 12345L);
    }

    protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        return new JointRegressorTrainingNode(this.impurity, examples, this.normalize, leafDeterminer);
    }

    public String toString() {
        StringBuilder buffer = new StringBuilder();
        buffer.append("CARTJointRegressionTrainer(maxDepth=");
        buffer.append(this.maxDepth);
        buffer.append(",minChildWeight=");
        buffer.append(this.minChildWeight);
        buffer.append(",minImpurityDecrease=");
        buffer.append(this.minImpurityDecrease);
        buffer.append(",fractionFeaturesInSplit=");
        buffer.append(this.fractionFeaturesInSplit);
        buffer.append(",useRandomSplitPoints=");
        buffer.append(this.useRandomSplitPoints);
        buffer.append(",impurity=");
        buffer.append(this.impurity.toString());
        buffer.append(",normalize=");
        buffer.append(this.normalize);
        buffer.append(",seed=");
        buffer.append(this.seed);
        buffer.append(")");
        return buffer.toString();
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }
}

