/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java.flink;

import java.io.InputStream;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.Communicator;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.RabitTracker;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class XGBoost {
    private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);

    public static XGBoostModel loadModelFromHadoopFile(String modelPath) throws Exception {
        FileSystem fileSystem = FileSystem.get((Configuration)new Configuration());
        Path f = new Path(modelPath);
        try (FSDataInputStream opened = fileSystem.open(f);){
            XGBoostModel xGBoostModel = new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel((InputStream)opened));
            return xGBoostModel;
        }
    }

    public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain, Map<String, Object> params, int numBoostRound) throws Exception {
        RabitTracker tracker = new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
        if (tracker.start()) {
            return (XGBoostModel)dtrain.mapPartition((MapPartitionFunction)new MapFunction(params, numBoostRound, tracker.getWorkerArgs())).reduce((ReduceFunction & Serializable)(x, y) -> x).collect().get(0);
        }
        throw new Error("Tracker cannot be started");
    }

    private static class MapFunction
    extends RichMapPartitionFunction<Tuple2<Vector, Double>, XGBoostModel> {
        private final Map<String, Object> params;
        private final int round;
        private final Map<String, Object> workerEnvs;

        public MapFunction(Map<String, Object> params, int round, Map<String, Object> workerEnvs) {
            this.params = params;
            this.round = round;
            this.workerEnvs = workerEnvs;
        }

        public void mapPartition(Iterable<Tuple2<Vector, Double>> it, Collector<XGBoostModel> collector) throws XGBoostError {
            Iterator dataIter;
            this.workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask()));
            if (logger.isInfoEnabled()) {
                logger.info("start with env: {}", (Object)this.workerEnvs.entrySet().stream().map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue())).collect(Collectors.joining(", ")));
            }
            if ((dataIter = StreamSupport.stream(it.spliterator(), false).map(VectorToPointMapper.INSTANCE).iterator()).hasNext()) {
                DMatrix trainMat = new DMatrix(dataIter, null);
                int numEarlyStoppingRounds = Optional.ofNullable(this.params.get("numEarlyStoppingRounds")).map(x -> Integer.parseInt(x.toString())).orElse(0);
                Booster booster = this.trainBooster(trainMat, numEarlyStoppingRounds);
                collector.collect((Object)new XGBoostModel(booster));
            } else {
                logger.warn("Nothing to train with.");
            }
        }

        private Booster trainBooster(final DMatrix trainMat, int numEarlyStoppingRounds) throws XGBoostError {
            Booster booster;
            HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>(){
                {
                    this.put("train", trainMat);
                }
            };
            try {
                Communicator.init(this.workerEnvs);
                booster = ml.dmlc.xgboost4j.java.XGBoost.train((DMatrix)trainMat, this.params, (int)this.round, (Map)watches, null, null, null, (int)numEarlyStoppingRounds);
            }
            catch (XGBoostError xgbException) {
                String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask());
                logger.warn(String.format("XGBooster worker %s has failed due to", identifier), (Throwable)xgbException);
                throw xgbException;
            }
            finally {
                Communicator.shutdown();
            }
            return booster;
        }

        private static class VectorToPointMapper
        implements Function<Tuple2<Vector, Double>, LabeledPoint> {
            public static VectorToPointMapper INSTANCE = new VectorToPointMapper();

            private VectorToPointMapper() {
            }

            @Override
            public LabeledPoint apply(Tuple2<Vector, Double> tuple) {
                SparseVector vector = ((Vector)tuple.f0).toSparse();
                double[] values = vector.values;
                int size = values.length;
                float[] array = new float[size];
                for (int i = 0; i < size; ++i) {
                    array[i] = (float)values[i];
                }
                return new LabeledPoint(((Double)tuple.f1).floatValue(), vector.size(), vector.indices, array);
            }
        }
    }
}

