/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class XGBoost {
    private static final Log logger = LogFactory.getLog(XGBoost.class);

    public static Booster loadModel(String string) throws XGBoostError {
        return Booster.loadModel(string);
    }

    public static Booster loadModel(InputStream inputStream) throws XGBoostError, IOException {
        return Booster.loadModel(inputStream);
    }

    public static Booster train(DMatrix dMatrix, Map<String, Object> map, int n, Map<String, DMatrix> map2, IObjective iObjective, IEvaluation iEvaluation) throws XGBoostError {
        return XGBoost.train(dMatrix, map, n, map2, null, iObjective, iEvaluation, 0);
    }

    public static Booster train(DMatrix dMatrix, Map<String, Object> map, int n, Map<String, DMatrix> map2, float[][] fArray, IObjective iObjective, IEvaluation iEvaluation, int n2) throws XGBoostError {
        return XGBoost.train(dMatrix, map, n, map2, fArray, iObjective, iEvaluation, n2, null);
    }

    public static Booster train(DMatrix dMatrix, Map<String, Object> map, int n, Map<String, DMatrix> map2, float[][] fArray, IObjective iObjective, IEvaluation iEvaluation, int n2, Booster booster) throws XGBoostError {
        DMatrix[] dMatrixArray;
        ArrayList<String> arrayList = new ArrayList<String>();
        ArrayList<DMatrix> arrayList2 = new ArrayList<DMatrix>();
        for (Map.Entry<String, DMatrix> entry : map2.entrySet()) {
            arrayList.add(entry.getKey());
            arrayList2.add(entry.getValue());
        }
        String[] stringArray = arrayList.toArray(new String[arrayList.size()]);
        DMatrix[] dMatrixArray2 = arrayList2.toArray(new DMatrix[arrayList2.size()]);
        float[][] fArray2 = fArray = fArray == null ? new float[stringArray.length][n] : fArray;
        if (dMatrixArray2.length > 0) {
            dMatrixArray = new DMatrix[dMatrixArray2.length + 1];
            dMatrixArray[0] = dMatrix;
            System.arraycopy(dMatrixArray2, 0, dMatrixArray, 1, dMatrixArray2.length);
        } else {
            dMatrixArray = new DMatrix[]{dMatrix};
        }
        if (booster == null) {
            booster = new Booster(map, dMatrixArray);
            booster.loadRabitCheckpoint();
        } else {
            booster.setParams(map);
        }
        for (int i = booster.getVersion() / 2; i < n; ++i) {
            if (booster.getVersion() % 2 == 0) {
                if (iObjective != null) {
                    booster.update(dMatrix, iObjective);
                } else {
                    booster.update(dMatrix, i);
                }
                booster.saveRabitCheckpoint();
            }
            if (dMatrixArray2.length > 0) {
                int n3;
                float[] fArray3 = new float[dMatrixArray2.length];
                String string = iEvaluation != null ? booster.evalSet(dMatrixArray2, stringArray, iEvaluation, fArray3) : booster.evalSet(dMatrixArray2, stringArray, i, fArray3);
                for (n3 = 0; n3 < fArray3.length; ++n3) {
                    fArray[n3][i] = fArray3[n3];
                }
                if (n2 > 0 && (n3 = (int)(XGBoost.judgeIfTrainingOnTrack(map, n2, fArray, i) ? 1 : 0)) == 0) {
                    String string2 = XGBoost.getReversedDirection(map);
                    Rabit.trackerPrint(String.format("early stopping after %d %s rounds", n2, string2));
                    break;
                }
                if (Rabit.getRank() == 0) {
                    Rabit.trackerPrint(string + '\n');
                }
            }
            booster.saveRabitCheckpoint();
        }
        return booster;
    }

    static boolean judgeIfTrainingOnTrack(Map<String, Object> map, int n, float[][] fArray, int n2) {
        boolean bl = XGBoost.getMetricsExpectedDirection(map);
        boolean bl2 = false;
        float[] fArray2 = fArray[fArray.length - 1];
        for (int i = 0; i < Math.min(n2, n) - 1; ++i) {
            bl2 |= bl ? fArray2[n2 - i] >= fArray2[n2 - i - 1] : fArray2[n2 - i] <= fArray2[n2 - i - 1];
        }
        return bl2;
    }

    private static String getReversedDirection(Map<String, Object> map) {
        String string = null;
        if (Boolean.valueOf(String.valueOf(map.get("maximize_evaluation_metrics"))).booleanValue()) {
            string = "descending";
        } else if (!Boolean.valueOf(String.valueOf(map.get("maximize_evaluation_metrics"))).booleanValue()) {
            string = "ascending";
        }
        return string;
    }

    private static boolean getMetricsExpectedDirection(Map<String, Object> map) {
        try {
            String string = String.valueOf(map.get("maximize_evaluation_metrics"));
            assert (string != null);
            return Boolean.valueOf(string);
        }
        catch (Exception exception) {
            logger.error((Object)"maximize_evaluation_metrics has to be specified for enabling early stop, allowed value: true/false", (Throwable)exception);
            throw exception;
        }
    }

    public static String[] crossValidation(DMatrix dMatrix, Map<String, Object> map, int n, int n2, String[] stringArray, IObjective iObjective, IEvaluation iEvaluation) throws XGBoostError {
        CVPack[] cVPackArray = XGBoost.makeNFold(dMatrix, n2, map, stringArray);
        String[] stringArray2 = new String[n];
        String[] stringArray3 = new String[cVPackArray.length];
        for (int i = 0; i < n; ++i) {
            for (CVPack cVPack : cVPackArray) {
                if (iObjective != null) {
                    cVPack.update(iObjective);
                    continue;
                }
                cVPack.update(i);
            }
            for (int j = 0; j < cVPackArray.length; ++j) {
                stringArray3[j] = iEvaluation != null ? cVPackArray[j].eval(iEvaluation) : cVPackArray[j].eval(i);
            }
            stringArray2[i] = XGBoost.aggCVResults(stringArray3);
            logger.info((Object)stringArray2[i]);
        }
        return stringArray2;
    }

    private static CVPack[] makeNFold(DMatrix dMatrix, int n, Map<String, Object> map, String[] stringArray) throws XGBoostError {
        List<Integer> list = XGBoost.genRandPermutationNums(0, (int)dMatrix.rowNum());
        int n2 = list.size() / n;
        int[] nArray = new int[n2];
        int[] nArray2 = new int[list.size() - n2];
        CVPack[] cVPackArray = new CVPack[n];
        for (int i = 0; i < n; ++i) {
            int n3 = 0;
            int n4 = 0;
            for (int j = 0; j < list.size(); ++j) {
                if (j > i * n2 && j < i * n2 + n2 && n3 < n2) {
                    nArray[n3] = list.get(j);
                    ++n3;
                    continue;
                }
                if (n4 < list.size() - n2) {
                    nArray2[n4] = list.get(j);
                    ++n4;
                    continue;
                }
                nArray[n3] = list.get(j);
                ++n3;
            }
            DMatrix dMatrix2 = dMatrix.slice(nArray2);
            DMatrix dMatrix3 = dMatrix.slice(nArray);
            CVPack cVPack = new CVPack(dMatrix2, dMatrix3, map);
            if (stringArray != null) {
                for (String string : stringArray) {
                    cVPack.booster.setParam("eval_metric", string);
                }
            }
            cVPackArray[i] = cVPack;
        }
        return cVPackArray;
    }

    private static List<Integer> genRandPermutationNums(int n, int n2) {
        ArrayList<Integer> arrayList = new ArrayList<Integer>();
        for (int i = n; i < n2; ++i) {
            arrayList.add(i);
        }
        Collections.shuffle(arrayList);
        return arrayList;
    }

    private static String aggCVResults(String[] stringArray) {
        HashMap hashMap = new HashMap();
        String string = stringArray[0].split("\t")[0];
        for (String string2 : stringArray) {
            Object object = string2.split("\t");
            for (int i = 1; i < ((String[])object).length; ++i) {
                String[] stringArray2 = object[i].split(":");
                String string3 = stringArray2[0];
                Float f = Float.valueOf(stringArray2[1]);
                if (!hashMap.containsKey(string3)) {
                    hashMap.put(string3, new ArrayList());
                }
                ((List)hashMap.get(string3)).add(f);
            }
        }
        for (String string4 : hashMap.keySet()) {
            float f = 0.0f;
            for (Object object : (List)hashMap.get(string4)) {
                f += ((Float)object).floatValue();
            }
            string = string + String.format("\tcv-%s:%f", string4, Float.valueOf(f /= (float)((List)hashMap.get(string4)).size()));
        }
        return string;
    }

    private static class CVPack {
        DMatrix dtrain;
        DMatrix dtest;
        DMatrix[] dmats;
        String[] names;
        Booster booster;

        public CVPack(DMatrix dMatrix, DMatrix dMatrix2, Map<String, Object> map) throws XGBoostError {
            this.dmats = new DMatrix[]{dMatrix, dMatrix2};
            this.booster = new Booster(map, this.dmats);
            this.names = new String[]{"train", "test"};
            this.dtrain = dMatrix;
            this.dtest = dMatrix2;
        }

        public void update(int n) throws XGBoostError {
            this.booster.update(this.dtrain, n);
        }

        public void update(IObjective iObjective) throws XGBoostError {
            this.booster.update(this.dtrain, iObjective);
        }

        public String eval(int n) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, n);
        }

        public String eval(IEvaluation iEvaluation) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, iEvaluation);
        }
    }
}

