/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.tree.impl;

import java.io.Serializable;
import java.util.Map;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkFunSuite;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NominalAttribute$;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.ml.attribute.NumericAttribute$;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.tree.DecisionTreeModel;
import org.apache.spark.ml.tree.InternalNode;
import org.apache.spark.ml.tree.LeafNode;
import org.apache.spark.ml.tree.Node;
import org.apache.spark.ml.tree.Split;
import org.apache.spark.ml.tree.TreeEnsembleModel;
import org.apache.spark.ml.tree.impl.TreeTests$;
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.types.Metadata;
import org.scalactic.Bool;
import org.scalactic.Bool$;
import org.scalactic.Equality$;
import org.scalactic.Prettifier$;
import org.scalactic.TripleEqualsSupport;
import org.scalactic.source.Position;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.mutable.ArrayOps;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.reflect.api.JavaUniverse;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public final class TreeTests$
extends SparkFunSuite {
    public static TreeTests$ MODULE$;
    private final scala.collection.immutable.Map<String, Object> allParamSettings;

    static {
        new TreeTests$();
    }

    public Dataset<Row> setMetadata(RDD<LabeledPoint> data, scala.collection.immutable.Map<Object, Object> categoricalFeatures, int numClasses) {
        SparkSession spark = SparkSession$.MODULE$.builder().sparkContext(data.sparkContext()).getOrCreate();
        JavaUniverse $u = scala.reflect.runtime.package$.MODULE$.universe();
        JavaUniverse.JavaMirror $m = scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(this.getClass().getClassLoader());
        public final class Org_apache_spark_ml_tree_impl_TreeTests$$typecreator5$1
        extends TypeCreator {
            public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                Universe $u = $m$untyped.universe();
                Mirror<U> $m = $m$untyped;
                return $m.staticClass("org.apache.spark.ml.feature.LabeledPoint").asType().toTypeConstructor();
            }

            public Org_apache_spark_ml_tree_impl_TreeTests$$typecreator5$1() {
            }
        }
        Dataset df = spark.implicits().rddToDatasetHolder(data, spark.implicits().newProductEncoder(((TypeTags)$u).TypeTag().apply((Mirror)$m, (TypeCreator)new Org_apache_spark_ml_tree_impl_TreeTests$$typecreator5$1()))).toDF();
        int numFeatures = ((LabeledPoint)data.first()).features().size();
        Attribute[] featuresAttributes = (Attribute[])((TraversableOnce)package$.MODULE$.Range().apply(0, numFeatures).map((Function1 & Serializable & scala.Serializable)feature -> TreeTests$.$anonfun$setMetadata$1(categoricalFeatures, BoxesRunTime.unboxToInt((Object)feature)), IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Attribute.class));
        Metadata featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata();
        NumericAttribute labelAttribute = numClasses == 0 ? NumericAttribute$.MODULE$.defaultAttr().withName("label") : NominalAttribute$.MODULE$.defaultAttr().withName("label").withNumValues(numClasses);
        Metadata labelMetadata = labelAttribute.toMetadata();
        return df.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{df.apply("features").as("features", featuresMetadata), df.apply("label").as("label", labelMetadata)}));
    }

    public Dataset<Row> setMetadata(JavaRDD<LabeledPoint> data, Map<Integer, Integer> categoricalFeatures, int numClasses) {
        return this.setMetadata((RDD<LabeledPoint>)data.rdd(), (scala.collection.immutable.Map<Object, Object>)((TraversableOnce)JavaConverters$.MODULE$.mapAsScalaMapConverter(categoricalFeatures).asScala()).toMap(Predef$.MODULE$.$conforms()), numClasses);
    }

    public Dataset<Row> setMetadata(Dataset<Row> data, int numClasses, String labelColName, String featuresColName) {
        NumericAttribute labelAttribute = numClasses == 0 ? NumericAttribute$.MODULE$.defaultAttr().withName(labelColName) : NominalAttribute$.MODULE$.defaultAttr().withName(labelColName).withNumValues(numClasses);
        Metadata labelMetadata = labelAttribute.toMetadata();
        return data.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{data.apply(featuresColName), data.apply(labelColName).as(labelColName, labelMetadata)}));
    }

    public void checkEqual(DecisionTreeModel a, DecisionTreeModel b) {
        try {
            this.checkEqual(a.rootNode(), b.rootNode());
        }
        catch (Exception ex) {
            throw new AssertionError(new StringBuilder(76).append("checkEqual failed since the two trees were not identical.\nTREE A:\n").append(a.toDebugString()).append("\n").append("TREE B:\n").append(b.toDebugString()).append("\n").toString(), ex);
        }
    }

    private void checkEqual(Node a, Node b) {
        block4: {
            block3: {
                Tuple2 tuple2;
                while (true) {
                    TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left = this.convertToEqualizer(BoxesRunTime.boxToDouble((double)a.prediction()));
                    double $org_scalatest_assert_macro_right = b.prediction();
                    Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left, "===", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), $org_scalatest_assert_macro_left.$eq$eq$eq((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), Equality$.MODULE$.default()), Prettifier$.MODULE$.default());
                    this.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)"", Prettifier$.MODULE$.default(), new Position("TreeTests.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 127));
                    TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left2 = this.convertToEqualizer(BoxesRunTime.boxToDouble((double)a.impurity()));
                    double $org_scalatest_assert_macro_right2 = b.impurity();
                    Bool $org_scalatest_assert_macro_expr2 = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left2, "===", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right2), $org_scalatest_assert_macro_left2.$eq$eq$eq((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right2), Equality$.MODULE$.default()), Prettifier$.MODULE$.default());
                    this.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr2, (Object)"", Prettifier$.MODULE$.default(), new Position("TreeTests.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 128));
                    tuple2 = new Tuple2((Object)a, (Object)b);
                    if (tuple2 == null) break;
                    Node aye = (Node)tuple2._1();
                    Node bee = (Node)tuple2._2();
                    if (!(aye instanceof InternalNode)) break;
                    InternalNode internalNode = (InternalNode)aye;
                    if (!(bee instanceof InternalNode)) break;
                    InternalNode internalNode2 = (InternalNode)bee;
                    TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left3 = this.convertToEqualizer(internalNode.split());
                    Split $org_scalatest_assert_macro_right3 = internalNode2.split();
                    Bool $org_scalatest_assert_macro_expr3 = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left3, "===", (Object)$org_scalatest_assert_macro_right3, $org_scalatest_assert_macro_left3.$eq$eq$eq((Object)$org_scalatest_assert_macro_right3, Equality$.MODULE$.default()), Prettifier$.MODULE$.default());
                    this.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr3, (Object)"", Prettifier$.MODULE$.default(), new Position("TreeTests.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 131));
                    this.checkEqual(internalNode.leftChild(), internalNode2.leftChild());
                    b = internalNode2.rightChild();
                    a = internalNode.rightChild();
                }
                if (tuple2 == null) break block3;
                Node aye = (Node)tuple2._1();
                Node bee = (Node)tuple2._2();
                if (aye instanceof LeafNode && bee instanceof LeafNode) break block4;
            }
            throw new AssertionError((Object)"Found mismatched nodes");
        }
        BoxedUnit boxedUnit = BoxedUnit.UNIT;
    }

    public <M extends DecisionTreeModel> void checkEqual(TreeEnsembleModel<M> a, TreeEnsembleModel<M> b) {
        try {
            new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])a.trees())).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])b.trees()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).foreach((Function1 & Serializable & scala.Serializable)x0$1 -> {
                TreeTests$.$anonfun$checkEqual$1(x0$1);
                return BoxedUnit.UNIT;
            });
            TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left = this.convertToEqualizer(a.treeWeights());
            double[] $org_scalatest_assert_macro_right = b.treeWeights();
            Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left, "===", (Object)$org_scalatest_assert_macro_right, $org_scalatest_assert_macro_left.$eq$eq$eq((Object)$org_scalatest_assert_macro_right, Equality$.MODULE$.default()), Prettifier$.MODULE$.default());
            this.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)"", Prettifier$.MODULE$.default(), new Position("TreeTests.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 149));
        }
        catch (Exception ex) {
            throw new AssertionError((Object)"checkEqual failed since the two tree ensembles were not identical");
        }
    }

    public Node buildParentNode(Node left, Node right, Split split) {
        ImpurityCalculator leftImp = left.impurityStats();
        ImpurityCalculator rightImp = right.impurityStats();
        ImpurityCalculator parentImp = leftImp.copy().add(rightImp);
        double leftWeight = (double)leftImp.count() / (double)parentImp.count();
        double rightWeight = (double)rightImp.count() / (double)parentImp.count();
        double gain = parentImp.calculate() - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate());
        double pred = parentImp.predict();
        return new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp);
    }

    public RDD<LabeledPoint> featureImportanceData(SparkContext sc) {
        return sc.parallelize((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new LabeledPoint[]{new LabeledPoint(0.0, Vectors$.MODULE$.dense(1.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{0.0, 0.0, 0.0, 1.0}))), new LabeledPoint(1.0, Vectors$.MODULE$.dense(1.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{1.0, 0.0, 1.0, 0.0}))), new LabeledPoint(1.0, Vectors$.MODULE$.dense(1.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{1.0, 0.0, 0.0, 0.0}))), new LabeledPoint(0.0, Vectors$.MODULE$.dense(1.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{0.0, 0.0, 0.0, 0.0}))), new LabeledPoint(1.0, Vectors$.MODULE$.dense(1.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{1.0, 0.0, 0.0, 0.0})))})), sc.parallelize$default$2(), ClassTag$.MODULE$.apply(LabeledPoint.class));
    }

    public RDD<LabeledPoint> varianceData(SparkContext sc) {
        return sc.parallelize((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new LabeledPoint[]{new LabeledPoint(1.0, Vectors$.MODULE$.dense(new double[]{0.0})), new LabeledPoint(2.0, Vectors$.MODULE$.dense(new double[]{1.0})), new LabeledPoint(3.0, Vectors$.MODULE$.dense(new double[]{2.0})), new LabeledPoint(10.0, Vectors$.MODULE$.dense(new double[]{3.0})), new LabeledPoint(12.0, Vectors$.MODULE$.dense(new double[]{4.0})), new LabeledPoint(14.0, Vectors$.MODULE$.dense(new double[]{5.0}))})), sc.parallelize$default$2(), ClassTag$.MODULE$.apply(LabeledPoint.class));
    }

    public scala.collection.immutable.Map<String, Object> allParamSettings() {
        return this.allParamSettings;
    }

    public RDD<LabeledPoint> getTreeReadWriteData(SparkContext sc) {
        LabeledPoint[] arr = (LabeledPoint[])((Object[])new LabeledPoint[]{new LabeledPoint(0.0, Vectors$.MODULE$.dense(0.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{0.0}))), new LabeledPoint(1.0, Vectors$.MODULE$.dense(0.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{1.0}))), new LabeledPoint(0.0, Vectors$.MODULE$.dense(0.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{0.0}))), new LabeledPoint(0.0, Vectors$.MODULE$.dense(0.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{2.0}))), new LabeledPoint(0.0, Vectors$.MODULE$.dense(1.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{0.0}))), new LabeledPoint(1.0, Vectors$.MODULE$.dense(1.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{1.0}))), new LabeledPoint(1.0, Vectors$.MODULE$.dense(1.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{0.0}))), new LabeledPoint(1.0, Vectors$.MODULE$.dense(1.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{2.0})))});
        return sc.parallelize((Seq)Predef$.MODULE$.wrapRefArray((Object[])arr), sc.parallelize$default$2(), ClassTag$.MODULE$.apply(LabeledPoint.class));
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ Attribute $anonfun$setMetadata$1(scala.collection.immutable.Map categoricalFeatures$1, int feature) {
        return categoricalFeatures$1.contains((Object)BoxesRunTime.boxToInteger((int)feature)) ? NominalAttribute$.MODULE$.defaultAttr().withIndex(feature).withNumValues(BoxesRunTime.unboxToInt((Object)categoricalFeatures$1.apply((Object)BoxesRunTime.boxToInteger((int)feature)))) : NumericAttribute$.MODULE$.defaultAttr().withIndex(feature);
    }

    public static final /* synthetic */ void $anonfun$checkEqual$1(Tuple2 x0$1) {
        Tuple2 tuple2 = x0$1;
        if (tuple2 == null) {
            throw new MatchError((Object)tuple2);
        }
        DecisionTreeModel treeA = (DecisionTreeModel)tuple2._1();
        DecisionTreeModel treeB = (DecisionTreeModel)tuple2._2();
        MODULE$.checkEqual(treeA, treeB);
        BoxedUnit boxedUnit = BoxedUnit.UNIT;
    }

    private TreeTests$() {
        MODULE$ = this;
        this.allParamSettings = (scala.collection.immutable.Map)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"checkpointInterval"), (Object)BoxesRunTime.boxToInteger((int)7)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"seed"), (Object)BoxesRunTime.boxToLong((long)543L)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"maxDepth"), (Object)BoxesRunTime.boxToInteger((int)2)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"maxBins"), (Object)BoxesRunTime.boxToInteger((int)20)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"minInstancesPerNode"), (Object)BoxesRunTime.boxToInteger((int)2)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"minInfoGain"), (Object)BoxesRunTime.boxToDouble((double)1.0E-14)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"maxMemoryInMB"), (Object)BoxesRunTime.boxToInteger((int)257)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"cacheNodeIds"), (Object)BoxesRunTime.boxToBoolean((boolean)true))}));
    }
}

