/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java.flink;

import java.io.IOException;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.stream.StreamSupport;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class XGBoostModel
implements Serializable {
    private static final Logger logger = LoggerFactory.getLogger(XGBoostModel.class);
    private final Booster booster;
    private final PredictorFunction predictorFunction;

    public XGBoostModel(Booster booster) {
        this.booster = booster;
        this.predictorFunction = new PredictorFunction(booster);
    }

    public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
        this.booster.saveModel((OutputStream)FileSystem.get((Configuration)new Configuration()).create(new Path(modelPath)));
    }

    public byte[] toByteArray(String format) throws XGBoostError {
        return this.booster.toByteArray(format);
    }

    public void saveModelAsHadoopFile(String modelPath, String format) throws IOException, XGBoostError {
        this.booster.saveModel((OutputStream)FileSystem.get((Configuration)new Configuration()).create(new Path(modelPath)), format);
    }

    public float[][] predict(DMatrix testSet) throws XGBoostError {
        return this.booster.predict(testSet, true, 0);
    }

    public DataSet<Float[]> predict(DataSet<Vector> data) {
        return data.mapPartition((MapPartitionFunction)this.predictorFunction);
    }

    private static class PredictorFunction
    implements MapPartitionFunction<Vector, Float[]> {
        private final Booster booster;

        public PredictorFunction(Booster booster) {
            this.booster = booster;
        }

        public void mapPartition(Iterable<Vector> it, Collector<Float[]> out) throws Exception {
            Iterator dataIter = StreamSupport.stream(it.spliterator(), false).map(Vector::toSparse).map(PredictorFunction::fromVector).iterator();
            if (dataIter.hasNext()) {
                DMatrix data = new DMatrix(dataIter, null);
                float[][] predictions = this.booster.predict(data, true, 2);
                Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(arg_0 -> out.collect(arg_0));
            } else {
                logger.debug("Empty partition");
            }
        }

        private static LabeledPoint fromVector(SparseVector vector) {
            int[] index = vector.indices;
            double[] value = vector.values;
            int size = value.length;
            float[] values = new float[size];
            for (int i = 0; i < size; ++i) {
                values[i] = (float)value[i];
            }
            return new LabeledPoint(0.0f, vector.size(), index, values);
        }
    }
}

