/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.tree.impl;

import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.ml.tree.impl.TimeTracker;
import org.apache.spark.mllib.tree.configuration.Algo$;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Variance$;
import org.apache.spark.mllib.tree.loss.Loss;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import org.slf4j.Logger;
import scala.Array$;
import scala.Enumeration;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.IterableLike;
import scala.collection.TraversableLike;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Range;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.ScalaRunTime$;
import scala.runtime.java8.JFunction1;
import scala.runtime.java8.JFunction2;

public final class GradientBoostedTrees$
implements Logging {
    public static GradientBoostedTrees$ MODULE$;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        new GradientBoostedTrees$();
    }

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> run(RDD<LabeledPoint> input, BoostingStrategy boostingStrategy, long seed, String featureSubsetStrategy) {
        Tuple2<DecisionTreeRegressionModel[], double[]> tuple2;
        Enumeration.Value algo;
        Enumeration.Value value = algo = boostingStrategy.treeStrategy().algo();
        Enumeration.Value value2 = Algo$.MODULE$.Regression();
        Enumeration.Value value3 = value;
        if (!(value2 != null ? !value2.equals(value3) : value3 != null)) {
            tuple2 = this.boost(input, input, boostingStrategy, false, seed, featureSubsetStrategy);
        } else {
            Enumeration.Value value4 = Algo$.MODULE$.Classification();
            Enumeration.Value value5 = value;
            if (!(value4 != null ? !value4.equals(value5) : value5 != null)) {
                RDD remappedInput = input.map((Function1 & Serializable & scala.Serializable)x -> new LabeledPoint(x.label() * (double)2 - 1.0, x.features()), ClassTag$.MODULE$.apply(LabeledPoint.class));
                tuple2 = this.boost((RDD<LabeledPoint>)remappedInput, (RDD<LabeledPoint>)remappedInput, boostingStrategy, false, seed, featureSubsetStrategy);
            } else {
                throw new IllegalArgumentException(new StringBuilder(39).append(algo).append(" is not supported by gradient boosting.").toString());
            }
        }
        return tuple2;
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> runWithValidation(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput, BoostingStrategy boostingStrategy, long seed, String featureSubsetStrategy) {
        Tuple2<DecisionTreeRegressionModel[], double[]> tuple2;
        Enumeration.Value algo;
        Enumeration.Value value = algo = boostingStrategy.treeStrategy().algo();
        Enumeration.Value value2 = Algo$.MODULE$.Regression();
        Enumeration.Value value3 = value;
        if (!(value2 != null ? !value2.equals(value3) : value3 != null)) {
            tuple2 = this.boost(input, validationInput, boostingStrategy, true, seed, featureSubsetStrategy);
        } else {
            Enumeration.Value value4 = Algo$.MODULE$.Classification();
            Enumeration.Value value5 = value;
            if (!(value4 != null ? !value4.equals(value5) : value5 != null)) {
                RDD remappedInput = input.map((Function1 & Serializable & scala.Serializable)x -> new LabeledPoint(x.label() * (double)2 - 1.0, x.features()), ClassTag$.MODULE$.apply(LabeledPoint.class));
                RDD remappedValidationInput = validationInput.map((Function1 & Serializable & scala.Serializable)x -> new LabeledPoint(x.label() * (double)2 - 1.0, x.features()), ClassTag$.MODULE$.apply(LabeledPoint.class));
                tuple2 = this.boost((RDD<LabeledPoint>)remappedInput, (RDD<LabeledPoint>)remappedValidationInput, boostingStrategy, true, seed, featureSubsetStrategy);
            } else {
                throw new IllegalArgumentException(new StringBuilder(43).append(algo).append(" is not supported by the gradient boosting.").toString());
            }
        }
        return tuple2;
    }

    public RDD<Tuple2<Object, Object>> computeInitialPredictionAndError(RDD<LabeledPoint> data, double initTreeWeight, DecisionTreeRegressionModel initTree, Loss loss) {
        return data.map((Function1 & Serializable & scala.Serializable)lp -> {
            double pred = MODULE$.updatePrediction(lp.features(), 0.0, initTree, initTreeWeight);
            double error = loss.computeError(pred, lp.label());
            return new Tuple2.mcDD.sp(pred, error);
        }, ClassTag$.MODULE$.apply(Tuple2.class));
    }

    public RDD<Tuple2<Object, Object>> updatePredictionError(RDD<LabeledPoint> data, RDD<Tuple2<Object, Object>> predictionAndError, double treeWeight, DecisionTreeRegressionModel tree, Loss loss) {
        RDD qual$1 = data.zip(predictionAndError, ClassTag$.MODULE$.apply(Tuple2.class));
        Function1 & Serializable & scala.Serializable x$1 = (Function1 & Serializable & scala.Serializable)iter -> iter.map((Function1 & Serializable & scala.Serializable)x0$1 -> {
            Tuple2 tuple2;
            LabeledPoint lp;
            block3: {
                Tuple2 tuple22;
                block2: {
                    tuple22 = x0$1;
                    if (tuple22 == null) break block2;
                    lp = (LabeledPoint)tuple22._1();
                    tuple2 = (Tuple2)tuple22._2();
                    if (tuple2 != null) break block3;
                }
                throw new MatchError((Object)tuple22);
            }
            double pred = tuple2._1$mcD$sp();
            double newPred = MODULE$.updatePrediction(lp.features(), pred, tree, treeWeight);
            double newError = loss.computeError(newPred, lp.label());
            Tuple2.mcDD.sp sp2 = new Tuple2.mcDD.sp(newPred, newError);
            return sp2;
        });
        boolean x$2 = qual$1.mapPartitions$default$2();
        RDD newPredError = qual$1.mapPartitions((Function1)x$1, x$2, ClassTag$.MODULE$.apply(Tuple2.class));
        return newPredError;
    }

    public double updatePrediction(Vector features, double prediction, DecisionTreeRegressionModel tree, double weight) {
        return prediction + tree.rootNode().predictImpl(features).prediction() * weight;
    }

    public double computeError(RDD<LabeledPoint> data, DecisionTreeRegressionModel[] trees, double[] treeWeights, Loss loss) {
        return RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(data.map((Function1 & Serializable & scala.Serializable)lp -> BoxesRunTime.boxToDouble((double)GradientBoostedTrees$.$anonfun$computeError$1(trees, treeWeights, loss, lp)), ClassTag$.MODULE$.Double())).mean();
    }

    public double[] evaluateEachIteration(RDD<LabeledPoint> data, DecisionTreeRegressionModel[] trees, double[] treeWeights, Loss loss, Enumeration.Value algo) {
        SparkContext sc = data.sparkContext();
        Enumeration.Value value = algo;
        Enumeration.Value value2 = Algo$.MODULE$.Classification();
        Enumeration.Value value3 = value;
        RDD rDD = !(value2 != null ? !value2.equals(value3) : value3 != null) ? data.map((Function1 & Serializable & scala.Serializable)x -> new LabeledPoint(x.label() * (double)2 - 1.0, x.features()), ClassTag$.MODULE$.apply(LabeledPoint.class)) : data;
        RDD remappedData = rDD;
        Broadcast broadcastTrees = sc.broadcast((Object)trees, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(DecisionTreeRegressionModel.class)));
        double[] localTreeWeights = treeWeights;
        Range treesIndices = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])trees)).indices();
        long dataCount = remappedData.count();
        IndexedSeq evaluation = (IndexedSeq)((TraversableLike)remappedData.map((Function1 & Serializable & scala.Serializable)point -> (IndexedSeq)((TraversableLike)((IterableLike)((TraversableLike)treesIndices.map((Function1)(JFunction1.mcDI.sp & Serializable & scala.Serializable)idx -> {
            double prediction = ((DecisionTreeRegressionModel[])broadcastTrees.value())[idx].rootNode().predictImpl(point.features()).prediction();
            return prediction * localTreeWeights[idx];
        }, IndexedSeq$.MODULE$.canBuildFrom())).scanLeft((Object)BoxesRunTime.boxToDouble((double)0.0), (Function2)(JFunction2.mcDDD.sp & Serializable & scala.Serializable)(x$1, x$2) -> x$1 + x$2, IndexedSeq$.MODULE$.canBuildFrom())).drop(1)).map((Function1)(JFunction1.mcDD.sp & Serializable & scala.Serializable)prediction -> loss.computeError(prediction, point.label()), IndexedSeq$.MODULE$.canBuildFrom()), ClassTag$.MODULE$.apply(IndexedSeq.class)).aggregate(treesIndices.map((Function1)(JFunction1.mcDI.sp & Serializable & scala.Serializable)x$3 -> 0.0, IndexedSeq$.MODULE$.canBuildFrom()), (Function2 & Serializable & scala.Serializable)(aggregated, row) -> (IndexedSeq)treesIndices.map((Function1)(JFunction1.mcDI.sp & Serializable & scala.Serializable)idx -> BoxesRunTime.unboxToDouble((Object)aggregated.apply(idx)) + BoxesRunTime.unboxToDouble((Object)row.apply(idx)), IndexedSeq$.MODULE$.canBuildFrom()), (Function2 & Serializable & scala.Serializable)(a, b) -> (IndexedSeq)treesIndices.map((Function1)(JFunction1.mcDI.sp & Serializable & scala.Serializable)idx -> BoxesRunTime.unboxToDouble((Object)a.apply(idx)) + BoxesRunTime.unboxToDouble((Object)b.apply(idx)), IndexedSeq$.MODULE$.canBuildFrom()), ClassTag$.MODULE$.apply(IndexedSeq.class))).map((Function1)(JFunction1.mcDD.sp & Serializable & scala.Serializable)x$4 -> x$4 / (double)dataCount, IndexedSeq$.MODULE$.canBuildFrom());
        broadcastTrees.destroy(false);
        return (double[])evaluation.toArray(ClassTag$.MODULE$.Double());
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> boost(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput, BoostingStrategy boostingStrategy, boolean validate, long seed, String featureSubsetStrategy) {
        boolean bl;
        TimeTracker timer = new TimeTracker();
        timer.start("total");
        timer.start("init");
        boostingStrategy.assertValid();
        int numIterations = boostingStrategy.numIterations();
        DecisionTreeRegressionModel[] baseLearners = new DecisionTreeRegressionModel[numIterations];
        double[] baseLearnerWeights = new double[numIterations];
        Loss loss = boostingStrategy.loss();
        double learningRate = boostingStrategy.learningRate();
        Strategy treeStrategy = boostingStrategy.treeStrategy().copy();
        double validationTol = boostingStrategy.validationTol();
        treeStrategy.algo_$eq(Algo$.MODULE$.Regression());
        treeStrategy.impurity_$eq(Variance$.MODULE$);
        treeStrategy.assertValid();
        StorageLevel storageLevel = input.getStorageLevel();
        StorageLevel storageLevel2 = StorageLevel$.MODULE$.NONE();
        if (!(storageLevel != null ? !storageLevel.equals(storageLevel2) : storageLevel2 != null)) {
            input.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK());
            bl = true;
        } else {
            bl = false;
        }
        boolean persistedInput = bl;
        PeriodicRDDCheckpointer predErrorCheckpointer = new PeriodicRDDCheckpointer(treeStrategy.getCheckpointInterval(), input.sparkContext());
        PeriodicRDDCheckpointer validatePredErrorCheckpointer = new PeriodicRDDCheckpointer(treeStrategy.getCheckpointInterval(), input.sparkContext());
        timer.stop("init");
        this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "##########");
        this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "Building tree 0");
        this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "##########");
        timer.start("building tree 0");
        DecisionTreeRegressor firstTree = new DecisionTreeRegressor().setSeed(seed);
        DecisionTreeRegressionModel firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy);
        double firstTreeWeight = 1.0;
        baseLearners[0] = firstTreeModel;
        baseLearnerWeights[0] = firstTreeWeight;
        ObjectRef predError = ObjectRef.create(this.computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss));
        predErrorCheckpointer.update((Object)((RDD)predError.elem));
        this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(15).append("error of gbt = ").append(RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions((RDD)predError$1.elem, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), (Ordering)Ordering.Double$.MODULE$).values()).mean()).toString());
        timer.stop("building tree 0");
        RDD<Tuple2<Object, Object>> validatePredError = this.computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss);
        if (validate) {
            validatePredErrorCheckpointer.update(validatePredError);
        }
        double bestValidateError = validate ? RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(validatePredError, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), (Ordering)Ordering.Double$.MODULE$).values()).mean() : 0.0;
        int bestM = 1;
        IntRef m = IntRef.create((int)1);
        boolean doneLearning = false;
        while (m.elem < numIterations && !doneLearning) {
            RDD data = ((RDD)predError.elem).zip(input, ClassTag$.MODULE$.apply(LabeledPoint.class)).map((Function1 & Serializable & scala.Serializable)x0$1 -> {
                LabeledPoint point;
                Tuple2 tuple2;
                block3: {
                    Tuple2 tuple22;
                    block2: {
                        tuple22 = x0$1;
                        if (tuple22 == null) break block2;
                        tuple2 = (Tuple2)tuple22._1();
                        point = (LabeledPoint)tuple22._2();
                        if (tuple2 != null) break block3;
                    }
                    throw new MatchError((Object)tuple22);
                }
                double pred = tuple2._1$mcD$sp();
                LabeledPoint labeledPoint = new LabeledPoint(-loss.gradient(pred, point.label()), point.features());
                return labeledPoint;
            }, ClassTag$.MODULE$.apply(LabeledPoint.class));
            timer.start(new StringBuilder(14).append("building tree ").append(m.elem).toString());
            this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "###################################################");
            this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(33).append("Gradient boosting tree iteration ").append(m$1.elem).toString());
            this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "###################################################");
            DecisionTreeRegressor dt = new DecisionTreeRegressor().setSeed(seed + (long)m.elem);
            DecisionTreeRegressionModel model = dt.train((RDD<LabeledPoint>)data, treeStrategy, featureSubsetStrategy);
            timer.stop(new StringBuilder(14).append("building tree ").append(m.elem).toString());
            baseLearners[m.elem] = model;
            baseLearnerWeights[m.elem] = learningRate;
            predError.elem = this.updatePredictionError(input, (RDD<Tuple2<Object, Object>>)((RDD)predError.elem), baseLearnerWeights[m.elem], baseLearners[m.elem], loss);
            predErrorCheckpointer.update((Object)((RDD)predError.elem));
            this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(15).append("error of gbt = ").append(RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions((RDD)predError$1.elem, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), (Ordering)Ordering.Double$.MODULE$).values()).mean()).toString());
            if (validate) {
                validatePredError = this.updatePredictionError(validationInput, validatePredError, baseLearnerWeights[m.elem], baseLearners[m.elem], loss);
                validatePredErrorCheckpointer.update(validatePredError);
                double currentValidateError = RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(validatePredError, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), (Ordering)Ordering.Double$.MODULE$).values()).mean();
                if (bestValidateError - currentValidateError < validationTol * Math.max(currentValidateError, 0.01)) {
                    doneLearning = true;
                } else if (currentValidateError < bestValidateError) {
                    bestValidateError = currentValidateError;
                    bestM = m.elem + 1;
                }
            }
            ++m.elem;
        }
        timer.stop("total");
        this.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "Internal timing for DecisionTree:");
        this.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> String.valueOf(timer));
        predErrorCheckpointer.unpersistDataSet();
        predErrorCheckpointer.deleteAllCheckpoints();
        validatePredErrorCheckpointer.unpersistDataSet();
        validatePredErrorCheckpointer.deleteAllCheckpoints();
        Object object = persistedInput ? input.unpersist(input.unpersist$default$1()) : BoxedUnit.UNIT;
        return validate ? new Tuple2(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])baseLearners)).slice(0, bestM), new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(baseLearnerWeights)).slice(0, bestM)) : new Tuple2((Object)baseLearners, (Object)baseLearnerWeights);
    }

    public static final /* synthetic */ double $anonfun$computeError$2(LabeledPoint lp$1, double x0$1, Tuple2 x1$1) {
        Tuple2 tuple2;
        double acc;
        block3: {
            Tuple2 tuple22;
            block2: {
                tuple22 = new Tuple2((Object)BoxesRunTime.boxToDouble((double)x0$1), (Object)x1$1);
                if (tuple22 == null) break block2;
                acc = tuple22._1$mcD$sp();
                tuple2 = (Tuple2)tuple22._2();
                if (tuple2 != null) break block3;
            }
            throw new MatchError((Object)tuple22);
        }
        DecisionTreeRegressionModel model = (DecisionTreeRegressionModel)tuple2._1();
        double weight = tuple2._2$mcD$sp();
        double d = MODULE$.updatePrediction(lp$1.features(), acc, model, weight);
        return d;
    }

    public static final /* synthetic */ double $anonfun$computeError$1(DecisionTreeRegressionModel[] trees$1, double[] treeWeights$1, Loss loss$3, LabeledPoint lp) {
        double predicted = BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])trees$1)).zip((GenIterable)Predef$.MODULE$.wrapDoubleArray(treeWeights$1), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).foldLeft((Object)BoxesRunTime.boxToDouble((double)0.0), (Function2 & Serializable & scala.Serializable)(x0$1, x1$1) -> BoxesRunTime.boxToDouble((double)GradientBoostedTrees$.$anonfun$computeError$2(lp, BoxesRunTime.unboxToDouble((Object)x0$1), x1$1))));
        return loss$3.computeError(predicted, lp.label());
    }

    private GradientBoostedTrees$() {
        MODULE$ = this;
        Logging.$init$((Logging)this);
    }
}

