/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.predict;

import biz.k11i.xgboost.util.FVec;
import hex.DataInfo;
import hex.genmodel.GenModel;

public class MutableOneHotEncoderFVec
implements FVec {
    private final DataInfo _di;
    private final boolean _treatsZeroAsNA;
    private final int[] _catMap;
    private final int[] _catValues;
    private final float[] _numValues;
    private final float _notHot;

    public MutableOneHotEncoderFVec(DataInfo di, boolean treatsZeroAsNA) {
        this._di = di;
        this._catValues = new int[this._di._cats];
        this._treatsZeroAsNA = treatsZeroAsNA;
        float f = this._notHot = this._treatsZeroAsNA ? Float.NaN : 0.0f;
        if (this._di._catOffsets == null) {
            this._catMap = new int[0];
        } else {
            this._catMap = new int[this._di._catOffsets[this._di._cats]];
            for (int c = 0; c < this._di._cats; ++c) {
                for (int j = this._di._catOffsets[c]; j < this._di._catOffsets[c + 1]; ++j) {
                    this._catMap[j] = c;
                }
            }
        }
        this._numValues = new float[this._di._nums];
    }

    public void setInput(double[] input) {
        GenModel.setCats((double[])input, (int[])this._catValues, (int)this._di._cats, (int[])this._di._catOffsets, (boolean)this._di._useAllFactorLevels);
        for (int i = 0; i < this._numValues.length; ++i) {
            float val = (float)input[this._di._cats + i];
            this._numValues[i] = this._treatsZeroAsNA && val == 0.0f ? Float.NaN : val;
        }
    }

    public final float fvalue(int index) {
        if (index >= this._catMap.length) {
            return this._numValues[index - this._catMap.length];
        }
        boolean isHot = this._catValues[this._catMap[index]] == index;
        return isHot ? 1.0f : this._notHot;
    }

    public void decodeAggregate(float[] encoded, float[] output) {
        for (int c = 0; c < this._di._cats; ++c) {
            float sum = 0.0f;
            for (int i = this._di._catOffsets[c]; i < this._di._catOffsets[c + 1]; ++i) {
                sum += encoded[i];
            }
            output[c] = sum;
        }
        int numStart = this._di._catOffsets[this._di._cats];
        if (this._di._nums >= 0) {
            System.arraycopy(encoded, numStart, output, this._di._cats, this._di._nums);
        }
    }
}

