/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.classification;

import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.junit.Assert;
import org.junit.Test;

public class JavaLogisticRegressionSuite
extends SharedSparkSession {
    int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) {
        int numAccurate = 0;
        for (LabeledPoint point : validationData) {
            Double prediction = model.predict(point.features());
            if (prediction.doubleValue() != point.label()) continue;
            ++numAccurate;
        }
        return numAccurate;
    }

    @Test
    public void runLRUsingConstructor() {
        int nPoints = 10000;
        double A = 2.0;
        double B = -1.5;
        JavaRDD testRDD = this.jsc.parallelize(LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
        List<LabeledPoint> validationData = LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
        LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
        lrImpl.setIntercept(true);
        lrImpl.optimizer().setStepSize(1.0).setRegParam(1.0).setNumIterations(100);
        LogisticRegressionModel model = (LogisticRegressionModel)lrImpl.run(testRDD.rdd());
        int numAccurate = this.validatePrediction(validationData, model);
        Assert.assertTrue(((double)numAccurate > (double)nPoints * 4.0 / 5.0 ? 1 : 0) != 0);
    }

    @Test
    public void runLRUsingStaticMethods() {
        LogisticRegressionModel model;
        int nPoints = 10000;
        double A = 0.0;
        double B = -2.5;
        JavaRDD testRDD = this.jsc.parallelize(LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
        List<LabeledPoint> validationData = LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
        int numAccurate = this.validatePrediction(validationData, model = LogisticRegressionWithSGD.train((RDD)testRDD.rdd(), (int)100, (double)1.0, (double)1.0));
        Assert.assertTrue(((double)numAccurate > (double)nPoints * 4.0 / 5.0 ? 1 : 0) != 0);
    }
}

