/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.mojos.runtime.h2o3;

import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoColumnFloat64;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameBuilder;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.utils.Debug;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import java.util.Collections;
import org.joda.time.DateTime;

public class MojoPipelineH2O3Impl
extends MojoPipeline {
    private final EasyPredictModelWrapper _model;
    private final GenModel genModel;
    private final MojoFrameMeta _inputMeta;
    private final MojoFrameMeta _outputMeta;

    MojoPipelineH2O3Impl(MojoModel model) {
        super(model.getUUID(), new DateTime(1970, 1, 1, 0, 0), "");
        this._model = MojoPipelineH2O3Impl.wrapModelForPrediction(model);
        this.genModel = this._model.m;
        switch (this.genModel.getModelCategory()) {
            case Binomial: 
            case Multinomial: 
            case Regression: {
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported ModelCategory: " + this.genModel.getModelCategory().toString());
            }
        }
        String[] inames = new String[this.genModel.getNumCols()];
        MojoColumn.Type[] itypes = new MojoColumn.Type[inames.length];
        for (int i = 0; i < inames.length; ++i) {
            inames[i] = this.genModel.getNames()[i];
            itypes[i] = this.genModel.getDomainValues(i) == null ? MojoColumn.Type.Float64 : MojoColumn.Type.Str;
        }
        this._inputMeta = new MojoFrameMeta(inames, itypes);
        switch (this.genModel.getModelCategory()) {
            case Binomial: 
            case Multinomial: {
                String[] onames = new String[this.genModel.getNumResponseClasses()];
                MojoColumn.Type[] otypes = new MojoColumn.Type[onames.length];
                for (int i = 0; i < onames.length; ++i) {
                    onames[i] = this.genModel.getResponseName() + "." + this.genModel.getDomainValues(this.genModel.getResponseIdx())[i];
                    otypes[i] = MojoColumn.Type.Float64;
                }
                this._outputMeta = new MojoFrameMeta(onames, otypes);
                break;
            }
            case Regression: {
                String[] onames = new String[]{this.genModel.getResponseName()};
                MojoColumn.Type[] otypes = new MojoColumn.Type[]{MojoColumn.Type.Float64};
                this._outputMeta = new MojoFrameMeta(onames, otypes);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported ModelCategory: " + this.genModel.getModelCategory().toString());
            }
        }
    }

    protected MojoFrameBuilder getFrameBuilder(MojoColumn.Kind kind) {
        return new MojoFrameBuilder(this.getMeta(kind), Collections.emptyList(), Collections.emptyMap());
    }

    protected MojoFrameMeta getMeta(MojoColumn.Kind kind) {
        switch (kind) {
            case Feature: {
                return this._inputMeta;
            }
            case Output: {
                return this._outputMeta;
            }
        }
        throw new UnsupportedOperationException("Cannot generate meta for interim frame");
    }

    public MojoFrame transform(MojoFrame inputFrame, MojoFrame outputFrame) {
        ModelCategory modelCategory = this.genModel.getModelCategory();
        int colCount = inputFrame.getNcols();
        int rowCount = inputFrame.getNrows();
        String[] columnNames = inputFrame.getColumnNames();
        String[][] columns = new String[colCount][];
        for (int j = 0; j < colCount; ++j) {
            columns[j] = inputFrame.getColumn(j).getDataAsStrings();
        }
        for (int rowIdx = 0; rowIdx < rowCount; ++rowIdx) {
            RowData rowData = new RowData();
            for (int colIdx = 0; colIdx < colCount; ++colIdx) {
                String key = columnNames[colIdx];
                String value = columns[colIdx][rowIdx];
                if (value == null) continue;
                rowData.put((Object)key, (Object)value);
            }
            try {
                switch (modelCategory) {
                    case Binomial: {
                        BinomialModelPrediction p = this._model.predictBinomial(rowData);
                        this.setPrediction(outputFrame, rowIdx, p.classProbabilities);
                        break;
                    }
                    case Multinomial: {
                        BinomialModelPrediction p = this._model.predictMultinomial(rowData);
                        this.setPrediction(outputFrame, rowIdx, p.classProbabilities);
                        break;
                    }
                    case Regression: {
                        BinomialModelPrediction p = this._model.predictRegression(rowData);
                        MojoColumnFloat64 col = (MojoColumnFloat64)outputFrame.getColumn(0);
                        double[] darr = (double[])col.getData();
                        darr[rowIdx] = p.value;
                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException("Unsupported ModelCategory: " + modelCategory.toString());
                    }
                }
                continue;
            }
            catch (UnsupportedOperationException e) {
                throw e;
            }
            catch (PredictException e) {
                if (Debug.getPrintH2O3Exceptions()) {
                    e.printStackTrace();
                }
                throw new UnsupportedOperationException(String.format("%s failed: %s", modelCategory, e.getMessage()));
            }
            catch (Exception e) {
                if (Debug.getPrintH2O3Exceptions()) {
                    e.printStackTrace();
                }
                throw new UnsupportedOperationException(String.format("%s failed with %s: %s", modelCategory, e.getClass().getName(), e.getMessage()));
            }
        }
        return outputFrame;
    }

    private void setPrediction(MojoFrame outputFrame, int rowIdx, double[] classProbabilities) {
        for (int outputColIdx = 0; outputColIdx < this.genModel.getNumResponseClasses(); ++outputColIdx) {
            MojoColumnFloat64 col = (MojoColumnFloat64)outputFrame.getColumn(outputColIdx);
            double[] darr = (double[])col.getData();
            darr[rowIdx] = classProbabilities[outputColIdx];
        }
    }

    private static EasyPredictModelWrapper wrapModelForPrediction(MojoModel model) {
        EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config().setModel((GenModel)model).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(true);
        return new EasyPredictModelWrapper(config);
    }
}

