/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import java.io.IOException;
import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Assert;
import org.junit.Test;
import scala.collection.JavaConverters;

public class JavaOneVsRestSuite
extends SharedSparkSession {
    private transient Dataset<Row> dataset;
    private transient JavaRDD<LabeledPoint> datasetRDD;

    @Override
    public void setUp() throws IOException {
        super.setUp();
        int nPoints = 3;
        double[] coefficients = new double[]{-0.57997, 0.912083, -0.371077, -0.819866, 2.688191, -0.16624, -0.84355, -0.048509, -0.301789, 4.170682};
        double[] xMean = new double[]{5.843, 3.057, 3.758, 1.199};
        double[] xVariance = new double[]{0.6856, 0.1899, 3.116, 0.581};
        List points = (List)JavaConverters.seqAsJavaListConverter(LogisticRegressionSuite.generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)).asJava();
        this.datasetRDD = this.jsc.parallelize(points, 2);
        this.dataset = this.spark.createDataFrame(this.datasetRDD, LabeledPoint.class);
    }

    @Test
    public void oneVsRestDefaultParams() {
        OneVsRest ova = new OneVsRest();
        ova.setClassifier((Classifier)new LogisticRegression());
        Assert.assertEquals((Object)ova.getLabelCol(), (Object)"label");
        Assert.assertEquals((Object)ova.getPredictionCol(), (Object)"prediction");
        OneVsRestModel ovaModel = ova.fit(this.dataset);
        Dataset predictions = ovaModel.transform(this.dataset).select("label", new String[]{"prediction"});
        predictions.collectAsList();
        Assert.assertEquals((Object)ovaModel.getLabelCol(), (Object)"label");
        Assert.assertEquals((Object)ovaModel.getPredictionCol(), (Object)"prediction");
    }
}

