/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.spark.models.svm;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import java.util.Arrays;
import java.util.HashSet;
import org.apache.spark.SparkContext;
import org.apache.spark.h2o.H2OContext;
import org.apache.spark.ml.FrameMLUtils;
import org.apache.spark.ml.spark.ProgressListener;
import org.apache.spark.ml.spark.models.MissingValuesHandling;
import org.apache.spark.ml.spark.models.svm.SVMModel;
import org.apache.spark.ml.spark.models.svm.SVMParameters;
import org.apache.spark.mllib.classification.SVMWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.scheduler.SparkListenerInterface;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.storage.RDDInfo;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.JavaConversions;
import water.DKV;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;

public class SVM
extends ModelBuilder<SVMModel, SVMParameters, SVMModel.SVMOutput> {
    private final transient H2OContext hc;

    public SVM(boolean startup_once, H2OContext hc) {
        super((Model.Parameters)new SVMParameters(), startup_once);
        this.hc = hc;
    }

    public SVM(SVMParameters parms, H2OContext hc) {
        super((Model.Parameters)parms);
        this.init(false);
        this.hc = hc;
    }

    protected ModelBuilder.Driver trainModelImpl() {
        return new SVMDriver();
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Binomial, ModelCategory.Regression};
    }

    public boolean isSupervised() {
        return true;
    }

    public void init(boolean expensive) {
        super.init(expensive);
        ((SVMParameters)this._parms).validate(this);
        if (this._train == null) {
            return;
        }
        if (null != ((SVMParameters)this._parms)._initial_weights) {
            Frame user_points = (Frame)((SVMParameters)this._parms)._initial_weights.get();
            if (user_points.numCols() != this._train.numCols() - this.numSpecialCols()) {
                this.error("_initial_weights", "The user-specified initial weights must have the same number of columns (" + (this._train.numCols() - this.numSpecialCols()) + ") as the training observations");
            }
            if (user_points.hasNAs()) {
                this.error("_initial_weights", "Initial weights cannot contain missing values.");
            }
        }
        if (MissingValuesHandling.NotAllowed == ((SVMParameters)this._parms)._missing_values_handling) {
            for (int i = 0; i < this._train.numCols(); ++i) {
                Vec vec = this._train.vec(i);
                String vecName = this._train.name(i);
                if (vec.naCnt() <= 0L || null != ((SVMParameters)this._parms)._ignored_columns && Arrays.binarySearch(((SVMParameters)this._parms)._ignored_columns, vecName) >= 0) continue;
                this.error("_train", "Training frame cannot contain any missing values [" + vecName + "].");
            }
        }
        HashSet<String> ignoredCols = null != ((SVMParameters)this._parms)._ignored_columns ? new HashSet<String>(Arrays.asList(((SVMParameters)this._parms)._ignored_columns)) : new HashSet();
        for (int i = 0; i < this._train.vecs().length; ++i) {
            Vec vec = this._train.vec(i);
            if (ignoredCols.contains(this._train.name(i)) || vec.isNumeric() || vec.isCategorical()) continue;
            this.error("_train", "SVM supports only frames with numeric/categorical values (except for result column). But a " + vec.get_type_str() + " was found.");
        }
        if (null != ((SVMParameters)this._parms)._response_column && null == this._train.vec(((SVMParameters)this._parms)._response_column)) {
            this.error("_train", "Training frame has to contain the response column.");
        }
        if (this._train != null && ((SVMParameters)this._parms)._response_column != null) {
            String[] responseDomains = this.responseDomains();
            if (null == responseDomains) {
                if (!Double.isNaN(((SVMParameters)this._parms)._threshold)) {
                    this.error("_threshold", "Threshold cannot be set for regression SVM. Set the threshold to NaN or modify the response column to an enum.");
                }
                if (!this._train.vec(((SVMParameters)this._parms)._response_column).isNumeric()) {
                    this.error("_response_column", "Regression SVM requires the response column type to be numeric.");
                }
            } else {
                if (Double.isNaN(((SVMParameters)this._parms)._threshold)) {
                    this.error("_threshold", "Threshold has to be set for binomial SVM. Set the threshold to a numeric value or change the response column type.");
                }
                if (responseDomains.length != 2) {
                    this.error("_response_column", "SVM requires the response column's domain to be of size 2.");
                }
            }
        }
    }

    private String[] responseDomains() {
        int idx = ((SVMParameters)this._parms).train().find(((SVMParameters)this._parms)._response_column);
        if (idx == -1) {
            return null;
        }
        return ((SVMParameters)this._parms).train().domains()[idx];
    }

    public int numSpecialCols() {
        return (this.hasOffsetCol() ? 1 : 0) + (this.hasWeightCol() ? 1 : 0) + (this.hasFoldCol() ? 1 : 0) + 1;
    }

    private final class SVMDriver
    extends ModelBuilder.Driver {
        private transient SparkContext sc;
        private transient H2OContext h2oContext;
        private transient SQLContext sqlContext;

        private SVMDriver() {
            super((ModelBuilder)SVM.this);
            this.sc = SVM.this.hc.sparkContext();
            this.h2oContext = SVM.this.hc;
            this.sqlContext = SQLContext.getOrCreate((SparkContext)this.sc);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void computeImpl() {
            SVM.this.init(true);
            SVMModel model = new SVMModel((Key<SVMModel>)SVM.this.dest(), (SVMParameters)SVM.this._parms, new SVMModel.SVMOutput(SVM.this));
            try {
                model.delete_and_lock(SVM.this._job);
                Tuple2<RDD<LabeledPoint>, double[]> points = FrameMLUtils.toLabeledPoints(SVM.this._train, ((SVMParameters)SVM.this._parms)._response_column, ((SVMModel.SVMOutput)model._output).nfeatures(), ((SVMParameters)SVM.this._parms)._missing_values_handling, this.h2oContext, this.sqlContext);
                RDD training = (RDD)points._1();
                training.cache();
                if (training.count() == 0L && MissingValuesHandling.Skip == ((SVMParameters)SVM.this._parms)._missing_values_handling) {
                    throw new H2OIllegalArgumentException("No rows left in the dataset after filtering out rows with missing values. Ignore columns with many NAs or set missing_values_handling to 'MeanImputation'.");
                }
                SVMWithSGD svm = new SVMWithSGD();
                svm.setIntercept(((SVMParameters)SVM.this._parms)._add_intercept);
                svm.optimizer().setNumIterations(((SVMParameters)SVM.this._parms)._max_iterations);
                svm.optimizer().setStepSize(((SVMParameters)SVM.this._parms)._step_size);
                svm.optimizer().setRegParam(((SVMParameters)SVM.this._parms)._reg_param);
                svm.optimizer().setMiniBatchFraction(((SVMParameters)SVM.this._parms)._mini_batch_fraction);
                svm.optimizer().setConvergenceTol(((SVMParameters)SVM.this._parms)._convergence_tol);
                svm.optimizer().setGradient(((SVMParameters)SVM.this._parms)._gradient.get());
                svm.optimizer().setUpdater(((SVMParameters)SVM.this._parms)._updater.get());
                ProgressListener progressBar = new ProgressListener(this.sc, SVM.this._job, RDDInfo.fromRdd((RDD)training), (Iterable<String>)JavaConversions.iterableAsScalaIterable(Arrays.asList("treeAggregate")));
                this.sc.addSparkListener((SparkListenerInterface)progressBar);
                org.apache.spark.mllib.classification.SVMModel trainedModel = null == ((SVMParameters)SVM.this._parms)._initial_weights ? (org.apache.spark.mllib.classification.SVMModel)svm.run(training) : (org.apache.spark.mllib.classification.SVMModel)svm.run(training, this.vec2vec(((SVMParameters)SVM.this._parms).initialWeights().vecs()));
                training.unpersist(false);
                this.sc.listenerBus().listeners().remove((Object)progressBar);
                ((SVMModel.SVMOutput)model._output).weights_$eq(trainedModel.weights().toArray());
                ((SVMModel.SVMOutput)model._output).iterations_$eq(((SVMParameters)SVM.this._parms)._max_iterations);
                ((SVMModel.SVMOutput)model._output).interceptor_$eq(trainedModel.intercept());
                ((SVMModel.SVMOutput)model._output).numMeans_$eq((double[])points._2());
                Frame train = (Frame)DKV.getGet((Key)((SVMParameters)SVM.this._parms)._train);
                model.score(train).delete();
                ((SVMModel.SVMOutput)model._output)._training_metrics = ModelMetrics.getFromDKV((Model)model, (Frame)train);
                model.update(SVM.this._job);
                if (SVM.this._valid != null) {
                    model.score(((SVMParameters)SVM.this._parms).valid()).delete();
                    ((SVMModel.SVMOutput)model._output)._validation_metrics = ModelMetrics.getFromDKV((Model)model, (Frame)((SVMParameters)SVM.this._parms).valid());
                    model.update(SVM.this._job);
                }
                ((SVMModel.SVMOutput)model._output).interceptor_$eq(trainedModel.intercept());
                Log.info((Object[])new Object[]{((SVMModel.SVMOutput)model._output)._model_summary});
            }
            finally {
                model.unlock(SVM.this._job);
            }
        }

        private Vector vec2vec(Vec[] vals) {
            double[] dense = new double[vals.length];
            for (int i = 0; i < vals.length; ++i) {
                dense[i] = vals[i].at(0L);
            }
            return Vectors.dense((double[])dense);
        }
    }
}

