/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.tree.SharedTreeModel;
import hex.tree.TreeVisitor;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Map;
import java.util.Random;
import water.Key;
import water.Keyed;
import water.util.IcedBitSet;
import water.util.SB;

public class CompressedTree
extends Keyed<CompressedTree> {
    private static final String KEY_PREFIX = "tree_";
    final byte[] _bits;
    final long _seed;

    public CompressedTree(byte[] bits, long seed, int tid, int cls) {
        super(CompressedTree.makeTreeKey(tid, cls));
        this._bits = bits;
        this._seed = seed;
    }

    private CompressedTree(Key<CompressedTree> key, byte[] bits, long seed) {
        super(key);
        this._bits = bits;
        this._seed = seed;
    }

    public double score(double[] row, String[][] domains) {
        return SharedTreeMojoModel.scoreTree(this._bits, row, false, domains);
    }

    @Deprecated
    public String getDecisionPath(double[] row, String[][] domains) {
        double d2 = SharedTreeMojoModel.scoreTree(this._bits, row, true, domains);
        return SharedTreeMojoModel.getDecisionPath(d2);
    }

    public <T> T getDecisionPath(double[] row, String[][] domains, SharedTreeMojoModel.DecisionPathTracker<T> tr) {
        double d2 = SharedTreeMojoModel.scoreTree(this._bits, row, true, domains);
        return SharedTreeMojoModel.getDecisionPath(d2, tr);
    }

    public Map<Integer, SharedTreeMojoModel.AuxInfo> toAuxInfos() {
        return SharedTreeMojoModel.readAuxInfos(this._bits);
    }

    public int findMaxNodeId() {
        return SharedTreeMojoModel.findMaxNodeId(this._bits);
    }

    public CompressedTree updateLeafNodeWeights(double[] leafNodeWeights) {
        Map<Integer, SharedTreeMojoModel.AuxInfo> nodeIdToAuxInfo = SharedTreeMojoModel.readAuxInfos(this._bits);
        ArrayList<SharedTreeMojoModel.AuxInfo> auxInfos = new ArrayList<SharedTreeMojoModel.AuxInfo>(nodeIdToAuxInfo.values());
        auxInfos.sort(Comparator.comparingInt(o2 -> -o2.pid));
        for (SharedTreeMojoModel.AuxInfo auxInfo : auxInfos) {
            auxInfo.weightL = 0.0f;
            auxInfo.weightR = 0.0f;
        }
        for (SharedTreeMojoModel.AuxInfo auxInfo : auxInfos) {
            auxInfo.weightL += (float)leafNodeWeights[auxInfo.nidL];
            auxInfo.weightR += (float)leafNodeWeights[auxInfo.nidR];
            if (auxInfo.pid < 0) continue;
            SharedTreeMojoModel.AuxInfo parentInfo = nodeIdToAuxInfo.get(auxInfo.pid);
            float nodeWeight = auxInfo.weightL + auxInfo.weightR;
            if (parentInfo.nidL == auxInfo.nid) {
                parentInfo.weightL += nodeWeight;
                continue;
            }
            parentInfo.weightR += nodeWeight;
        }
        ByteBuffer bb = ByteBuffer.allocate(this._bits.length).order(ByteOrder.nativeOrder());
        SharedTreeMojoModel.writeUpdatedAuxInfos(this._bits, nodeIdToAuxInfo, bb);
        byte[] bits = bb.array();
        return new CompressedTree(this._key, bits, this._seed);
    }

    public boolean hasZeroWeight() {
        return SharedTreeMojoModel.readAuxInfos(this._bits).values().stream().anyMatch(auxInfo -> auxInfo.weightL == 0.0f || auxInfo.weightR == 0.0f);
    }

    public SharedTreeSubgraph toSharedTreeSubgraph(CompressedTree auxTreeInfo, String[] colNames, String[][] domains) {
        TreeCoords tc = this.getTreeCoords();
        String treeName = SharedTreeMojoModel.treeName(tc._treeId, tc._clazz, domains[domains.length - 1]);
        return SharedTreeMojoModel.computeTreeGraph(tc._treeId, treeName, this._bits, auxTreeInfo._bits, colNames, domains);
    }

    public Random rngForChunk(int cidx) {
        Random rand = new Random(this._seed);
        for (int i2 = 0; i2 < cidx; ++i2) {
            rand.nextLong();
        }
        long seed = rand.nextLong();
        return new Random(seed);
    }

    public String toString(SharedTreeModel.SharedTreeOutput tm) {
        final String[] names = tm._names;
        final SB sb = new SB();
        new TreeVisitor<RuntimeException>(this){

            @Override
            protected void pre(int col, float fcmp, IcedBitSet gcmp, int equal, int naSplitDirInt) {
                if (naSplitDirInt == DhnasdNaVsRest) {
                    sb.p("!Double.isNaN(" + sb.i().p(names[col]).p(")"));
                } else if (naSplitDirInt == DhnasdNaLeft) {
                    sb.p("Double.isNaN(" + sb.i().p(names[col]).p(") || "));
                } else if (equal == 1) {
                    sb.p("!Double.isNaN(" + sb.i().p(names[col]).p(") && "));
                }
                if (naSplitDirInt != DhnasdNaVsRest) {
                    sb.i().p(names[col]).p(' ');
                    if (equal == 0) {
                        sb.p("< ").pj(fcmp);
                    } else if (equal == 1) {
                        sb.p("!=").pj(fcmp);
                    } else {
                        sb.p("in ").p(gcmp);
                    }
                }
                sb.ii(1).nl();
            }

            @Override
            protected void post(int col, float fcmp, int equal) {
                sb.di(1);
            }

            @Override
            protected void leaf(float pred) {
                sb.i().p("return ").pj(pred).nl();
            }
        }.visit();
        return sb.toString();
    }

    public static Key<CompressedTree> makeTreeKey(int treeId, int clazz) {
        return Key.makeSystem(KEY_PREFIX + treeId + "_" + clazz + "_" + Key.rand());
    }

    TreeCoords getTreeCoords() {
        return TreeCoords.parseTreeCoords(this._key);
    }

    @Override
    protected long checksum_impl() {
        throw new UnsupportedOperationException();
    }

    static class TreeCoords {
        int _treeId;
        int _clazz;

        TreeCoords() {
        }

        private static TreeCoords parseTreeCoords(Key<CompressedTree> ctKey) {
            String key = ctKey.toString();
            int prefixIdx = key.indexOf(CompressedTree.KEY_PREFIX);
            if (prefixIdx < 0) {
                throw new IllegalStateException("Unexpected structure of a CompressedTree key=" + key);
            }
            String[] keyParts = key.substring(prefixIdx + CompressedTree.KEY_PREFIX.length()).split("_", 3);
            TreeCoords tc = new TreeCoords();
            tc._treeId = Integer.valueOf(keyParts[0]);
            tc._clazz = Integer.valueOf(keyParts[1]);
            return tc;
        }
    }
}

