/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class XGBoost {
    private static final Log logger = LogFactory.getLog(XGBoost.class);

    public static Booster loadModel(String modelPath) throws XGBoostError {
        return Booster.loadModel(modelPath);
    }

    public static Booster loadModel(InputStream in) throws XGBoostError, IOException {
        return Booster.loadModel(in);
    }

    public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, IObjective obj, IEvaluation eval2) throws XGBoostError {
        return XGBoost.train(dtrain, params, round, watches, null, obj, eval2, 0);
    }

    public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval2, int earlyStoppingRound) throws XGBoostError {
        return XGBoost.train(dtrain, params, round, watches, metrics, obj, eval2, earlyStoppingRound, null);
    }

    public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval2, int earlyStoppingRounds, Booster booster) throws XGBoostError {
        DMatrix[] allMats;
        ArrayList<String> names = new ArrayList<String>();
        ArrayList<DMatrix> mats = new ArrayList<DMatrix>();
        for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
            names.add(evalEntry.getKey());
            mats.add(evalEntry.getValue());
        }
        String[] evalNames = names.toArray(new String[names.size()]);
        DMatrix[] evalMats = mats.toArray(new DMatrix[mats.size()]);
        float bestScore = XGBoost.isMaximizeEvaluation(params) ? -3.4028235E38f : Float.MAX_VALUE;
        int bestIteration = 0;
        float[][] fArray = metrics = metrics == null ? new float[evalNames.length][round] : metrics;
        if (evalMats.length > 0) {
            allMats = new DMatrix[evalMats.length + 1];
            allMats[0] = dtrain;
            System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
        } else {
            allMats = new DMatrix[]{dtrain};
        }
        if (booster == null) {
            booster = Booster.newBooster(params, allMats);
            booster.loadRabitCheckpoint();
        } else {
            booster.setParams(params);
        }
        for (int iter = booster.getVersion() / 2; iter < round; ++iter) {
            if (booster.getVersion() % 2 == 0) {
                if (obj != null) {
                    booster.update(dtrain, obj);
                } else {
                    booster.update(dtrain, iter);
                }
                booster.saveRabitCheckpoint();
            }
            if (evalMats.length > 0) {
                float[] metricsOut = new float[evalMats.length];
                String evalInfo = eval2 != null ? booster.evalSet(evalMats, evalNames, eval2, metricsOut) : booster.evalSet(evalMats, evalNames, iter, metricsOut);
                for (int i = 0; i < metricsOut.length; ++i) {
                    metrics[i][iter] = metricsOut[i];
                }
                float score = metricsOut[metricsOut.length - 1];
                if (XGBoost.isMaximizeEvaluation(params)) {
                    if (score > bestScore) {
                        bestScore = score;
                        bestIteration = iter;
                    }
                } else if (score < bestScore) {
                    bestScore = score;
                    bestIteration = iter;
                }
                if (earlyStoppingRounds > 0 && XGBoost.shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
                    Rabit.trackerPrint(String.format("early stopping after %d rounds away from the best iteration", earlyStoppingRounds));
                    break;
                }
                if (Rabit.getRank() == 0 && XGBoost.shouldPrint(params, iter) && XGBoost.shouldPrint(params, iter)) {
                    Rabit.trackerPrint(evalInfo + '\n');
                }
            }
            booster.saveRabitCheckpoint();
        }
        return booster;
    }

    private static Integer tryGetIntFromObject(Object o) {
        if (o instanceof Integer) {
            return (int)((Integer)o);
        }
        if (o instanceof String) {
            try {
                return Integer.parseInt((String)o);
            }
            catch (NumberFormatException e) {
                return null;
            }
        }
        return null;
    }

    private static boolean shouldPrint(Map<String, Object> params, int iter) {
        Object silent = params.get("silent");
        Integer silentInt = XGBoost.tryGetIntFromObject(silent);
        if (silent != null && (silent.equals("true") || silent.equals("True") || silentInt != null && silentInt != 0)) {
            return false;
        }
        Object verboseEval = params.get("verbose_eval");
        Integer verboseEvalInt = XGBoost.tryGetIntFromObject(verboseEval);
        if (verboseEval == null) {
            return true;
        }
        if (verboseEval.equals("false") || verboseEval.equals("False")) {
            return false;
        }
        if (verboseEvalInt != null) {
            if (verboseEvalInt == 0) {
                return false;
            }
            return iter % verboseEvalInt == 0;
        }
        return true;
    }

    static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
        return iter - bestIteration >= earlyStoppingRounds;
    }

    private static boolean isMaximizeEvaluation(Map<String, Object> params) {
        try {
            String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
            assert (maximize != null);
            return Boolean.valueOf(maximize);
        }
        catch (Exception ex) {
            logger.error((Object)"maximize_evaluation_metrics has to be specified for enabling early stop, allowed value: true/false", (Throwable)ex);
            throw ex;
        }
    }

    public static String[] crossValidation(DMatrix data, Map<String, Object> params, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval2) throws XGBoostError {
        CVPack[] cvPacks = XGBoost.makeNFold(data, nfold, params, metrics);
        String[] evalHist = new String[round];
        String[] results = new String[cvPacks.length];
        for (int i = 0; i < round; ++i) {
            for (CVPack cvPack : cvPacks) {
                if (obj != null) {
                    cvPack.update(obj);
                    continue;
                }
                cvPack.update(i);
            }
            for (int j = 0; j < cvPacks.length; ++j) {
                results[j] = eval2 != null ? cvPacks[j].eval(eval2) : cvPacks[j].eval(i);
            }
            evalHist[i] = XGBoost.aggCVResults(results);
            logger.info((Object)evalHist[i]);
        }
        return evalHist;
    }

    private static CVPack[] makeNFold(DMatrix data, int nfold, Map<String, Object> params, String[] evalMetrics) throws XGBoostError {
        List<Integer> samples = XGBoost.genRandPermutationNums(0, (int)data.rowNum());
        int step = samples.size() / nfold;
        int[] testSlice = new int[step];
        int[] trainSlice = new int[samples.size() - step];
        CVPack[] cvPacks = new CVPack[nfold];
        for (int i = 0; i < nfold; ++i) {
            int testid = 0;
            int trainid = 0;
            for (int j = 0; j < samples.size(); ++j) {
                if (j > i * step && j < i * step + step && testid < step) {
                    testSlice[testid] = samples.get(j);
                    ++testid;
                    continue;
                }
                if (trainid < samples.size() - step) {
                    trainSlice[trainid] = samples.get(j);
                    ++trainid;
                    continue;
                }
                testSlice[testid] = samples.get(j);
                ++testid;
            }
            DMatrix dtrain = data.slice(trainSlice);
            DMatrix dtest = data.slice(testSlice);
            CVPack cvPack = new CVPack(dtrain, dtest, params);
            if (evalMetrics != null) {
                for (String type : evalMetrics) {
                    cvPack.booster.setParam("eval_metric", type);
                }
            }
            cvPacks[i] = cvPack;
        }
        return cvPacks;
    }

    private static List<Integer> genRandPermutationNums(int start2, int end) {
        ArrayList<Integer> samples = new ArrayList<Integer>();
        for (int i = start2; i < end; ++i) {
            samples.add(i);
        }
        Collections.shuffle(samples);
        return samples;
    }

    private static String aggCVResults(String[] results) {
        HashMap cvMap = new HashMap();
        String aggResult = results[0].split("\t")[0];
        for (String result : results) {
            String[] items = result.split("\t");
            for (int i = 1; i < items.length; ++i) {
                String[] tup = items[i].split(":");
                String key = tup[0];
                Float value = Float.valueOf(tup[1]);
                if (!cvMap.containsKey(key)) {
                    cvMap.put(key, new ArrayList());
                }
                ((List)cvMap.get(key)).add(value);
            }
        }
        for (String key : cvMap.keySet()) {
            float value = 0.0f;
            for (Float tvalue : (List)cvMap.get(key)) {
                value += tvalue.floatValue();
            }
            aggResult = aggResult + String.format("\tcv-%s:%f", key, Float.valueOf(value /= (float)((List)cvMap.get(key)).size()));
        }
        return aggResult;
    }

    private static class CVPack {
        DMatrix dtrain;
        DMatrix dtest;
        DMatrix[] dmats;
        String[] names;
        Booster booster;

        public CVPack(DMatrix dtrain, DMatrix dtest, Map<String, Object> params) throws XGBoostError {
            this.dmats = new DMatrix[]{dtrain, dtest};
            this.booster = Booster.newBooster(params, this.dmats);
            this.names = new String[]{"train", "test"};
            this.dtrain = dtrain;
            this.dtest = dtest;
        }

        public void update(int iter) throws XGBoostError {
            this.booster.update(this.dtrain, iter);
        }

        public void update(IObjective obj) throws XGBoostError {
            this.booster.update(this.dtrain, obj);
        }

        public String eval(int iter) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, iter);
        }

        public String eval(IEvaluation eval2) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, eval2);
        }
    }
}

