/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.targetencoding;

import ai.h2o.targetencoding.TargetEncoder;
import ai.h2o.targetencoding.TargetEncoderModel;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import java.util.Arrays;
import java.util.Map;
import water.Key;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.IcedHashMapGeneric;
import water.util.Log;

public class TargetEncoderBuilder
extends ModelBuilder<TargetEncoderModel, TargetEncoderModel.TargetEncoderParameters, TargetEncoderModel.TargetEncoderOutput> {
    private TargetEncoderModel _targetEncoderModel;

    public TargetEncoderModel getTargetEncoderModel() {
        assert (this._targetEncoderModel != null) : "Training phase of the TargetEncoderBuilder did not take place yet. TargetEncoderModel is not available.";
        return this._targetEncoderModel;
    }

    public TargetEncoderBuilder(TargetEncoderModel.TargetEncoderParameters parms) {
        super((Model.Parameters)parms);
        super.init(false);
    }

    public TargetEncoderBuilder(boolean startupOnce) {
        super((Model.Parameters)new TargetEncoderModel.TargetEncoderParameters(), startupOnce);
    }

    public void init(boolean expensive) {
        super.init(expensive);
    }

    protected void ignoreInvalidColumns(int npredictors, boolean expensive) {
        new ModelBuilder.FilterCols(npredictors){

            protected boolean filter(Vec v) {
                return !v.isCategorical();
            }
        }.doIt(this.train(), "Removing non-categorical columns found in the list of encoded columns.", expensive);
    }

    public boolean nFoldCV() {
        return false;
    }

    protected ModelBuilder.Driver trainModelImpl() {
        return new TargetEncoderDriver();
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.TargetEncoder};
    }

    public boolean isSupervised() {
        return true;
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public boolean haveMojo() {
        return true;
    }

    public String getName() {
        return "targetencoder";
    }

    private class TargetEncoderDriver
    extends ModelBuilder.Driver {
        private TargetEncoderDriver() {
            super((ModelBuilder)TargetEncoderBuilder.this);
        }

        public void computeImpl() {
            int numColsRemoved = TargetEncoderBuilder.this.hasFoldCol() ? 2 : 1;
            String[] encodedColumns = Arrays.copyOf(TargetEncoderBuilder.this.train().names(), TargetEncoderBuilder.this.train().names().length - numColsRemoved);
            TargetEncoder tec = new TargetEncoder(encodedColumns);
            Scope.untrack((Key[])TargetEncoderBuilder.this.train().keys());
            IcedHashMapGeneric<String, Frame> _targetEncodingMap = tec.prepareEncodingMap(TargetEncoderBuilder.this.train(), ((TargetEncoderModel.TargetEncoderParameters)TargetEncoderBuilder.this._parms)._response_column, ((TargetEncoderModel.TargetEncoderParameters)TargetEncoderBuilder.this._parms)._fold_column);
            double priorMean = tec.calculatePriorMean((Frame)((Map.Entry)_targetEncodingMap.entrySet().iterator().next()).getValue());
            for (Map.Entry entry : _targetEncodingMap.entrySet()) {
                Frame frameWithEncodingMap = (Frame)entry.getValue();
                Scope.untrack((Key[])frameWithEncodingMap.keys());
            }
            this.disableIgnoreConstColsFeature();
            TargetEncoderModel.TargetEncoderOutput output = new TargetEncoderModel.TargetEncoderOutput(TargetEncoderBuilder.this, _targetEncodingMap, priorMean);
            TargetEncoderBuilder.this._targetEncoderModel = new TargetEncoderModel((Key<TargetEncoderModel>)TargetEncoderBuilder.this._job._result, (TargetEncoderModel.TargetEncoderParameters)TargetEncoderBuilder.this._parms, output, tec);
            TargetEncoderBuilder.this._targetEncoderModel.write_lock(TargetEncoderBuilder.this._job);
            TargetEncoderBuilder.this._targetEncoderModel.unlock(TargetEncoderBuilder.this._job);
        }

        private void disableIgnoreConstColsFeature() {
            ((TargetEncoderModel.TargetEncoderParameters)TargetEncoderBuilder.this._parms)._ignore_const_cols = false;
            Log.info((Object[])new Object[]{"We don't want to ignore any columns during target encoding transformation therefore `_ignore_const_cols` parameter was set to `false`"});
        }
    }
}

