/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.xgboost4j.java;

import ai.h2o.xgboost4j.java.Booster;
import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.ExternalCheckpointManager;
import ai.h2o.xgboost4j.java.IEvaluation;
import ai.h2o.xgboost4j.java.IObjective;
import ai.h2o.xgboost4j.java.Rabit;
import ai.h2o.xgboost4j.java.XGBoostError;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;

public class XGBoost {
    private static final Log logger = LogFactory.getLog(XGBoost.class);

    public static Booster loadModel(String modelPath) throws XGBoostError {
        return Booster.loadModel(modelPath);
    }

    public static Booster loadModel(InputStream in) throws XGBoostError, IOException {
        return Booster.loadModel(in);
    }

    public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, IObjective obj, IEvaluation eval) throws XGBoostError {
        return XGBoost.train(dtrain, params, round, watches, null, obj, eval, 0);
    }

    public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval, int earlyStoppingRound) throws XGBoostError {
        return XGBoost.train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
    }

    private static void saveCheckpoint(Booster booster, int iter2, Set<Integer> checkpointIterations, ExternalCheckpointManager ecm) throws XGBoostError {
        try {
            if (checkpointIterations.contains(iter2)) {
                ecm.updateCheckpoint(booster);
            }
        }
        catch (Exception e2) {
            logger.error((Object)("failed to save checkpoint in XGBoost4J at iteration " + iter2), (Throwable)e2);
            throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter2, e2);
        }
    }

    public static Booster trainAndSaveCheckpoint(DMatrix dtrain, Map<String, Object> params, int numRounds, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval, int earlyStoppingRounds, Booster booster, int checkpointInterval, String checkpointPath, FileSystem fs) throws XGBoostError, IOException {
        DMatrix[] allMats;
        ArrayList<String> names = new ArrayList<String>();
        ArrayList<DMatrix> mats = new ArrayList<DMatrix>();
        HashSet<Integer> checkpointIterations = new HashSet();
        ExternalCheckpointManager ecm = null;
        if (checkpointPath != null) {
            ecm = new ExternalCheckpointManager(checkpointPath, fs);
        }
        for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
            names.add(evalEntry.getKey());
            mats.add(evalEntry.getValue());
        }
        String[] evalNames = names.toArray(new String[names.size()]);
        DMatrix[] evalMats = mats.toArray(new DMatrix[mats.size()]);
        float bestScore = XGBoost.isMaximizeEvaluation(params) ? -3.4028235E38f : Float.MAX_VALUE;
        int bestIteration = 0;
        float[][] fArray = metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
        if (evalMats.length > 0) {
            allMats = new DMatrix[evalMats.length + 1];
            allMats[0] = dtrain;
            System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
        } else {
            allMats = new DMatrix[]{dtrain};
        }
        if (booster == null) {
            booster = Booster.newBooster(params, allMats);
            booster.loadRabitCheckpoint();
        } else {
            booster.setParams(params);
        }
        if (ecm != null) {
            checkpointIterations = new HashSet<Integer>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
        }
        for (int iter2 = booster.getVersion() / 2; iter2 < numRounds; ++iter2) {
            if (booster.getVersion() % 2 == 0) {
                if (obj != null) {
                    booster.update(dtrain, obj);
                } else {
                    booster.update(dtrain, iter2);
                }
                XGBoost.saveCheckpoint(booster, iter2, checkpointIterations, ecm);
                booster.saveRabitCheckpoint();
            }
            if (evalMats.length > 0) {
                float[] metricsOut = new float[evalMats.length];
                String evalInfo = eval != null ? booster.evalSet(evalMats, evalNames, eval, metricsOut) : booster.evalSet(evalMats, evalNames, iter2, metricsOut);
                for (int i2 = 0; i2 < metricsOut.length; ++i2) {
                    metrics[i2][iter2] = metricsOut[i2];
                }
                float score = metricsOut[metricsOut.length - 1];
                if (XGBoost.isMaximizeEvaluation(params)) {
                    if (score > bestScore) {
                        bestScore = score;
                        bestIteration = iter2;
                    }
                } else if (score < bestScore) {
                    bestScore = score;
                    bestIteration = iter2;
                }
                if (earlyStoppingRounds > 0 && XGBoost.shouldEarlyStop(earlyStoppingRounds, iter2, bestIteration)) {
                    Rabit.trackerPrint(String.format("early stopping after %d rounds away from the best iteration", earlyStoppingRounds));
                    break;
                }
                if (Rabit.getRank() == 0 && XGBoost.shouldPrint(params, iter2) && XGBoost.shouldPrint(params, iter2)) {
                    Rabit.trackerPrint(evalInfo + '\n');
                }
            }
            booster.saveRabitCheckpoint();
        }
        return booster;
    }

    public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval, int earlyStoppingRounds, Booster booster) throws XGBoostError {
        try {
            return XGBoost.trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRounds, booster, -1, null, null);
        }
        catch (IOException e2) {
            logger.error((Object)"training failed in xgboost4j", (Throwable)e2);
            throw new XGBoostError("training failed in xgboost4j ", e2);
        }
    }

    private static Integer tryGetIntFromObject(Object o2) {
        if (o2 instanceof Integer) {
            return (int)((Integer)o2);
        }
        if (o2 instanceof String) {
            try {
                return Integer.parseInt((String)o2);
            }
            catch (NumberFormatException e2) {
                return null;
            }
        }
        return null;
    }

    private static boolean shouldPrint(Map<String, Object> params, int iter2) {
        Object silent = params.get("silent");
        Integer silentInt = XGBoost.tryGetIntFromObject(silent);
        if (silent != null && (silent.equals("true") || silent.equals("True") || silentInt != null && silentInt != 0)) {
            return false;
        }
        Object verboseEval = params.get("verbose_eval");
        Integer verboseEvalInt = XGBoost.tryGetIntFromObject(verboseEval);
        if (verboseEval == null) {
            return true;
        }
        if (verboseEval.equals("false") || verboseEval.equals("False")) {
            return false;
        }
        if (verboseEvalInt != null) {
            if (verboseEvalInt == 0) {
                return false;
            }
            return iter2 % verboseEvalInt == 0;
        }
        return true;
    }

    static boolean shouldEarlyStop(int earlyStoppingRounds, int iter2, int bestIteration) {
        return iter2 - bestIteration >= earlyStoppingRounds;
    }

    private static boolean isMaximizeEvaluation(Map<String, Object> params) {
        try {
            String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
            assert (maximize != null);
            return Boolean.valueOf(maximize);
        }
        catch (Exception ex) {
            logger.error((Object)"maximize_evaluation_metrics has to be specified for enabling early stop, allowed value: true/false", (Throwable)ex);
            throw ex;
        }
    }

    public static String[] crossValidation(DMatrix data, Map<String, Object> params, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) throws XGBoostError {
        CVPack[] cvPacks = XGBoost.makeNFold(data, nfold, params, metrics);
        String[] evalHist = new String[round];
        String[] results = new String[cvPacks.length];
        for (int i2 = 0; i2 < round; ++i2) {
            for (CVPack cvPack : cvPacks) {
                if (obj != null) {
                    cvPack.update(obj);
                    continue;
                }
                cvPack.update(i2);
            }
            for (int j2 = 0; j2 < cvPacks.length; ++j2) {
                results[j2] = eval != null ? cvPacks[j2].eval(eval) : cvPacks[j2].eval(i2);
            }
            evalHist[i2] = XGBoost.aggCVResults(results);
            logger.info((Object)evalHist[i2]);
        }
        return evalHist;
    }

    private static CVPack[] makeNFold(DMatrix data, int nfold, Map<String, Object> params, String[] evalMetrics) throws XGBoostError {
        List<Integer> samples = XGBoost.genRandPermutationNums(0, (int)data.rowNum());
        int step = samples.size() / nfold;
        int[] testSlice = new int[step];
        int[] trainSlice = new int[samples.size() - step];
        CVPack[] cvPacks = new CVPack[nfold];
        for (int i2 = 0; i2 < nfold; ++i2) {
            int testid = 0;
            int trainid = 0;
            for (int j2 = 0; j2 < samples.size(); ++j2) {
                if (j2 > i2 * step && j2 < i2 * step + step && testid < step) {
                    testSlice[testid] = samples.get(j2);
                    ++testid;
                    continue;
                }
                if (trainid < samples.size() - step) {
                    trainSlice[trainid] = samples.get(j2);
                    ++trainid;
                    continue;
                }
                testSlice[testid] = samples.get(j2);
                ++testid;
            }
            DMatrix dtrain = data.slice(trainSlice);
            DMatrix dtest = data.slice(testSlice);
            CVPack cvPack = new CVPack(dtrain, dtest, params);
            if (evalMetrics != null) {
                for (String type : evalMetrics) {
                    cvPack.booster.setParam("eval_metric", type);
                }
            }
            cvPacks[i2] = cvPack;
        }
        return cvPacks;
    }

    private static List<Integer> genRandPermutationNums(int start, int end) {
        ArrayList<Integer> samples = new ArrayList<Integer>();
        for (int i2 = start; i2 < end; ++i2) {
            samples.add(i2);
        }
        Collections.shuffle(samples);
        return samples;
    }

    private static String aggCVResults(String[] results) {
        HashMap cvMap = new HashMap();
        String aggResult = results[0].split("\t")[0];
        for (String result : results) {
            String[] items = result.split("\t");
            for (int i2 = 1; i2 < items.length; ++i2) {
                String[] tup = items[i2].split(":");
                String key = tup[0];
                Float value = Float.valueOf(tup[1]);
                if (!cvMap.containsKey(key)) {
                    cvMap.put(key, new ArrayList());
                }
                ((List)cvMap.get(key)).add(value);
            }
        }
        for (String key : cvMap.keySet()) {
            float value = 0.0f;
            for (Float tvalue : (List)cvMap.get(key)) {
                value += tvalue.floatValue();
            }
            aggResult = aggResult + String.format("\tcv-%s:%f", key, Float.valueOf(value /= (float)((List)cvMap.get(key)).size()));
        }
        return aggResult;
    }

    private static class CVPack {
        DMatrix dtrain;
        DMatrix dtest;
        DMatrix[] dmats;
        String[] names;
        Booster booster;

        public CVPack(DMatrix dtrain, DMatrix dtest, Map<String, Object> params) throws XGBoostError {
            this.dmats = new DMatrix[]{dtrain, dtest};
            this.booster = Booster.newBooster(params, this.dmats);
            this.names = new String[]{"train", "test"};
            this.dtrain = dtrain;
            this.dtest = dtest;
        }

        public void update(int iter2) throws XGBoostError {
            this.booster.update(this.dtrain, iter2);
        }

        public void update(IObjective obj) throws XGBoostError {
            this.booster.update(this.dtrain, obj);
        }

        public String eval(int iter2) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, iter2);
        }

        public String eval(IEvaluation eval) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, eval);
        }
    }
}

