/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.preprocessing;

import ai.h2o.automl.AutoML;
import ai.h2o.automl.AutoMLBuildSpec;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.PreprocessingStep;
import ai.h2o.automl.preprocessing.PreprocessingStepDefinition;
import ai.h2o.targetencoding.TargetEncoder;
import ai.h2o.targetencoding.TargetEncoderModel;
import ai.h2o.targetencoding.TargetEncoderPreprocessor;
import hex.Model;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import water.DKV;
import water.Key;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.ast.prims.advmath.AstKFold;
import water.util.ArrayUtils;

public class TargetEncoding
implements PreprocessingStep {
    public static String CONFIG_ENABLED = "target_encoding_enabled";
    public static String CONFIG_PREPARE_CV_ONLY = "target_encoding_prepare_cv_only";
    static String TE_FOLD_COLUMN_SUFFIX = "_te_fold";
    private static final PreprocessingStep.Completer NOOP = () -> {};
    private AutoML _aml;
    private TargetEncoderPreprocessor _tePreprocessor;
    private TargetEncoderModel _teModel;
    private final List<PreprocessingStep.Completer> _disposables = new ArrayList<PreprocessingStep.Completer>();
    private TargetEncoderModel.TargetEncoderParameters _defaultParams;
    private boolean _encodeAllColumns = false;
    private int _columnCardinalityThreshold = 25;

    public TargetEncoding(AutoML aml) {
        this._aml = aml;
    }

    @Override
    public String getType() {
        return PreprocessingStepDefinition.Type.TargetEncoding.name();
    }

    @Override
    public void prepare() {
        AutoMLBuildSpec.AutoMLInput amlInput = this._aml.getBuildSpec().input_spec;
        AutoMLBuildSpec.AutoMLBuildControl amlBuild = this._aml.getBuildSpec().build_control;
        Frame amlTrain = this._aml.getTrainingFrame();
        TargetEncoderModel.TargetEncoderParameters params = (TargetEncoderModel.TargetEncoderParameters)this.getDefaultParams().clone();
        params._train = amlTrain._key;
        params._response_column = amlInput.response_column;
        params._seed = amlBuild.stopping_criteria.seed();
        Set<String> teColumns = this.selectColumnsToEncode(amlTrain, params);
        if (teColumns.isEmpty()) {
            return;
        }
        this._aml.eventLog().warn(EventLogEntry.Stage.FeatureCreation, "Target Encoding integration in AutoML is in an experimental stage, the models obtained with this feature can not yet be downloaded as MOJO for production.");
        if (this._aml.isCVEnabled()) {
            params._data_leakage_handling = TargetEncoderModel.DataLeakageHandlingStrategy.KFold;
            params._fold_column = amlInput.fold_column;
            if (params._fold_column == null) {
                Frame train = new Frame(params.train());
                Vec foldColumn = TargetEncoding.createFoldColumn(params.train(), Model.Parameters.FoldAssignmentScheme.Modulo, amlBuild.nfolds, params._response_column, params._seed);
                DKV.put(foldColumn);
                params._fold_column = params._response_column + TE_FOLD_COLUMN_SUFFIX;
                train.add(params._fold_column, foldColumn);
                TargetEncoding.register(train, params._train.toString(), true);
                params._train = train._key;
                this._disposables.add(() -> {
                    foldColumn.remove();
                    DKV.remove(train._key);
                });
            }
        }
        String[] keep = params.getNonPredictors();
        params._ignored_columns = (String[])Arrays.stream(amlTrain.names()).filter(col -> !teColumns.contains(col) && !ArrayUtils.contains(keep, col)).toArray(String[]::new);
        TargetEncoder te = new TargetEncoder(params, this._aml.makeKey(this.getType(), null, false));
        this._teModel = (TargetEncoderModel)te.trainModel().get();
        this._tePreprocessor = new TargetEncoderPreprocessor(this._teModel);
    }

    @Override
    public PreprocessingStep.Completer apply(Model.Parameters params, PreprocessingConfig config) {
        boolean addFoldColumn;
        if (this._tePreprocessor == null || !config.get(CONFIG_ENABLED, true)) {
            return NOOP;
        }
        if (!config.get(CONFIG_PREPARE_CV_ONLY, false)) {
            params._preprocessors = ArrayUtils.append(params._preprocessors, this._tePreprocessor._key);
        }
        Frame train = new Frame(params.train());
        String foldColumn = ((TargetEncoderModel.TargetEncoderParameters)this._teModel._parms)._fold_column;
        boolean bl = addFoldColumn = foldColumn != null && train.find(foldColumn) < 0;
        if (addFoldColumn) {
            train.add(foldColumn, ((Frame)((TargetEncoderModel.TargetEncoderParameters)this._teModel._parms)._train.get()).vec(foldColumn));
            TargetEncoding.register(train, params._train.toString(), true);
            params._train = train._key;
            params._fold_column = foldColumn;
            params._nfolds = 0;
            params._fold_assignment = Model.Parameters.FoldAssignmentScheme.AUTO;
        }
        return () -> {
            if (addFoldColumn) {
                DKV.remove(train._key);
            }
        };
    }

    @Override
    public void dispose() {
        for (PreprocessingStep.Completer disposable : this._disposables) {
            disposable.run();
        }
    }

    @Override
    public void remove() {
        if (this._tePreprocessor != null) {
            this._tePreprocessor.remove(true);
            this._tePreprocessor = null;
            this._teModel = null;
        }
    }

    public void setDefaultParams(TargetEncoderModel.TargetEncoderParameters defaultParams) {
        this._defaultParams = defaultParams;
    }

    public void setEncodeAllColumns(boolean encodeAllColumns) {
        this._encodeAllColumns = encodeAllColumns;
    }

    public void setColumnCardinalityThreshold(int threshold) {
        this._columnCardinalityThreshold = threshold;
    }

    private TargetEncoderModel.TargetEncoderParameters getDefaultParams() {
        if (this._defaultParams != null) {
            return this._defaultParams;
        }
        this._defaultParams = new TargetEncoderModel.TargetEncoderParameters();
        this._defaultParams._keep_original_categorical_columns = false;
        this._defaultParams._blending = true;
        this._defaultParams._inflection_point = 5.0;
        this._defaultParams._smoothing = 10.0;
        this._defaultParams._noise = 0.0;
        return this._defaultParams;
    }

    private Set<String> selectColumnsToEncode(Frame fr, TargetEncoderModel.TargetEncoderParameters params) {
        HashSet<String> encode2 = new HashSet<String>();
        if (this._encodeAllColumns) {
            encode2.addAll(Arrays.asList(fr.names()));
        } else {
            Predicate<Vec> cardinalityLargeEnough = v2 -> v2.cardinality() >= this._columnCardinalityThreshold;
            Predicate<Vec> cardinalityNotTooLarge = params._blending ? v2 -> (double)fr.numRows() / (double)v2.cardinality() > params._inflection_point : v2 -> true;
            for (int i2 = 0; i2 < fr.names().length; ++i2) {
                Vec v3 = fr.vec(i2);
                if (!cardinalityLargeEnough.test(v3) || !cardinalityNotTooLarge.test(v3)) continue;
                encode2.add(fr.name(i2));
            }
        }
        AutoMLBuildSpec.AutoMLInput amlInput = this._aml.getBuildSpec().input_spec;
        List<String> nonPredictors = Arrays.asList(amlInput.weights_column, amlInput.fold_column, amlInput.response_column);
        encode2.removeAll(nonPredictors);
        return encode2;
    }

    TargetEncoderPreprocessor getTEPreprocessor() {
        return this._tePreprocessor;
    }

    TargetEncoderModel getTEModel() {
        return this._teModel;
    }

    private static void register(Frame fr, String keyPrefix, boolean force) {
        Key key = fr._key;
        if (key == null || force) {
            Key key2 = fr._key = keyPrefix == null ? Key.make() : Key.make(keyPrefix + "_" + Key.rand());
        }
        if (force) {
            DKV.remove(key);
        }
        DKV.put(fr);
    }

    public static Vec createFoldColumn(Frame fr, Model.Parameters.FoldAssignmentScheme fold_assignment, int nfolds, String responseColumn, long seed) {
        Vec foldColumn;
        switch (fold_assignment) {
            default: {
                foldColumn = AstKFold.kfoldColumn(fr.anyVec().makeZero(), nfolds, seed);
                break;
            }
            case Modulo: {
                foldColumn = AstKFold.moduloKfoldColumn(fr.anyVec().makeZero(), nfolds);
                break;
            }
            case Stratified: {
                foldColumn = AstKFold.stratifiedKFoldColumn(fr.vec(responseColumn), nfolds, seed);
            }
        }
        return foldColumn;
    }
}

