/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.tree.impl;

import java.io.Serializable;
import java.util.HashMap;
import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.LogEntry;
import org.apache.spark.internal.LogEntry$;
import org.apache.spark.internal.LogKey;
import org.apache.spark.internal.LogKeys;
import org.apache.spark.internal.Logging;
import org.apache.spark.internal.MDC;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.tree.Split;
import org.apache.spark.ml.tree.impl.BaggedPoint;
import org.apache.spark.ml.tree.impl.BaggedPoint$;
import org.apache.spark.ml.tree.impl.DecisionTreeMetadata;
import org.apache.spark.ml.tree.impl.DecisionTreeMetadata$;
import org.apache.spark.ml.tree.impl.RandomForest$;
import org.apache.spark.ml.tree.impl.TimeTracker;
import org.apache.spark.ml.tree.impl.TreePoint;
import org.apache.spark.ml.tree.impl.TreePoint$;
import org.apache.spark.ml.util.Instrumentation;
import org.apache.spark.mllib.tree.configuration.Algo$;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Variance$;
import org.apache.spark.mllib.tree.loss.Loss;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer;
import org.apache.spark.storage.StorageLevel$;
import org.slf4j.Logger;
import scala.Array$;
import scala.Enumeration;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.ArrayOps$;
import scala.collection.IterableOnce;
import scala.collection.Iterator;
import scala.collection.immutable.Seq;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.java8.JFunction1;
import scala.runtime.java8.JFunction2;

