/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.predict;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.util.FVec;
import hex.DataInfo;
import hex.Model;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.predict.MutableOneHotEncoderFVec;
import hex.tree.xgboost.predict.PredictorFactory;
import water.DKV;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.VecUtils;

public abstract class AssignLeafNodeTask
extends MRTask<AssignLeafNodeTask> {
    protected final Predictor _p;
    protected final String[] _names;
    private final DataInfo _di;
    private final boolean _sparse;
    private byte _resultType;

    protected AssignLeafNodeTask(DataInfo di, XGBoostOutput output, byte[] boosterBytes, byte resultType) {
        this._p = PredictorFactory.makePredictor(boosterBytes, false);
        this._di = di;
        this._sparse = output._sparse;
        this._names = this.makeNames(output._ntrees, output.nclasses());
        this._resultType = resultType;
    }

    protected abstract void assignNodes(FVec var1, NewChunk[] var2);

    private String[] makeNames(int ntrees, int nclass) {
        nclass = nclass > 2 ? nclass : 1;
        String[] names = new String[ntrees * nclass];
        for (int t = 0; t < ntrees; ++t) {
            for (int c = 0; c < nclass; ++c) {
                names[t * nclass + c] = "T" + (t + 1) + ".C" + (c + 1);
            }
        }
        return names;
    }

    public void map(Chunk[] chks, NewChunk[] idx) {
        MutableOneHotEncoderFVec inputVec = new MutableOneHotEncoderFVec(this._di, this._sparse);
        double[] input = new double[chks.length];
        for (int row = 0; row < chks[0]._len; ++row) {
            for (int i = 0; i < chks.length; ++i) {
                input[i] = chks[i].atd(row);
            }
            inputVec.setInput(input);
            this.assignNodes(inputVec, idx);
        }
    }

    public Frame execute(Frame adaptFrm, Key<Frame> destKey) {
        return ((AssignLeafNodeTask)this.doAll(this._names.length, this._resultType, adaptFrm)).outputFrame(destKey, this._names, null);
    }

    public static AssignLeafNodeTask make(DataInfo di, XGBoostOutput output, byte[] boosterBytes, Model.LeafNodeAssignment.LeafNodeAssignmentType type) {
        switch (type) {
            case Path: {
                return new AssignTreePathTask(di, output, boosterBytes);
            }
            case Node_ID: {
                return new AssignLeafNodeIdTask(di, output, boosterBytes);
            }
        }
        throw new UnsupportedOperationException("Unknown leaf node assignment type: " + type);
    }

    static class AssignLeafNodeIdTask
    extends AssignLeafNodeTask {
        public AssignLeafNodeIdTask(DataInfo di, XGBoostOutput output, byte[] boosterBytes) {
            super(di, output, boosterBytes, (byte)3);
        }

        @Override
        protected void assignNodes(FVec input, NewChunk[] outs) {
            int[] leafIdx = this._p.getBooster().predictLeaf(input, 0);
            for (int i = 0; i < leafIdx.length; ++i) {
                outs[i].addNum((double)leafIdx[i]);
            }
        }
    }

    static class AssignTreePathTask
    extends AssignLeafNodeTask {
        public AssignTreePathTask(DataInfo di, XGBoostOutput output, byte[] boosterBytes) {
            super(di, output, boosterBytes, (byte)2);
        }

        @Override
        protected void assignNodes(FVec input, NewChunk[] outs) {
            String[] leafPaths = this._p.predictLeafPath(input);
            for (int i = 0; i < leafPaths.length; ++i) {
                outs[i].addStr((Object)leafPaths[i]);
            }
        }

        @Override
        public Frame execute(Frame adaptFrm, Key<Frame> destKey) {
            Frame res = super.execute(adaptFrm, destKey);
            Vec[] nvecs = new Vec[res.vecs().length];
            for (int c = 0; c < res.vecs().length; ++c) {
                Vec vv = res.vec(c);
                try {
                    nvecs[c] = vv.toCategoricalVec();
                    continue;
                }
                catch (Exception e) {
                    VecUtils.deleteVecs((Vec[])nvecs, (int)c);
                    throw e;
                }
            }
            res.delete();
            res = new Frame(destKey, this._names, nvecs);
            DKV.put((Keyed)res);
            return res;
        }
    }
}

