/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.targetencoding;

import ai.h2o.targetencoding.TargetEncoderHelper;
import ai.h2o.targetencoding.TargetEncoderModel;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import water.DKV;
import water.Key;
import water.Keyed;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.logging.Logger;
import water.logging.LoggerFactory;
import water.util.IcedHashMap;

public class TargetEncoder
extends ModelBuilder<TargetEncoderModel, TargetEncoderModel.TargetEncoderParameters, TargetEncoderModel.TargetEncoderOutput> {
    private static final Logger logger = LoggerFactory.getLogger(TargetEncoder.class);
    private TargetEncoderModel _targetEncoderModel;
    private String[] _columnsToEncode;

    public TargetEncoder(TargetEncoderModel.TargetEncoderParameters parms) {
        super((Model.Parameters)parms);
        this.init(false);
    }

    public TargetEncoder(TargetEncoderModel.TargetEncoderParameters parms, Key<TargetEncoderModel> key) {
        super((Model.Parameters)parms, key);
        this.init(false);
    }

    public TargetEncoder(boolean startupOnce) {
        super((Model.Parameters)new TargetEncoderModel.TargetEncoderParameters(), startupOnce);
    }

    public void init(boolean expensive) {
        this.disableIgnoreConstColsFeature(expensive);
        super.init(expensive);
        assert (((TargetEncoderModel.TargetEncoderParameters)this._parms)._nfolds == 0) : "nfolds usage forbidden in TargetEncoder";
        if (expensive) {
            if (((TargetEncoderModel.TargetEncoderParameters)this._parms)._data_leakage_handling == null) {
                ((TargetEncoderModel.TargetEncoderParameters)this._parms)._data_leakage_handling = TargetEncoderModel.DataLeakageHandlingStrategy.None;
            }
            if (((TargetEncoderModel.TargetEncoderParameters)this._parms)._data_leakage_handling == TargetEncoderModel.DataLeakageHandlingStrategy.KFold && ((TargetEncoderModel.TargetEncoderParameters)this._parms)._fold_column == null) {
                this.error("_fold_column", "Fold column is required when using KFold leakage handling strategy.");
            }
            List<String> colsToIgnore = Arrays.asList(((TargetEncoderModel.TargetEncoderParameters)this._parms)._response_column, ((TargetEncoderModel.TargetEncoderParameters)this._parms)._fold_column, ((TargetEncoderModel.TargetEncoderParameters)this._parms)._weights_column, ((TargetEncoderModel.TargetEncoderParameters)this._parms)._offset_column);
            Frame train = this.train();
            ArrayList<String> columnsToEncode = new ArrayList<String>(train.numCols());
            for (int i = 0; i < train.numCols(); ++i) {
                String colName = train.name(i);
                if (colsToIgnore.contains(colName)) continue;
                if (!train.vec(i).isCategorical()) {
                    this.warn("_train", "Column `" + colName + "` is not categorical and will therefore be ignored by target encoder.");
                    continue;
                }
                columnsToEncode.add(colName);
            }
            this._columnsToEncode = columnsToEncode.toArray(new String[0]);
        }
    }

    private void disableIgnoreConstColsFeature(boolean expensive) {
        ((TargetEncoderModel.TargetEncoderParameters)this._parms)._ignore_const_cols = false;
        if (expensive && logger.isInfoEnabled()) {
            logger.info("We don't want to ignore any columns during target encoding transformation therefore `_ignore_const_cols` parameter was set to `false`");
        }
    }

    protected void ignoreInvalidColumns(int npredictors, boolean expensive) {
        new ModelBuilder.FilterCols(npredictors){

            protected boolean filter(Vec v) {
                return !v.isCategorical();
            }
        }.doIt(this.train(), "Removing non-categorical columns found in the list of encoded columns.", expensive);
    }

    public boolean nFoldCV() {
        return false;
    }

    protected ModelBuilder.Driver trainModelImpl() {
        return new TargetEncoderDriver();
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.TargetEncoder};
    }

    public boolean isSupervised() {
        return true;
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public boolean haveMojo() {
        return true;
    }

    private class TargetEncoderDriver
    extends ModelBuilder.Driver {
        private TargetEncoderDriver() {
            super((ModelBuilder)TargetEncoder.this);
        }

        public void computeImpl() {
            TargetEncoder.this._targetEncoderModel = null;
            try {
                TargetEncoder.this.init(true);
                if (TargetEncoder.this.error_count() > 0) {
                    throw H2OModelBuilderIllegalArgumentException.makeFromBuilder((ModelBuilder)TargetEncoder.this);
                }
                TargetEncoderModel.TargetEncoderOutput emptyOutput = new TargetEncoderModel.TargetEncoderOutput(TargetEncoder.this, (IcedHashMap<String, Frame>)new IcedHashMap());
                TargetEncoderModel model = new TargetEncoderModel((Key<TargetEncoderModel>)TargetEncoder.this.dest(), (TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms, emptyOutput);
                TargetEncoder.this._targetEncoderModel = (TargetEncoderModel)model.delete_and_lock(TargetEncoder.this._job);
                IcedHashMap<String, Frame> _targetEncodingMap = this.prepareEncodingMap();
                for (Map.Entry entry : _targetEncodingMap.entrySet()) {
                    Frame encodings = (Frame)entry.getValue();
                    Scope.untrack((Frame[])new Frame[]{encodings});
                }
                ((TargetEncoder)TargetEncoder.this)._targetEncoderModel._output = new TargetEncoderModel.TargetEncoderOutput(TargetEncoder.this, _targetEncodingMap);
                TargetEncoder.this._job.update(1L);
            }
            catch (Exception e) {
                if (TargetEncoder.this._targetEncoderModel != null) {
                    Scope.track_generic((Keyed)TargetEncoder.this._targetEncoderModel);
                }
                throw e;
            }
            finally {
                if (TargetEncoder.this._targetEncoderModel != null) {
                    TargetEncoder.this._targetEncoderModel.update(TargetEncoder.this._job);
                    TargetEncoder.this._targetEncoderModel.unlock(TargetEncoder.this._job);
                }
            }
        }

        private Frame filterOutNAsFromTargetColumn(Frame data, int targetColumnIndex) {
            return TargetEncoderHelper.filterOutNAsInColumn(data, targetColumnIndex);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private IcedHashMap<String, Frame> prepareEncodingMap() {
            Frame workingFrame = null;
            try {
                int targetIdx = TargetEncoder.this.train().find(((TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms)._response_column);
                int foldColIdx = ((TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms)._fold_column == null ? -1 : TargetEncoder.this.train().find(((TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms)._fold_column);
                workingFrame = this.filterOutNAsFromTargetColumn(TargetEncoder.this.train(), targetIdx);
                IcedHashMap columnToEncodings = new IcedHashMap();
                for (IcedHashMap columnToEncode : TargetEncoder.this._columnsToEncode) {
                    int colIdx = workingFrame.find((String)columnToEncode);
                    TargetEncoderHelper.imputeCategoricalColumn(workingFrame, colIdx, (String)columnToEncode + TargetEncoderHelper.NA_POSTFIX);
                    Frame encodings = TargetEncoderHelper.buildEncodingsFrame(workingFrame, colIdx, targetIdx, foldColIdx, TargetEncoder.this.nclasses());
                    Frame finalEncodings = this.applyLeakageStrategyToEncodings(encodings, (String)columnToEncode, ((TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms)._data_leakage_handling, ((TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms)._fold_column);
                    encodings.delete();
                    encodings = finalEncodings;
                    if (encodings._key != null) {
                        DKV.remove((Key)encodings._key);
                    }
                    encodings._key = Key.make((String)(TargetEncoder.this._result.toString() + "_encodings_" + (String)columnToEncode));
                    DKV.put((Keyed)encodings);
                    columnToEncodings.put((Object)columnToEncode, (Object)encodings);
                }
                IcedHashMap icedHashMap = columnToEncodings;
                return icedHashMap;
            }
            finally {
                if (workingFrame != null) {
                    workingFrame.delete();
                }
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private Frame applyLeakageStrategyToEncodings(Frame encodings, String columnToEncode, TargetEncoderModel.DataLeakageHandlingStrategy leakageHandlingStrategy, String foldColumn) {
            Frame groupedEncodings = null;
            int encodingsTEColIdx = encodings.find(columnToEncode);
            try {
                Scope.enter();
                switch (leakageHandlingStrategy) {
                    case KFold: {
                        long[] foldValues;
                        for (long foldValue : foldValues = TargetEncoderHelper.getUniqueColumnValues(encodings, encodings.find(foldColumn))) {
                            Frame outOfFoldEncodings = this.getOutOfFoldEncodings(encodings, foldColumn, foldValue);
                            Scope.track((Frame[])new Frame[]{outOfFoldEncodings});
                            Frame tmpEncodings = TargetEncoderHelper.register(TargetEncoderHelper.groupEncodingsByCategory(outOfFoldEncodings, encodingsTEColIdx));
                            Scope.track((Frame[])new Frame[]{tmpEncodings});
                            TargetEncoderHelper.addCon(tmpEncodings, foldColumn, foldValue);
                            if (groupedEncodings == null) {
                                groupedEncodings = tmpEncodings;
                            } else {
                                Frame newHoldoutEncodings = TargetEncoderHelper.rBind(groupedEncodings, tmpEncodings);
                                groupedEncodings.delete();
                                groupedEncodings = newHoldoutEncodings;
                            }
                            Scope.track((Frame[])new Frame[]{groupedEncodings});
                        }
                        break;
                    }
                    case LeaveOneOut: 
                    case None: {
                        groupedEncodings = TargetEncoderHelper.groupEncodingsByCategory(encodings, encodingsTEColIdx, foldColumn != null);
                        break;
                    }
                    default: {
                        throw new IllegalStateException("null or unsupported leakageHandlingStrategy");
                    }
                }
                Scope.untrack((Frame[])new Frame[]{groupedEncodings});
            }
            finally {
                Scope.exit((Key[])new Key[0]);
            }
            return groupedEncodings;
        }

        private Frame getOutOfFoldEncodings(Frame encodingsFrame, String foldColumn, long foldValue) {
            int foldColumnIdx = encodingsFrame.find(foldColumn);
            return TargetEncoderHelper.filterNotByValue(encodingsFrame, foldColumnIdx, foldValue);
        }
    }
}

