/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.Communicator;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.ExternalCheckpointManager;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;

public class XGBoost {
    private static final Log logger = LogFactory.getLog(XGBoost.class);
    public static final String[] MAXIMIZ_METRICES = new String[]{"auc", "aucpr", "pre", "pre@", "map", "ndcg", "auc@", "aucpr@", "map@", "ndcg@"};

    public static Booster loadModel(String string) throws XGBoostError {
        return Booster.loadModel(string);
    }

    public static Booster loadModel(InputStream inputStream) throws XGBoostError, IOException {
        int n;
        byte[] byArray = new byte[0x100000];
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        while ((n = inputStream.read(byArray)) != -1) {
            byteArrayOutputStream.write(byArray, 0, n);
        }
        inputStream.close();
        return Booster.loadModel(byteArrayOutputStream.toByteArray());
    }

    public static Booster loadModel(byte[] byArray) throws XGBoostError, IOException {
        return Booster.loadModel(byArray);
    }

    public static Booster train(DMatrix dMatrix, Map<String, Object> map, int n, Map<String, DMatrix> map2, IObjective iObjective, IEvaluation iEvaluation) throws XGBoostError {
        return XGBoost.train(dMatrix, map, n, map2, null, iObjective, iEvaluation, 0);
    }

    public static Booster train(DMatrix dMatrix, Map<String, Object> map, int n, Map<String, DMatrix> map2, float[][] fArray, IObjective iObjective, IEvaluation iEvaluation, int n2) throws XGBoostError {
        return XGBoost.train(dMatrix, map, n, map2, fArray, iObjective, iEvaluation, n2, null);
    }

