/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.rtree;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayDeque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SplittableRandom;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.common.tree.AbstractCARTTrainer;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.TreeModel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.rtree.IndependentRegressionTreeModel;
import org.tribuo.regression.rtree.impl.RegressorTrainingNode;
import org.tribuo.regression.rtree.impurity.MeanSquaredError;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;
import org.tribuo.util.Util;

public final class CARTRegressionTrainer
extends AbstractCARTTrainer<Regressor> {
    @Config(description="Regression impurity measure used to determine split quality.")
    private RegressorImpurity impurity = new MeanSquaredError();

    public CARTRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, RegressorImpurity impurity, long seed) {
        super(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, useRandomSplitPoints, seed);
        this.impurity = impurity;
        this.postConfig();
    }

    public CARTRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, RegressorImpurity impurity, long seed) {
        this(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, false, impurity, seed);
    }

    public CARTRegressionTrainer() {
        this(Integer.MAX_VALUE);
    }

    public CARTRegressionTrainer(int maxDepth) {
        this(maxDepth, 5.0f, 0.0f, 1.0f, false, new MeanSquaredError(), 12345L);
    }

    protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        throw new IllegalStateException("Shouldn't reach here.");
    }

    public TreeModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
        return this.train(examples, runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public TreeModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        CARTRegressionTrainer cARTRegressionTrainer = this;
        synchronized (cARTRegressionTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            localRNG = this.rng.split();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = examples.getOutputIDInfo();
        Set domain = outputIDInfo.getDomain();
        int numFeaturesInSplit = Math.min(Math.round(this.fractionFeaturesInSplit * (float)featureIDMap.size()), featureIDMap.size());
        int[] originalIndices = new int[featureIDMap.size()];
        for (int i = 0; i < originalIndices.length; ++i) {
            originalIndices[i] = i;
        }
        int[] indices = numFeaturesInSplit != featureIDMap.size() ? new int[numFeaturesInSplit] : originalIndices;
        float weightSum = 0.0f;
        for (Example e : examples) {
            weightSum += e.getWeight();
        }
        float scaledMinImpurityDecrease = this.getMinImpurityDecrease() * weightSum;
        AbstractTrainingNode.LeafDeterminer leafDeterminer = new AbstractTrainingNode.LeafDeterminer(this.maxDepth, this.minChildWeight, scaledMinImpurityDecrease);
        RegressorTrainingNode.InvertedData data = RegressorTrainingNode.invertData(examples);
        HashMap<String, Node<Regressor>> nodeMap = new HashMap<String, Node<Regressor>>();
        for (Regressor r : domain) {
            String dimName = r.getNames()[0];
            int dimIdx = outputIDInfo.getID((Output)r);
            RegressorTrainingNode root = new RegressorTrainingNode(this.impurity, data, dimIdx, dimName, examples.size(), featureIDMap, (ImmutableOutputInfo<Regressor>)outputIDInfo, leafDeterminer);
            ArrayDeque<AbstractTrainingNode> queue = new ArrayDeque<AbstractTrainingNode>();
            queue.add(root);
            while (!queue.isEmpty()) {
                AbstractTrainingNode node = (AbstractTrainingNode)queue.poll();
                if (!(node.getImpurity() > 0.0) || node.getDepth() >= this.maxDepth || !(node.getWeightSum() >= this.minChildWeight)) continue;
                if (numFeaturesInSplit != featureIDMap.size()) {
                    Util.randpermInPlace((int[])originalIndices, (SplittableRandom)localRNG);
                    System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit);
                }
                List nodes = node.buildTree(indices, localRNG, this.getUseRandomSplitPoints());
                for (AbstractTrainingNode newNode : nodes) {
                    queue.addFirst(newNode);
                }
            }
            nodeMap.put(dimName, (Node<Regressor>)root.convertTree());
        }
        ModelProvenance provenance = new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        return new IndependentRegressionTreeModel("cart-tree", provenance, featureIDMap, (ImmutableOutputInfo<Regressor>)outputIDInfo, false, nodeMap);
    }

    public String toString() {
        StringBuilder buffer = new StringBuilder();
        buffer.append("CARTRegressionTrainer(maxDepth=");
        buffer.append(this.maxDepth);
        buffer.append(",minChildWeight=");
        buffer.append(this.minChildWeight);
        buffer.append(",minImpurityDecrease=");
        buffer.append(this.minImpurityDecrease);
        buffer.append(",fractionFeaturesInSplit=");
        buffer.append(this.fractionFeaturesInSplit);
        buffer.append(",useRandomSplitPoints=");
        buffer.append(this.useRandomSplitPoints);
        buffer.append(",impurity=");
        buffer.append(this.impurity.toString());
        buffer.append(",seed=");
        buffer.append(this.seed);
        buffer.append(")");
        return buffer.toString();
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }
}

