/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.mojos.runtime.h2o3;

import ai.h2o.mojos.runtime.frame.MojoColumnFloat64;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.transforms.ShapCapableTransform;
import ai.h2o.mojos.runtime.utils.Debug;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.BinomialModelPrediction;

public class H2O3Transform
extends ShapCapableTransform {
    private final GenModel genModel;
    private final EasyPredictModelWrapper easyPredictModelWrapper;

    H2O3Transform(MojoFrameMeta meta, int[] iindices, int[] oindices, EasyPredictModelWrapper easyPredictModelWrapper) {
        super(iindices, oindices);
        this.easyPredictModelWrapper = easyPredictModelWrapper;
        this.genModel = easyPredictModelWrapper.m;
    }

    public void transform(MojoFrame frame) {
        ModelCategory modelCategory = this.genModel.getModelCategory();
        int colCount = this.iindices.length;
        int rowCount = frame.getNrows();
        String[][] columns = new String[colCount][];
        for (int j = 0; j < colCount; ++j) {
            int iidx = this.iindices[j];
            columns[j] = frame.getColumn(iidx).getDataAsStrings();
        }
        for (int rowIdx = 0; rowIdx < rowCount; ++rowIdx) {
            RowData rowData = new RowData();
            for (int colIdx = 0; colIdx < colCount; ++colIdx) {
                int iidx = this.iindices[colIdx];
                String key = frame.getColumnName(iidx);
                String value = columns[colIdx][rowIdx];
                if (value == null) continue;
                rowData.put((Object)key, (Object)value);
            }
            try {
                switch (modelCategory) {
                    case Binomial: {
                        BinomialModelPrediction p = this.easyPredictModelWrapper.predictBinomial(rowData);
                        this.setPrediction(frame, rowIdx, p.classProbabilities);
                        break;
                    }
                    case Multinomial: {
                        BinomialModelPrediction p = this.easyPredictModelWrapper.predictMultinomial(rowData);
                        this.setPrediction(frame, rowIdx, p.classProbabilities);
                        break;
                    }
                    case AutoEncoder: {
                        BinomialModelPrediction p = this.easyPredictModelWrapper.predictAutoEncoder(rowData);
                        this.setPrediction(frame, rowIdx, p.reconstructed);
                        break;
                    }
                    case Regression: {
                        BinomialModelPrediction p = this.easyPredictModelWrapper.predictRegression(rowData);
                        MojoColumnFloat64 col = (MojoColumnFloat64)frame.getColumn(this.oindices[0]);
                        double[] darr = (double[])col.getData();
                        darr[rowIdx] = p.value;
                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException("Unsupported ModelCategory: " + modelCategory.toString());
                    }
                }
                continue;
            }
            catch (UnsupportedOperationException e) {
                throw e;
            }
            catch (PredictException e) {
                if (Debug.getPrintH2O3Exceptions()) {
                    e.printStackTrace();
                }
                throw new UnsupportedOperationException(String.format("%s failed: %s", modelCategory, e.getMessage()));
            }
            catch (Exception e) {
                if (Debug.getPrintH2O3Exceptions()) {
                    e.printStackTrace();
                }
                throw new UnsupportedOperationException(String.format("%s failed with %s: %s", modelCategory, e.getClass().getName(), e.getMessage()));
            }
        }
    }

    private void setPrediction(MojoFrame frame, int rowIdx, double[] classProbabilities) {
        for (int outputColIdx = 0; outputColIdx < this.oindices.length; ++outputColIdx) {
            int oidx = this.oindices[outputColIdx];
            MojoColumnFloat64 col = (MojoColumnFloat64)frame.getColumn(oidx);
            double[] darr = (double[])col.getData();
            darr[rowIdx] = classProbabilities[outputColIdx];
        }
    }

    public void computeShap(double[] inputs, double[][] shapContribs) {
        ModelCategory modelCategory = this.genModel.getModelCategory();
        try {
            switch (modelCategory) {
                case Binomial: 
                case Regression: {
                    for (double[] contribs : shapContribs) {
                        RowData rowData = new RowData();
                        String[] columns = this.genModel.features();
                        int colCount = this.iindices.length;
                        for (int colIdx = 0; colIdx < colCount; ++colIdx) {
                            int iidx = this.iindices[colIdx];
                            String key = columns[iidx];
                            String input = String.valueOf(inputs[colIdx]);
                            rowData.put((Object)key, (Object)input);
                        }
                        float[] contrib_preds = this.easyPredictModelWrapper.predictContributions(rowData);
                        for (int j = 0; j < contrib_preds.length; ++j) {
                            contribs[j] = contrib_preds[j];
                        }
                        assert (inputs.length + 1 == contrib_preds.length);
                    }
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported ModelCategory: " + modelCategory.toString());
                }
            }
        }
        catch (Exception e) {
            if (Debug.getPrintH2O3Exceptions()) {
                e.printStackTrace();
            }
            throw new UnsupportedOperationException(String.format("%s failed with %s: %s", modelCategory, e.getClass().getName(), e.getMessage()));
        }
    }
}