    private static void saveCheckpoint(Booster booster, int n, Set<Integer> set, ExternalCheckpointManager externalCheckpointManager) throws XGBoostError {
        try {
            if (set.contains(n)) {
                externalCheckpointManager.updateCheckpoint(booster);
            }
        }
        catch (Exception exception) {
            logger.error((Object)("failed to save checkpoint in XGBoost4J at iteration " + n), (Throwable)exception);
            throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + n, exception);
        }
    }

    public static Booster trainAndSaveCheckpoint(DMatrix dMatrix, Map<String, Object> map, int n, Map<String, DMatrix> map2, float[][] fArray, IObjective iObjective, IEvaluation iEvaluation, int n2, Booster booster, int n3, String string, FileSystem fileSystem) throws XGBoostError, IOException {
        DMatrix[] dMatrixArray;
        float f = 1.0f;
        ArrayList<String> arrayList = new ArrayList<String>();
        ArrayList<DMatrix> arrayList2 = new ArrayList<DMatrix>();
        HashSet<Integer> hashSet = new HashSet();
        ExternalCheckpointManager externalCheckpointManager = null;
        if (string != null) {
            externalCheckpointManager = new ExternalCheckpointManager(string, fileSystem);
        }
        for (Map.Entry<String, DMatrix> entry : map2.entrySet()) {
            arrayList.add(entry.getKey());
            arrayList2.add(entry.getValue());
        }
        String[] stringArray = arrayList.toArray(new String[arrayList.size()]);
        DMatrix[] dMatrixArray2 = arrayList2.toArray(new DMatrix[arrayList2.size()]);
        int n4 = 0;
        float[][] fArray2 = fArray = fArray == null ? new float[stringArray.length][n] : fArray;
        if (dMatrixArray2.length > 0) {
            dMatrixArray = new DMatrix[dMatrixArray2.length + 1];
            dMatrixArray[0] = dMatrix;
            System.arraycopy(dMatrixArray2, 0, dMatrixArray, 1, dMatrixArray2.length);
        } else {
            dMatrixArray = new DMatrix[]{dMatrix};
        }
        if (booster == null) {
            booster = new Booster(map, dMatrixArray);
            booster.setFeatureNames(dMatrix.getFeatureNames());
            booster.setFeatureTypes(dMatrix.getFeatureTypes());
            booster.loadRabitCheckpoint();
        } else {
            booster.setParams(map);
        }
        if (externalCheckpointManager != null) {
            hashSet = new HashSet<Integer>(externalCheckpointManager.getCheckpointRounds(n3, n));
        }
        boolean bl = false;
        boolean bl2 = false;
        for (int i = booster.getVersion() / 2; i < n; ++i) {
            if (booster.getVersion() % 2 == 0) {
                if (iObjective != null) {
                    booster.update(dMatrix, iObjective);
                } else {
                    booster.update(dMatrix, i);
                }
                XGBoost.saveCheckpoint(booster, i, hashSet, externalCheckpointManager);
                booster.saveRabitCheckpoint();
            }
            if (dMatrixArray2.length > 0) {
                float[] fArray3 = new float[dMatrixArray2.length];
                String string2 = iEvaluation != null ? booster.evalSet(dMatrixArray2, stringArray, iEvaluation, fArray3) : booster.evalSet(dMatrixArray2, stringArray, i, fArray3);
                if (!bl) {
                    if (XGBoost.isMaximizeEvaluation(string2, stringArray, map)) {
                        bl2 = true;
                        f = -3.4028235E38f;
                    } else {
                        bl2 = false;
                        f = Float.MAX_VALUE;
                    }
                    bl = true;
                }
                for (int j = 0; j < fArray3.length; ++j) {
                    fArray[j][i] = fArray3[j];
                }
                float f2 = fArray3[fArray3.length - 1];
                if (bl2) {
                    if (f2 > f) {
                        f = f2;
                        n4 = i;
                        booster.setAttr("best_iteration", String.valueOf(n4));
                        booster.setAttr("best_score", String.valueOf(f));
                    }
                } else if (f2 < f) {
                    f = f2;
                    n4 = i;
                    booster.setAttr("best_iteration", String.valueOf(n4));
                    booster.setAttr("best_score", String.valueOf(f));
                }
                if (XGBoost.shouldEarlyStop(n2, i, n4)) {
                    if (!XGBoost.shouldPrint(map, i)) break;
                    Communicator.communicatorPrint(String.format("early stopping after %d rounds away from the best iteration", n2));
                    break;
                }
                if (Communicator.getRank() == 0 && XGBoost.shouldPrint(map, i)) {
                    Communicator.communicatorPrint(string2 + "\n");
                }
            }
            booster.saveRabitCheckpoint();
        }
        return booster;
    }

    public static Booster train(DMatrix dMatrix, Map<String, Object> map, int n, Map<String, DMatrix> map2, float[][] fArray, IObjective iObjective, IEvaluation iEvaluation, int n2, Booster booster) throws XGBoostError {
        try {
            return XGBoost.trainAndSaveCheckpoint(dMatrix, map, n, map2, fArray, iObjective, iEvaluation, n2, booster, -1, null, null);
        }
        catch (IOException iOException) {
            logger.error((Object)"training failed in xgboost4j", (Throwable)iOException);
            throw new XGBoostError("training failed in xgboost4j ", iOException);
        }
    }

    private static Integer tryGetIntFromObject(Object object) {
        if (object instanceof Integer) {
            return (int)((Integer)object);
        }
        if (object instanceof String) {
            try {
                return Integer.parseInt((String)object);
            }
            catch (NumberFormatException numberFormatException) {
                return null;
            }
        }
        return null;
    }

    private static boolean shouldPrint(Map<String, Object> map, int n) {
        Object object = map.get("silent");
        Integer n2 = XGBoost.tryGetIntFromObject(object);
        if (object != null && (object.equals("true") || object.equals("True") || n2 != null && n2 != 0)) {
            return false;
        }
        Object object2 = map.get("verbose_eval");
        Integer n3 = XGBoost.tryGetIntFromObject(object2);
        if (object2 == null) {
            return true;
        }
        if (object2.equals("false") || object2.equals("False")) {
            return false;
        }
        if (n3 != null) {
            if (n3 == 0) {
                return false;
            }
            return n % n3 == 0;
        }
        return true;
    }

    static boolean shouldEarlyStop(int n, int n2, int n3) {
        if (n <= 0) {
            return false;
        }
        return n2 - n3 >= n;
    }

    private static String getMetricNameFromlog(String string, String[] stringArray) {
        String string2 = Pattern.quote(stringArray[0]) + "-(.*):";
        Pattern pattern = Pattern.compile(string2);
        Matcher matcher = pattern.matcher(string);
        String string3 = null;
        if (matcher.find()) {
            string3 = matcher.group(1);
            logger.debug((Object)("Got the metric name: " + string3));
        }
        return string3;
    }

    public static boolean isMaximizeEvaluation(String string, String[] stringArray, Map<String, Object> map) {
        if (map.get("maximize_evaluation_metrics") != null) {
            String string2 = String.valueOf(map.get("maximize_evaluation_metrics"));
            return Boolean.valueOf(string2);
        }
        String string3 = map.get("eval_metric") != null ? String.valueOf(map.get("eval_metric")) : XGBoost.getMetricNameFromlog(string, stringArray);
        assert (string3 != null);
        if (!"mape".equals(string3)) {
            for (String string4 : MAXIMIZ_METRICES) {
                if (!string3.startsWith(string4)) continue;
                return true;
            }
        }
        return false;
    }

    public static String[] crossValidation(DMatrix dMatrix, Map<String, Object> map, int n, int n2, String[] stringArray, IObjective iObjective, IEvaluation iEvaluation) throws XGBoostError {
        CVPack[] cVPackArray = XGBoost.makeNFold(dMatrix, n2, map, stringArray);
        String[] stringArray2 = new String[n];
        String[] stringArray3 = new String[cVPackArray.length];
        for (int i = 0; i < n; ++i) {
            for (CVPack cVPack : cVPackArray) {
                if (iObjective != null) {
                    cVPack.update(iObjective);
                    continue;
                }
                cVPack.update(i);
            }
            for (int j = 0; j < cVPackArray.length; ++j) {
                stringArray3[j] = iEvaluation != null ? cVPackArray[j].eval(iEvaluation) : cVPackArray[j].eval(i);
            }
            stringArray2[i] = XGBoost.aggCVResults(stringArray3);
            logger.info((Object)stringArray2[i]);
        }
        return stringArray2;
    }

    private static CVPack[] makeNFold(DMatrix dMatrix, int n, Map<String, Object> map, String[] stringArray) throws XGBoostError {
        List<Integer> list = XGBoost.genRandPermutationNums(0, (int)dMatrix.rowNum());
        int n2 = list.size() / n;
        int[] nArray = new int[n2];
        int[] nArray2 = new int[list.size() - n2];
        CVPack[] cVPackArray = new CVPack[n];
        for (int i = 0; i < n; ++i) {
            int n3 = 0;
            int n4 = 0;
            for (int j = 0; j < list.size(); ++j) {
                if (j > i * n2 && j < i * n2 + n2 && n3 < n2) {
                    nArray[n3] = list.get(j);
                    ++n3;
                    continue;
                }
                if (n4 < list.size() - n2) {
                    nArray2[n4] = list.get(j);
                    ++n4;
                    continue;
                }
                nArray[n3] = list.get(j);
                ++n3;
            }
            DMatrix dMatrix2 = dMatrix.slice(nArray2);
            DMatrix dMatrix3 = dMatrix.slice(nArray);
            CVPack cVPack = new CVPack(dMatrix2, dMatrix3, map);
            if (stringArray != null) {
                for (String string : stringArray) {
                    cVPack.booster.setParam("eval_metric", string);
                }
            }
            cVPackArray[i] = cVPack;
        }
        return cVPackArray;
    }

    private static List<Integer> genRandPermutationNums(int n, int n2) {
        ArrayList<Integer> arrayList = new ArrayList<Integer>();
        for (int i = n; i < n2; ++i) {
            arrayList.add(i);
        }
        Collections.shuffle(arrayList);
        return arrayList;
    }

    private static String aggCVResults(String[] stringArray) {
        HashMap hashMap = new HashMap();
        Object object = stringArray[0].split("\t")[0];
        for (String string : stringArray) {
            Object object2 = string.split("\t");
            for (int i = 1; i < ((String[])object2).length; ++i) {
                String[] stringArray2 = object2[i].split(":");
                String string2 = stringArray2[0];
                Float f = Float.valueOf(stringArray2[1]);
                if (!hashMap.containsKey(string2)) {
                    hashMap.put(string2, new ArrayList());
                }
                ((List)hashMap.get(string2)).add(f);
            }
        }
        for (String string : hashMap.keySet()) {
            float f = 0.0f;
            for (Object object2 : (List)hashMap.get(string)) {
                f += ((Float)object2).floatValue();
            }
            object = (String)object + String.format("\tcv-%s:%f", string, Float.valueOf(f /= (float)((List)hashMap.get(string)).size()));
        }
        return object;
    }

    private static class CVPack {
        DMatrix dtrain;
        DMatrix dtest;
        DMatrix[] dmats;
        String[] names;
        Booster booster;

        public CVPack(DMatrix dMatrix, DMatrix dMatrix2, Map<String, Object> map) throws XGBoostError {
            this.dmats = new DMatrix[]{dMatrix, dMatrix2};
            this.booster = new Booster(map, this.dmats);
            this.names = new String[]{"train", "test"};
            this.dtrain = dMatrix;
            this.dtest = dMatrix2;
        }

        public void update(int n) throws XGBoostError {
            this.booster.update(this.dtrain, n);
        }

        public void update(IObjective iObjective) throws XGBoostError {
            this.booster.update(this.dtrain, iObjective);
        }

        public String eval(int n) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, n);
        }

        public String eval(IEvaluation iEvaluation) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, iEvaluation);
        }
    }
}

