/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import ai.h2o.algos.tree.INode;
import hex.ContributionsWithBackgroundFrameTask;
import hex.DistributionFactory;
import hex.Model;
import hex.genmodel.algos.tree.ContributionComposer;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.tree.TreeSHAP;
import hex.genmodel.algos.tree.TreeSHAPEnsemble;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.tree.SharedTreeModel;
import java.util.ArrayList;
import java.util.Arrays;
import water.DKV;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;
import water.util.Log;

public abstract class SharedTreeModelWithContributions<M extends SharedTreeModel<M, P, O>, P extends SharedTreeModel.SharedTreeParameters, O extends SharedTreeModel.SharedTreeOutput>
extends SharedTreeModel<M, P, O>
implements Model.Contributions {
    public SharedTreeModelWithContributions(Key<M> selfKey, P parms, O output) {
        super(selfKey, parms, output);
    }

    public Frame scoreContributions(Frame frame, Key<Frame> destination_key) {
        return this.scoreContributions(frame, destination_key, null);
    }

    protected Frame removeSpecialColumns(Frame frame) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._response_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._fold_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._weights_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._offset_column);
        return adaptFrm;
    }

    protected Frame removeSpecialNNonNumericColumns(Frame frame) {
        int numCols;
        Frame adaptFrm = this.removeSpecialColumns(frame);
        for (int index = numCols = adaptFrm.numCols() - 1; index >= 0; --index) {
            if (adaptFrm.vec(index).isNumeric()) continue;
            adaptFrm.remove(index);
        }
        return adaptFrm;
    }

    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j) {
        if (((SharedTreeModel.SharedTreeOutput)this._output).nclasses() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        Frame adaptFrm = this.removeSpecialColumns(frame);
        String[] outputNames = (String[])ArrayUtils.append((Object[])adaptFrm.names(), (Object[])new String[]{"BiasTerm"});
        return ((ScoreContributionsTask)this.getScoreContributionsTask(this).withPostMapAction((MRTask.PostMapAction)JobUpdatePostMap.forJob(j)).doAll(outputNames.length, (byte)3, adaptFrm)).outputFrame(destination_key, outputNames, null);
    }

    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j, Model.Contributions.ContributionsOptions options) {
        if (((SharedTreeModel.SharedTreeOutput)this._output).nclasses() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        if (!options.isSortingRequired()) {
            return this.scoreContributions(frame, destination_key, j);
        }
        Frame adaptFrm = this.removeSpecialColumns(frame);
        String[] contribNames = (String[])ArrayUtils.append((Object[])adaptFrm.names(), (Object[])new String[]{"BiasTerm"});
        ContributionComposer contributionComposer = new ContributionComposer();
        int topNAdjusted = contributionComposer.checkAndAdjustInput(options._topN, adaptFrm.names().length);
        int bottomNAdjusted = contributionComposer.checkAndAdjustInput(options._bottomN, adaptFrm.names().length);
        int outputSize = Math.min((topNAdjusted + bottomNAdjusted) * 2, adaptFrm.names().length * 2);
        String[] names = new String[outputSize + 1];
        byte[] types = new byte[outputSize + 1];
        String[][] domains = new String[outputSize + 1][contribNames.length];
        this.composeScoreContributionTaskMetadata(names, types, domains, adaptFrm.names(), options);
        return ((ScoreContributionsTask)this.getScoreContributionsSoringTask(this, options).withPostMapAction((MRTask.PostMapAction)JobUpdatePostMap.forJob(j)).doAll(types, adaptFrm)).outputFrame(destination_key, names, domains);
    }

    protected abstract ScoreContributionsWithBackgroundTask getScoreContributionsWithBackgroundTask(SharedTreeModel var1, Frame var2, Frame var3, boolean var4, int[] var5, Model.Contributions.ContributionsOptions var6);

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j, Model.Contributions.ContributionsOptions options, Frame backgroundFrame) {
        Frame frame2;
        Frame adaptedBgFrame;
        block17: {
            Frame adaptedFrame;
            block15: {
                Frame frame3;
                block16: {
                    if (backgroundFrame == null) {
                        return this.scoreContributions(frame, destination_key, j, options);
                    }
                    assert (!options.isSortingRequired());
                    if (((SharedTreeModel.SharedTreeOutput)this._output).nclasses() > 2) {
                        throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
                    }
                    Log.info((Object[])new Object[]{"Starting contributions calculation for " + this._key + "..."});
                    adaptedFrame = null;
                    adaptedBgFrame = null;
                    if (options._outputFormat != Model.Contributions.ContributionsOutputFormat.Compact && ((SharedTreeModel.SharedTreeOutput)this._output)._domains != null) break block15;
                    adaptedFrame = this.removeSpecialColumns(frame);
                    adaptedBgFrame = this.removeSpecialColumns(backgroundFrame);
                    DKV.put((Keyed)adaptedFrame);
                    DKV.put((Keyed)adaptedBgFrame);
                    String[] outputNames = (String[])ArrayUtils.append((Object[])adaptedFrame.names(), (Object[])new String[]{"BiasTerm"});
                    frame3 = this.getScoreContributionsWithBackgroundTask(this, adaptedFrame, adaptedBgFrame, false, null, options).runAndGetOutput(j, destination_key, outputNames);
                    if (null == adaptedFrame) break block16;
                    Frame.deleteTempFrameAndItsNonSharedVecs((Frame)adaptedFrame, (Frame)frame);
                }
                if (null != adaptedBgFrame) {
                    Frame.deleteTempFrameAndItsNonSharedVecs((Frame)adaptedBgFrame, (Frame)backgroundFrame);
                }
                Log.info((Object[])new Object[]{"Finished contributions calculation for " + this._key + "..."});
                return frame3;
            }
            try {
                adaptedFrame = this.removeSpecialColumns(frame);
                adaptedBgFrame = this.removeSpecialColumns(backgroundFrame);
                DKV.put((Keyed)adaptedFrame);
                DKV.put((Keyed)adaptedBgFrame);
                assert (Model.Parameters.CategoricalEncodingScheme.Enum.equals((Object)((SharedTreeModel.SharedTreeParameters)this._parms)._categorical_encoding)) : "Unsupported categorical encoding. Only enum is supported.";
                int[] catOffsets = new int[((SharedTreeModel.SharedTreeOutput)this._output)._domains.length + 1];
                int nCols = 1;
                for (int i = 0; i < ((SharedTreeModel.SharedTreeOutput)this._output)._domains.length; ++i) {
                    if (((SharedTreeModel.SharedTreeOutput)this._output)._names[i].equals(((SharedTreeModel.SharedTreeParameters)this._parms)._response_column) || ((SharedTreeModel.SharedTreeOutput)this._output)._names[i].equals(((SharedTreeModel.SharedTreeParameters)this._parms)._fold_column) || ((SharedTreeModel.SharedTreeOutput)this._output)._names[i].equals(((SharedTreeModel.SharedTreeParameters)this._parms)._weights_column) || ((SharedTreeModel.SharedTreeOutput)this._output)._names[i].equals(((SharedTreeModel.SharedTreeParameters)this._parms)._offset_column)) continue;
                    catOffsets[i + 1] = null == ((SharedTreeModel.SharedTreeOutput)this._output)._domains[i] ? catOffsets[i] + 1 : catOffsets[i] + ((SharedTreeModel.SharedTreeOutput)this._output)._domains[i].length + 1;
                    ++nCols;
                }
                catOffsets = Arrays.copyOf(catOffsets, nCols);
                String[] outputNames = new String[catOffsets[catOffsets.length - 1] + 1];
                outputNames[catOffsets[catOffsets.length - 1]] = "BiasTerm";
                int l = 0;
                for (int i = 0; i < ((SharedTreeModel.SharedTreeOutput)this._output)._names.length; ++i) {
                    if (((SharedTreeModel.SharedTreeOutput)this._output)._names[i].equals(((SharedTreeModel.SharedTreeParameters)this._parms)._response_column) || ((SharedTreeModel.SharedTreeOutput)this._output)._names[i].equals(((SharedTreeModel.SharedTreeParameters)this._parms)._fold_column) || ((SharedTreeModel.SharedTreeOutput)this._output)._names[i].equals(((SharedTreeModel.SharedTreeParameters)this._parms)._weights_column) || ((SharedTreeModel.SharedTreeOutput)this._output)._names[i].equals(((SharedTreeModel.SharedTreeParameters)this._parms)._offset_column)) continue;
                    if (null == ((SharedTreeModel.SharedTreeOutput)this._output)._domains[i]) {
                        outputNames[l++] = ((SharedTreeModel.SharedTreeOutput)this._output)._names[i];
                        continue;
                    }
                    for (int k = 0; k < ((SharedTreeModel.SharedTreeOutput)this._output)._domains[i].length; ++k) {
                        outputNames[l++] = ((SharedTreeModel.SharedTreeOutput)this._output)._names[i] + "." + ((SharedTreeModel.SharedTreeOutput)this._output)._domains[i][k];
                    }
                    outputNames[l++] = ((SharedTreeModel.SharedTreeOutput)this._output)._names[i] + ".missing(NA)";
                }
                frame2 = this.getScoreContributionsWithBackgroundTask(this, adaptedFrame, adaptedBgFrame, true, catOffsets, options).runAndGetOutput(j, destination_key, outputNames);
                if (null == adaptedFrame) break block17;
            }
            catch (Throwable throwable) {
                if (null != adaptedFrame) {
                    Frame.deleteTempFrameAndItsNonSharedVecs((Frame)adaptedFrame, (Frame)frame);
                }
                if (null != adaptedBgFrame) {
                    Frame.deleteTempFrameAndItsNonSharedVecs(adaptedBgFrame, (Frame)backgroundFrame);
                }
                Log.info((Object[])new Object[]{"Finished contributions calculation for " + this._key + "..."});
                throw throwable;
            }
            Frame.deleteTempFrameAndItsNonSharedVecs((Frame)adaptedFrame, (Frame)frame);
        }
        if (null != adaptedBgFrame) {
            Frame.deleteTempFrameAndItsNonSharedVecs((Frame)adaptedBgFrame, (Frame)backgroundFrame);
        }
        Log.info((Object[])new Object[]{"Finished contributions calculation for " + this._key + "..."});
        return frame2;
    }

    protected abstract ScoreContributionsTask getScoreContributionsTask(SharedTreeModel var1);

    protected abstract ScoreContributionsTask getScoreContributionsSoringTask(SharedTreeModel var1, Model.Contributions.ContributionsOptions var2);

    public class ScoreContributionsWithBackgroundTask
    extends ContributionsWithBackgroundFrameTask<ScoreContributionsWithBackgroundTask> {
        protected final Key<SharedTreeModel> _modelKey;
        protected transient SharedTreeModel _model;
        protected transient SharedTreeModel.SharedTreeOutput _output;
        protected transient TreeSHAPPredictor<double[]> _treeSHAP;
        protected boolean _expand;
        protected boolean _outputSpace;
        protected int[] _catOffsets;

        public ScoreContributionsWithBackgroundTask(Key<Frame> frKey, Key<Frame> backgroundFrameKey, boolean perReference, SharedTreeModel model, boolean expand, int[] catOffsets, boolean outputSpace) {
            super(frKey, backgroundFrameKey, perReference);
            this._modelKey = model._key;
            this._expand = expand;
            this._catOffsets = catOffsets;
            this._outputSpace = outputSpace;
        }

        protected void setupLocal() {
            this._model = (SharedTreeModel)this._modelKey.get();
            assert (this._model != null);
            this._output = (SharedTreeModel.SharedTreeOutput)this._model._output;
            assert (this._output != null);
            ArrayList<TreeSHAP> treeSHAPs = new ArrayList<TreeSHAP>(this._output._ntrees);
            for (int treeIdx = 0; treeIdx < this._output._ntrees; ++treeIdx) {
                for (int treeClass = 0; treeClass < this._output._treeKeys[treeIdx].length; ++treeClass) {
                    if (this._output._treeKeys[treeIdx][treeClass] == null) continue;
                    SharedTreeSubgraph tree = this._model.getSharedTreeSubgraph(treeIdx, treeClass);
                    SharedTreeNode[] nodes = tree.getNodes();
                    treeSHAPs.add(new TreeSHAP((INode[])nodes));
                }
            }
            assert (treeSHAPs.size() == this._output._ntrees);
            this._treeSHAP = new TreeSHAPEnsemble(treeSHAPs, (float)this._output._init_f);
        }

        protected void fillInput(Chunk[] chks, int row, double[] input) {
            for (int i = 0; i < chks.length; ++i) {
                input[i] = chks[i].atd(row);
            }
        }

        @Override
        public void map(Chunk[] cs, Chunk[] bgCs, NewChunk[] nc) {
            assert (cs.length <= nc.length - 1);
            double[] input = MemoryManager.malloc8d((int)cs.length);
            double[] inputBg = MemoryManager.malloc8d((int)bgCs.length);
            double[] contribs = MemoryManager.malloc8d((int)nc.length);
            for (int row = 0; row < cs[0]._len; ++row) {
                this.fillInput(cs, row, input);
                for (int bgRow = 0; bgRow < bgCs[0]._len; ++bgRow) {
                    Arrays.fill(contribs, 0.0);
                    this.fillInput(bgCs, bgRow, inputBg);
                    this._treeSHAP.calculateInterventionalContributions((Object)input, (Object)inputBg, contribs, this._catOffsets, this._expand);
                    this.doModelSpecificComputation(contribs);
                    this.addContribToNewChunk(contribs, nc);
                }
            }
        }

        protected void doModelSpecificComputation(double[] contribs) {
        }

        protected void addContribToNewChunk(double[] contribs, NewChunk[] nc) {
            double transformationRatio = 1.0;
            double biasTerm = contribs[contribs.length - 1];
            if (this._outputSpace) {
                double linkSpaceX = Arrays.stream(contribs).sum();
                double linkSpaceBg = biasTerm;
                double outSpaceX = DistributionFactory.getDistribution((Model.Parameters)SharedTreeModelWithContributions.this._parms).linkInv(linkSpaceX);
                double outSpaceBg = DistributionFactory.getDistribution((Model.Parameters)SharedTreeModelWithContributions.this._parms).linkInv(linkSpaceBg);
                transformationRatio = Math.abs(linkSpaceX - linkSpaceBg) < 1.0E-6 ? 0.0 : (outSpaceX - outSpaceBg) / (linkSpaceX - linkSpaceBg);
                biasTerm = outSpaceBg;
            }
            for (int i = 0; i < nc.length - 1; ++i) {
                nc[i].addNum(contribs[i] * transformationRatio);
            }
            nc[nc.length - 1].addNum(biasTerm);
        }
    }

    public class ScoreContributionsSortingTask
    extends ScoreContributionsTask {
        private final int _topN;
        private final int _bottomN;
        private final boolean _compareAbs;

        public ScoreContributionsSortingTask(SharedTreeModel model, Model.Contributions.ContributionsOptions options) {
            super(model);
            this._topN = options._topN;
            this._bottomN = options._bottomN;
            this._compareAbs = options._compareAbs;
        }

        protected void fillInput(Chunk[] chks, int row, double[] input, float[] contribs, int[] contribNameIds) {
            super.fillInput(chks, row, input, contribs);
            for (int i = 0; i < contribNameIds.length; ++i) {
                contribNameIds[i] = i;
            }
        }

        @Override
        public void map(Chunk[] chks, NewChunk[] nc) {
            double[] input = MemoryManager.malloc8d((int)chks.length);
            float[] contribs = MemoryManager.malloc4f((int)(chks.length + 1));
            int[] contribNameIds = MemoryManager.malloc4((int)(chks.length + 1));
            TreeSHAPPredictor.Workspace workspace = this._treeSHAP.makeWorkspace();
            for (int row = 0; row < chks[0]._len; ++row) {
                this.fillInput(chks, row, input, contribs, contribNameIds);
                this._treeSHAP.calculateContributions((Object)input, contribs, 0, -1, workspace);
                this.doModelSpecificComputation(contribs);
                ContributionComposer contributionComposer = new ContributionComposer();
                int[] contribNameIdsSorted = contributionComposer.composeContributions(contribNameIds, contribs, this._topN, this._bottomN, this._compareAbs);
                this.addContribToNewChunk(contribs, contribNameIdsSorted, nc);
            }
        }

        protected void addContribToNewChunk(float[] contribs, int[] contribNameIdsSorted, NewChunk[] nc) {
            int i = 0;
            int inputPointer = 0;
            while (i < nc.length - 1) {
                nc[i].addNum((double)contribNameIdsSorted[inputPointer]);
                nc[i + 1].addNum((double)contribs[contribNameIdsSorted[inputPointer]]);
                i += 2;
                ++inputPointer;
            }
            nc[nc.length - 1].addNum((double)contribs[contribs.length - 1]);
        }
    }

    public class ScoreContributionsTask
    extends MRTask<ScoreContributionsTask> {
        protected final Key<SharedTreeModel> _modelKey;
        protected transient SharedTreeModel _model;
        protected transient SharedTreeModel.SharedTreeOutput _output;
        protected transient TreeSHAPPredictor<double[]> _treeSHAP;

        public ScoreContributionsTask(SharedTreeModel model) {
            this._modelKey = model._key;
        }

        protected void setupLocal() {
            this._model = (SharedTreeModel)this._modelKey.get();
            assert (this._model != null);
            this._output = (SharedTreeModel.SharedTreeOutput)this._model._output;
            assert (this._output != null);
            ArrayList<TreeSHAP> treeSHAPs = new ArrayList<TreeSHAP>(this._output._ntrees);
            for (int treeIdx = 0; treeIdx < this._output._ntrees; ++treeIdx) {
                for (int treeClass = 0; treeClass < this._output._treeKeys[treeIdx].length; ++treeClass) {
                    if (this._output._treeKeys[treeIdx][treeClass] == null) continue;
                    SharedTreeSubgraph tree = this._model.getSharedTreeSubgraph(treeIdx, treeClass);
                    SharedTreeNode[] nodes = tree.getNodes();
                    treeSHAPs.add(new TreeSHAP((INode[])nodes));
                }
            }
            assert (treeSHAPs.size() == this._output._ntrees);
            this._treeSHAP = new TreeSHAPEnsemble(treeSHAPs, (float)this._output._init_f);
        }

        protected void fillInput(Chunk[] chks, int row, double[] input, float[] contribs) {
            for (int i = 0; i < chks.length; ++i) {
                input[i] = chks[i].atd(row);
            }
            Arrays.fill(contribs, 0.0f);
        }

        public void map(Chunk[] chks, NewChunk[] nc) {
            assert (chks.length == nc.length - 1);
            double[] input = MemoryManager.malloc8d((int)chks.length);
            float[] contribs = MemoryManager.malloc4f((int)nc.length);
            TreeSHAPPredictor.Workspace workspace = this._treeSHAP.makeWorkspace();
            for (int row = 0; row < chks[0]._len; ++row) {
                this.fillInput(chks, row, input, contribs);
                this._treeSHAP.calculateContributions((Object)input, contribs, 0, -1, workspace);
                this.doModelSpecificComputation(contribs);
                this.addContribToNewChunk(contribs, nc);
            }
        }

        protected void doModelSpecificComputation(float[] contribs) {
        }

        protected void addContribToNewChunk(float[] contribs, NewChunk[] nc) {
            for (int i = 0; i < nc.length; ++i) {
                nc[i].addNum((double)contribs[i]);
            }
        }
    }
}