public final class GradientBoostedTrees$
implements Logging {
    public static final GradientBoostedTrees$ MODULE$ = new GradientBoostedTrees$();
    private static transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        Logging.$init$((Logging)MODULE$);
    }

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public Logging.LogStringContext LogStringContext(StringContext sc) {
        return Logging.LogStringContext$((Logging)this, (StringContext)sc);
    }

    public void withLogContext(HashMap<String, String> context, Function0<BoxedUnit> body) {
        Logging.withLogContext$((Logging)this, context, body);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logInfo(LogEntry entry) {
        Logging.logInfo$((Logging)this, (LogEntry)entry);
    }

    public void logInfo(LogEntry entry, Throwable throwable) {
        Logging.logInfo$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logDebug(LogEntry entry) {
        Logging.logDebug$((Logging)this, (LogEntry)entry);
    }

    public void logDebug(LogEntry entry, Throwable throwable) {
        Logging.logDebug$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logTrace(LogEntry entry) {
        Logging.logTrace$((Logging)this, (LogEntry)entry);
    }

    public void logTrace(LogEntry entry, Throwable throwable) {
        Logging.logTrace$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logWarning(LogEntry entry) {
        Logging.logWarning$((Logging)this, (LogEntry)entry);
    }

    public void logWarning(LogEntry entry, Throwable throwable) {
        Logging.logWarning$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logError(LogEntry entry) {
        Logging.logError$((Logging)this, (LogEntry)entry);
    }

    public void logError(LogEntry entry, Throwable throwable) {
        Logging.logError$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        org$apache$spark$internal$Logging$$log_ = x$1;
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> run(RDD<Instance> input, BoostingStrategy boostingStrategy, long seed, String featureSubsetStrategy, Option<Instrumentation> instr) {
        Enumeration.Value algo;
        Enumeration.Value value = algo = boostingStrategy.treeStrategy().algo();
        Enumeration.Value value2 = Algo$.MODULE$.Regression();
        Enumeration.Value value3 = value;
        if (!(value2 != null ? !value2.equals(value3) : value3 != null)) {
            return this.boost(input, input, boostingStrategy, false, seed, featureSubsetStrategy, instr);
        }
        Enumeration.Value value4 = Algo$.MODULE$.Classification();
        Enumeration.Value value5 = value;
        if (!(value4 != null ? !value4.equals(value5) : value5 != null)) {
            RDD remappedInput = input.map((Function1 & Serializable)x -> new Instance(x.label() * (double)2 - 1.0, x.weight(), x.features()), ClassTag$.MODULE$.apply(Instance.class));
            return this.boost((RDD<Instance>)remappedInput, (RDD<Instance>)remappedInput, boostingStrategy, false, seed, featureSubsetStrategy, instr);
        }
        throw new IllegalArgumentException(algo + " is not supported by gradient boosting.");
    }

    public Option<Instrumentation> run$default$5() {
        return None$.MODULE$;
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> runWithValidation(RDD<Instance> input, RDD<Instance> validationInput, BoostingStrategy boostingStrategy, long seed, String featureSubsetStrategy, Option<Instrumentation> instr) {
        Enumeration.Value algo;
        Enumeration.Value value = algo = boostingStrategy.treeStrategy().algo();
        Enumeration.Value value2 = Algo$.MODULE$.Regression();
        Enumeration.Value value3 = value;
        if (!(value2 != null ? !value2.equals(value3) : value3 != null)) {
            return this.boost(input, validationInput, boostingStrategy, true, seed, featureSubsetStrategy, instr);
        }
        Enumeration.Value value4 = Algo$.MODULE$.Classification();
        Enumeration.Value value5 = value;
        if (!(value4 != null ? !value4.equals(value5) : value5 != null)) {
            RDD remappedInput = input.map((Function1 & Serializable)x -> new Instance(x.label() * (double)2 - 1.0, x.weight(), x.features()), ClassTag$.MODULE$.apply(Instance.class));
            RDD remappedValidationInput = validationInput.map((Function1 & Serializable)x -> new Instance(x.label() * (double)2 - 1.0, x.weight(), x.features()), ClassTag$.MODULE$.apply(Instance.class));
            return this.boost((RDD<Instance>)remappedInput, (RDD<Instance>)remappedValidationInput, boostingStrategy, true, seed, featureSubsetStrategy, instr);
        }
        throw new IllegalArgumentException(algo + " is not supported by the gradient boosting.");
    }

    public Option<Instrumentation> runWithValidation$default$6() {
        return None$.MODULE$;
    }

    public RDD<Tuple2<Object, Object>> computeInitialPredictionAndError(RDD<TreePoint> data, double initTreeWeight, DecisionTreeRegressionModel initTree, Loss loss, Broadcast<Split[][]> bcSplits) {
        return data.map((Function1 & Serializable)treePoint -> {
            double pred = MODULE$.updatePrediction((TreePoint)treePoint, 0.0, initTree, initTreeWeight, (Split[][])bcSplits.value());
            double error = loss.computeError(pred, treePoint.label());
            return new Tuple2.mcDD.sp(pred, error);
        }, ClassTag$.MODULE$.apply(Tuple2.class));
    }

    public RDD<Tuple2<Object, Object>> updatePredictionError(RDD<TreePoint> data, RDD<Tuple2<Object, Object>> predictionAndError, double treeWeight, DecisionTreeRegressionModel tree, Loss loss, Broadcast<Split[][]> bcSplits) {
        return data.zip(predictionAndError, ClassTag$.MODULE$.apply(Tuple2.class)).map((Function1 & Serializable)x0$1 -> {
            Tuple2 tuple2 = x0$1;
            if (tuple2 != null) {
                TreePoint treePoint = (TreePoint)tuple2._1();
                Tuple2 tuple22 = (Tuple2)tuple2._2();
                if (tuple22 != null) {
                    double pred = tuple22._1$mcD$sp();
                    double newPred = MODULE$.updatePrediction(treePoint, pred, tree, treeWeight, (Split[][])bcSplits.value());
                    double newError = loss.computeError(newPred, treePoint.label());
                    return new Tuple2.mcDD.sp(newPred, newError);
                }
            }
            throw new MatchError((Object)tuple2);
        }, ClassTag$.MODULE$.apply(Tuple2.class));
    }

    public double updatePrediction(TreePoint treePoint, double prediction, DecisionTreeRegressionModel tree, double weight, Split[][] splits) {
        return prediction + tree.rootNode().predictBinned(treePoint.binnedFeatures(), splits).prediction() * weight;
    }

    public double updatePrediction(Vector features, double prediction, DecisionTreeRegressionModel tree, double weight) {
        return prediction + tree.rootNode().predictImpl(features).prediction() * weight;
    }

    public double computeWeightedError(RDD<Instance> data, DecisionTreeRegressionModel[] trees, double[] treeWeights, Loss loss) {
        int x$2;
        Function2 & Serializable x$1;
        RDD qual$1 = data.map((Function1 & Serializable)x0$1 -> {
            Instance instance = x0$1;
            if (instance != null) {
                double label = instance.label();
                double weight = instance.weight();
                Vector features = instance.features();
                double predicted = BoxesRunTime.unboxToDouble((Object)ArrayOps$.MODULE$.foldLeft$extension(Predef$.MODULE$.refArrayOps((Object[])ArrayOps$.MODULE$.zip$extension(Predef$.MODULE$.refArrayOps((Object[])trees), (IterableOnce)Predef$.MODULE$.wrapDoubleArray(treeWeights))), (Object)BoxesRunTime.boxToDouble((double)0.0), (Function2 & Serializable)(x0$2, x1$1) -> BoxesRunTime.boxToDouble((double)GradientBoostedTrees$.$anonfun$computeWeightedError$2(features, BoxesRunTime.unboxToDouble((Object)x0$2), x1$1))));
                return new Tuple2.mcDD.sp(loss.computeError(predicted, label) * weight, weight);
            }
            throw new MatchError((Object)instance);
        }, ClassTag$.MODULE$.apply(Tuple2.class));
        Tuple2 tuple2 = (Tuple2)qual$1.treeReduce((Function2)(x$1 = (Function2 & Serializable)(x0$3, x1$2) -> {
            Tuple2 tuple2 = new Tuple2(x0$3, x1$2);
            if (tuple2 != null) {
                Tuple2 tuple22 = (Tuple2)tuple2._1();
                Tuple2 tuple23 = (Tuple2)tuple2._2();
                if (tuple22 != null) {
                    double err1 = tuple22._1$mcD$sp();
                    double weight1 = tuple22._2$mcD$sp();
                    if (tuple23 != null) {
                        double err2 = tuple23._1$mcD$sp();
                        double weight2 = tuple23._2$mcD$sp();
                        return new Tuple2.mcDD.sp(err1 + err2, weight1 + weight2);
                    }
                }
            }
            throw new MatchError((Object)tuple2);
        }), x$2 = qual$1.treeReduce$default$2());
        if (tuple2 == null) {
            throw new MatchError((Object)tuple2);
        }
        double errSum = tuple2._1$mcD$sp();
        double weightSum = tuple2._2$mcD$sp();
        Tuple2.mcDD.sp sp2 = new Tuple2.mcDD.sp(errSum, weightSum);
        double errSum2 = sp2._1$mcD$sp();
        double weightSum2 = sp2._2$mcD$sp();
        return errSum2 / weightSum2;
    }

    public double computeWeightedError(RDD<TreePoint> data, RDD<Tuple2<Object, Object>> predError) {
        int x$2;
        Function2 & Serializable x$1;
        RDD qual$1 = data.zip(predError, ClassTag$.MODULE$.apply(Tuple2.class)).map((Function1 & Serializable)x0$1 -> {
            Tuple2 tuple2 = x0$1;
            if (tuple2 != null) {
                TreePoint treePoint = (TreePoint)tuple2._1();
                Tuple2 tuple22 = (Tuple2)tuple2._2();
                if (tuple22 != null) {
                    double err = tuple22._2$mcD$sp();
                    return new Tuple2.mcDD.sp(err * treePoint.weight(), treePoint.weight());
                }
            }
            throw new MatchError((Object)tuple2);
        }, ClassTag$.MODULE$.apply(Tuple2.class));
        Tuple2 tuple2 = (Tuple2)qual$1.treeReduce((Function2)(x$1 = (Function2 & Serializable)(x0$2, x1$1) -> {
            Tuple2 tuple2 = new Tuple2(x0$2, x1$1);
            if (tuple2 != null) {
                Tuple2 tuple22 = (Tuple2)tuple2._1();
                Tuple2 tuple23 = (Tuple2)tuple2._2();
                if (tuple22 != null) {
                    double err1 = tuple22._1$mcD$sp();
                    double weight1 = tuple22._2$mcD$sp();
                    if (tuple23 != null) {
                        double err2 = tuple23._1$mcD$sp();
                        double weight2 = tuple23._2$mcD$sp();
                        return new Tuple2.mcDD.sp(err1 + err2, weight1 + weight2);
                    }
                }
            }
            throw new MatchError((Object)tuple2);
        }), x$2 = qual$1.treeReduce$default$2());
        if (tuple2 == null) {
            throw new MatchError((Object)tuple2);
        }
        double errSum = tuple2._1$mcD$sp();
        double weightSum = tuple2._2$mcD$sp();
        Tuple2.mcDD.sp sp2 = new Tuple2.mcDD.sp(errSum, weightSum);
        double errSum2 = sp2._1$mcD$sp();
        double weightSum2 = sp2._2$mcD$sp();
        return errSum2 / weightSum2;
    }

    public double[] evaluateEachIteration(RDD<Instance> data, DecisionTreeRegressionModel[] trees, double[] treeWeights, Loss loss, Enumeration.Value algo) {
        int x$2;
        Function2 & Serializable x$1;
        int numTrees;
        Enumeration.Value value = algo;
        Enumeration.Value value2 = Algo$.MODULE$.Classification();
        Enumeration.Value value3 = value;
        RDD remappedData = !(value2 != null ? !value2.equals(value3) : value3 != null) ? data.map((Function1 & Serializable)x -> new Instance(x.label() * (double)2 - 1.0, x.weight(), x.features()), ClassTag$.MODULE$.apply(Instance.class)) : data;
        RDD qual$1 = remappedData.mapPartitions(arg_0 -> GradientBoostedTrees$.$anonfun$evaluateEachIteration$2(numTrees = trees.length, trees, treeWeights, loss, arg_0), remappedData.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class));
        Tuple2 tuple2 = (Tuple2)qual$1.treeReduce((Function2)(x$1 = (Function2 & Serializable)(x0$2, x1$1) -> {
            Tuple2 tuple2 = new Tuple2(x0$2, x1$1);
            if (tuple2 != null) {
                Tuple2 tuple22 = (Tuple2)tuple2._1();
                Tuple2 tuple23 = (Tuple2)tuple2._2();
                if (tuple22 != null) {
                    double[] err1 = (double[])tuple22._1();
                    double weight1 = tuple22._2$mcD$sp();
                    if (tuple23 != null) {
                        double[] err2 = (double[])tuple23._1();
                        double weight2 = tuple23._2$mcD$sp();
                        RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), numTrees).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable)i -> {
                            err1$1[i] = err1[i] + err2[i];
                        });
                        return new Tuple2((Object)err1, (Object)BoxesRunTime.boxToDouble((double)(weight1 + weight2)));
                    }
                }
            }
            throw new MatchError((Object)tuple2);
        }), x$2 = qual$1.treeReduce$default$2());
        if (tuple2 == null) {
            throw new MatchError((Object)tuple2);
        }
        double[] errSum = (double[])tuple2._1();
        double weightSum = tuple2._2$mcD$sp();
        Tuple2 tuple22 = new Tuple2((Object)errSum, (Object)BoxesRunTime.boxToDouble((double)weightSum));
        double[] errSum2 = (double[])tuple22._1();
        double weightSum2 = tuple22._2$mcD$sp();
        return (double[])ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.doubleArrayOps(errSum2), (Function1)(JFunction1.mcDD.sp & Serializable)x$6 -> x$6 / weightSum2, (ClassTag)ClassTag$.MODULE$.Double());
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> boost(RDD<Instance> input, RDD<Instance> validationInput, BoostingStrategy boostingStrategy, boolean validate, long seed, String featureSubsetStrategy, Option<Instrumentation> instr) {
        RDD firstBagged;
        TimeTracker timer = new TimeTracker();
        timer.start("total");
        timer.start("init");
        SparkContext sc = input.sparkContext();
        boostingStrategy.assertValid();
        int numIterations = boostingStrategy.numIterations();
        DecisionTreeRegressionModel[] baseLearners = new DecisionTreeRegressionModel[numIterations];
        double[] baseLearnerWeights = new double[numIterations];
        Loss loss = boostingStrategy.loss();
        double learningRate = boostingStrategy.learningRate();
        Strategy treeStrategy = boostingStrategy.treeStrategy().copy();
        double validationTol = boostingStrategy.validationTol();
        treeStrategy.algo_$eq(Algo$.MODULE$.Regression());
        treeStrategy.impurity_$eq(Variance$.MODULE$);
        Predef$.MODULE$.require(!treeStrategy.bootstrap(), (Function0 & Serializable)() -> "GradientBoostedTrees does not need bootstrap sampling");
        treeStrategy.assertValid();
        PeriodicRDDCheckpointer predErrorCheckpointer = new PeriodicRDDCheckpointer(treeStrategy.getCheckpointInterval(), sc, StorageLevel$.MODULE$.MEMORY_AND_DISK());
        timer.stop("init");
        this.logDebug((Function0<String>)(Function0 & Serializable)() -> "##########");
        this.logDebug((Function0<String>)(Function0 & Serializable)() -> "Building tree 0");
        this.logDebug((Function0<String>)(Function0 & Serializable)() -> "##########");
        timer.start("building tree 0");
        RDD retaggedInput = input.retag(Instance.class);
        timer.start("buildMetadata");
        DecisionTreeMetadata metadata = DecisionTreeMetadata$.MODULE$.buildMetadata((RDD<Instance>)retaggedInput, treeStrategy, 1, featureSubsetStrategy);
        timer.stop("buildMetadata");
        timer.start("findSplits");
        Split[][] splits = RandomForest$.MODULE$.findSplits((RDD<Instance>)retaggedInput, metadata, seed);
        timer.stop("findSplits");
        Broadcast bcSplits = sc.broadcast((Object)splits, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ScalaRunTime$.MODULE$.arrayClass(Split.class))));
        RDD treePoints = TreePoint$.MODULE$.convertToTreeRDD((RDD<Instance>)retaggedInput, splits, metadata).persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()).setName("binned tree points");
        RDD firstCounts = BaggedPoint$.MODULE$.convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate(), 1, treeStrategy.bootstrap(), (Function1 & Serializable)tp -> BoxesRunTime.boxToDouble((double)tp.weight()), seed).map((Function1 & Serializable)bagged -> BoxesRunTime.boxToInteger((int)GradientBoostedTrees$.$anonfun$boost$6(bagged)), (ClassTag)ClassTag$.MODULE$.Int()).persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()).setName("firstCounts at iter=0");
        RDD x$1 = firstBagged = treePoints.zip(firstCounts, (ClassTag)ClassTag$.MODULE$.Int()).map((Function1 & Serializable)x0$1 -> {
            Tuple2 tuple2 = x0$1;
            if (tuple2 != null) {
                TreePoint treePoint = (TreePoint)tuple2._1();
                int count = tuple2._2$mcI$sp();
                return new BaggedPoint<TreePoint>(treePoint, new int[]{count}, treePoint.weight());
            }
            throw new MatchError((Object)tuple2);
        }, ClassTag$.MODULE$.apply(BaggedPoint.class));
        DecisionTreeMetadata x$2 = metadata;
        Broadcast x$3 = bcSplits;
        Strategy x$4 = treeStrategy;
        boolean x$5 = true;
        String x$6 = featureSubsetStrategy;
        long x$7 = seed;
        Option<Instrumentation> x$8 = instr;
        None$ x$9 = None$.MODULE$;
        boolean x$10 = RandomForest$.MODULE$.runBagged$default$9();
        DecisionTreeRegressionModel firstTreeModel = (DecisionTreeRegressionModel)ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.refArrayOps((Object[])RandomForest$.MODULE$.runBagged((RDD<BaggedPoint<TreePoint>>)x$1, x$2, (Broadcast<Split[][]>)x$3, x$4, 1, x$6, x$7, x$8, x$10, (Option<String>)x$9)));
        firstCounts.unpersist(firstCounts.unpersist$default$1());
        double firstTreeWeight = 1.0;
        baseLearners[0] = firstTreeModel;
        baseLearnerWeights[0] = firstTreeWeight;
        ObjectRef predError = ObjectRef.create(this.computeInitialPredictionAndError((RDD<TreePoint>)treePoints, firstTreeWeight, firstTreeModel, loss, (Broadcast<Split[][]>)bcSplits));
        predErrorCheckpointer.update((Object)((RDD)predError.elem));
        this.logDebug((Function0<String>)(Function0 & Serializable)() -> "error of gbt = " + MODULE$.computeWeightedError((RDD<TreePoint>)treePoints, (RDD<Tuple2<Object, Object>>)((RDD)predError$1.elem)));
        timer.stop("building tree 0");
        RDD validationTreePoints = null;
        RDD<Tuple2<Object, Object>> validatePredError = null;
        PeriodicRDDCheckpointer validatePredErrorCheckpointer = null;
        double bestValidateError = 0.0;
        if (validate) {
            timer.start("init validation");
            validationTreePoints = TreePoint$.MODULE$.convertToTreeRDD((RDD<Instance>)validationInput.retag(Instance.class), splits, metadata).persist(StorageLevel$.MODULE$.MEMORY_AND_DISK());
            validatePredError = this.computeInitialPredictionAndError((RDD<TreePoint>)validationTreePoints, firstTreeWeight, firstTreeModel, loss, (Broadcast<Split[][]>)bcSplits);
            validatePredErrorCheckpointer = new PeriodicRDDCheckpointer(treeStrategy.getCheckpointInterval(), sc, StorageLevel$.MODULE$.MEMORY_AND_DISK());
            validatePredErrorCheckpointer.update(validatePredError);
            bestValidateError = this.computeWeightedError((RDD<TreePoint>)validationTreePoints, validatePredError);
            v0 = BoxesRunTime.boxToDouble((double)timer.stop("init validation"));
        } else {
            v0 = BoxedUnit.UNIT;
        }
        int bestM = 1;
        IntRef m = IntRef.create((int)1);
        boolean doneLearning = false;
        while (m.elem < numIterations && !doneLearning) {
            RDD bagged2;
            timer.start("building tree " + m.elem);
            this.logDebug((Function0<String>)(Function0 & Serializable)() -> "###################################################");
            this.logDebug((Function0<String>)(Function0 & Serializable)() -> "Gradient boosting tree iteration " + m$1.elem);
            this.logDebug((Function0<String>)(Function0 & Serializable)() -> "###################################################");
            RDD labelWithCounts = BaggedPoint$.MODULE$.convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate(), 1, treeStrategy.bootstrap(), (Function1 & Serializable)tp -> BoxesRunTime.boxToDouble((double)tp.weight()), seed + (long)m.elem).zip((RDD)predError.elem, ClassTag$.MODULE$.apply(Tuple2.class)).map((Function1 & Serializable)x0$2 -> {
                Tuple2 tuple2 = x0$2;
                if (tuple2 != null) {
                    BaggedPoint bagged = (BaggedPoint)tuple2._1();
                    Tuple2 tuple22 = (Tuple2)tuple2._2();
                    if (tuple22 != null) {
                        double pred = tuple22._1$mcD$sp();
                        Predef$.MODULE$.require(bagged.subsampleCounts().length == 1);
                        Predef$.MODULE$.require(bagged.sampleWeight() == ((TreePoint)bagged.datum()).weight());
                        double newLabel = -loss.gradient(pred, ((TreePoint)bagged.datum()).label());
                        return new Tuple2.mcDI.sp(newLabel, BoxesRunTime.unboxToInt((Object)ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.intArrayOps(bagged.subsampleCounts()))));
                    }
                }
                throw new MatchError((Object)tuple2);
            }, ClassTag$.MODULE$.apply(Tuple2.class)).persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()).setName("labelWithCounts at iter=" + m.elem);
            RDD x$11 = bagged2 = treePoints.zip(labelWithCounts, ClassTag$.MODULE$.apply(Tuple2.class)).map((Function1 & Serializable)x0$3 -> {
                Tuple2 tuple2 = x0$3;
                if (tuple2 != null) {
                    TreePoint treePoint = (TreePoint)tuple2._1();
                    Tuple2 tuple22 = (Tuple2)tuple2._2();
                    if (tuple22 != null) {
                        double newLabel = tuple22._1$mcD$sp();
                        int count = tuple22._2$mcI$sp();
                        TreePoint newTreePoint = new TreePoint(newLabel, treePoint.binnedFeatures(), treePoint.weight());
                        return new BaggedPoint<TreePoint>(newTreePoint, new int[]{count}, treePoint.weight());
                    }
                }
                throw new MatchError((Object)tuple2);
            }, ClassTag$.MODULE$.apply(BaggedPoint.class));
            DecisionTreeMetadata x$12 = metadata;
            Broadcast x$13 = bcSplits;
            Strategy x$14 = treeStrategy;
            boolean x$15 = true;
            String x$16 = featureSubsetStrategy;
            long x$17 = seed + (long)m.elem;
            None$ x$18 = None$.MODULE$;
            None$ x$19 = None$.MODULE$;
            boolean x$20 = RandomForest$.MODULE$.runBagged$default$9();
            DecisionTreeRegressionModel model = (DecisionTreeRegressionModel)ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.refArrayOps((Object[])RandomForest$.MODULE$.runBagged((RDD<BaggedPoint<TreePoint>>)x$11, x$12, (Broadcast<Split[][]>)x$13, x$14, 1, x$16, x$17, (Option<Instrumentation>)x$18, x$20, (Option<String>)x$19)));
            labelWithCounts.unpersist(labelWithCounts.unpersist$default$1());
            timer.stop("building tree " + m.elem);
            baseLearners[m.elem] = model;
            baseLearnerWeights[m.elem] = learningRate;
            predError.elem = this.updatePredictionError((RDD<TreePoint>)treePoints, (RDD<Tuple2<Object, Object>>)((RDD)predError.elem), baseLearnerWeights[m.elem], baseLearners[m.elem], loss, (Broadcast<Split[][]>)bcSplits);
            predErrorCheckpointer.update((Object)((RDD)predError.elem));
            this.logDebug((Function0<String>)(Function0 & Serializable)() -> "error of gbt = " + MODULE$.computeWeightedError((RDD<TreePoint>)treePoints, (RDD<Tuple2<Object, Object>>)((RDD)predError$1.elem)));
            if (validate) {
                validatePredError = this.updatePredictionError((RDD<TreePoint>)validationTreePoints, validatePredError, baseLearnerWeights[m.elem], baseLearners[m.elem], loss, (Broadcast<Split[][]>)bcSplits);
                validatePredErrorCheckpointer.update(validatePredError);
                double currentValidateError = this.computeWeightedError((RDD<TreePoint>)validationTreePoints, validatePredError);
                if (bestValidateError - currentValidateError < validationTol * Math.max(currentValidateError, 0.01)) {
                    doneLearning = true;
                } else if (currentValidateError < bestValidateError) {
                    bestValidateError = currentValidateError;
                    bestM = m.elem + 1;
                }
            }
            ++m.elem;
        }
        timer.stop("total");
        this.logInfo((Function0<String>)(Function0 & Serializable)() -> "Internal timing for DecisionTree:");
        this.logInfo(LogEntry$.MODULE$.from((Function0 & Serializable)() -> MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"", ""}))).log((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new MDC[]{new MDC((LogKey)LogKeys.TIMER$.MODULE$, (Object)timer)}))));
        bcSplits.destroy();
        treePoints.unpersist(treePoints.unpersist$default$1());
        predErrorCheckpointer.unpersistDataSet();
        predErrorCheckpointer.deleteAllCheckpoints();
        if (validate) {
            RDD qual$1 = validationTreePoints;
            boolean x$21 = qual$1.unpersist$default$1();
            qual$1.unpersist(x$21);
            validatePredErrorCheckpointer.unpersistDataSet();
            validatePredErrorCheckpointer.deleteAllCheckpoints();
        }
        if (validate) {
            return new Tuple2(ArrayOps$.MODULE$.slice$extension(Predef$.MODULE$.refArrayOps((Object[])baseLearners), 0, bestM), ArrayOps$.MODULE$.slice$extension(Predef$.MODULE$.doubleArrayOps(baseLearnerWeights), 0, bestM));
        }
        return new Tuple2((Object)baseLearners, (Object)baseLearnerWeights);
    }

    public Option<Instrumentation> boost$default$7() {
        return None$.MODULE$;
    }

    public static final /* synthetic */ double $anonfun$computeWeightedError$2(Vector features$1, double x0$2, Tuple2 x1$1) {
        Tuple2 tuple2 = new Tuple2((Object)BoxesRunTime.boxToDouble((double)x0$2), (Object)x1$1);
        if (tuple2 != null) {
            double acc = tuple2._1$mcD$sp();
            Tuple2 tuple22 = (Tuple2)tuple2._2();
            if (tuple22 != null) {
                DecisionTreeRegressionModel model = (DecisionTreeRegressionModel)tuple22._1();
                double weight = tuple22._2$mcD$sp();
                return MODULE$.updatePrediction(features$1, acc, model, weight);
            }
        }
        throw new MatchError((Object)tuple2);
    }

    public static final /* synthetic */ Iterator $anonfun$evaluateEachIteration$2(int numTrees$1, DecisionTreeRegressionModel[] trees$2, double[] treeWeights$2, Loss loss$4, Iterator iter) {
        return iter.map((Function1 & Serializable)x0$1 -> {
            Instance instance = x0$1;
            if (instance != null) {
                double label = instance.label();
                double weight = instance.weight();
                Vector features = instance.features();
                double[] pred = (double[])Array$.MODULE$.tabulate(numTrees$1, (Function1)(JFunction1.mcDI.sp & Serializable)i -> trees$2[i].rootNode().predictImpl(features).prediction() * treeWeights$2[i], (ClassTag)ClassTag$.MODULE$.Double());
                double[] err = (double[])ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.doubleArrayOps((double[])ArrayOps$.MODULE$.drop$extension(Predef$.MODULE$.doubleArrayOps((double[])ArrayOps$.MODULE$.scanLeft$extension(Predef$.MODULE$.doubleArrayOps(pred), (Object)BoxesRunTime.boxToDouble((double)0.0), (Function2)(JFunction2.mcDDD.sp & Serializable)(x$3, x$4) -> x$3 + x$4, (ClassTag)ClassTag$.MODULE$.Double())), 1)), (Function1)(JFunction1.mcDD.sp & Serializable)p -> loss$4.computeError(p, label) * weight, (ClassTag)ClassTag$.MODULE$.Double());
                return new Tuple2((Object)err, (Object)BoxesRunTime.boxToDouble((double)weight));
            }
            throw new MatchError((Object)instance);
        });
    }

    public static final /* synthetic */ int $anonfun$boost$6(BaggedPoint bagged) {
        Predef$.MODULE$.require(bagged.subsampleCounts().length == 1);
        Predef$.MODULE$.require(bagged.sampleWeight() == ((TreePoint)bagged.datum()).weight());
        return BoxesRunTime.unboxToInt((Object)ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.intArrayOps(bagged.subsampleCounts())));
    }

    private GradientBoostedTrees$() {
    }
}

