/*
 * Decompiled with CFR 0.152.
 */
package io.github.metarank.ltrlib.booster;

import io.github.metarank.ltrlib.booster.Booster;
import io.github.metarank.ltrlib.booster.BoosterDataset;
import io.github.metarank.ltrlib.booster.XGBoostBooster;
import io.github.metarank.ltrlib.booster.XGBoostOptions;
import io.github.metarank.ltrlib.util.Logging;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.Serializable;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import org.slf4j.Logger;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.ArrayOps$;
import scala.collection.Map;
import scala.collection.StringOps$;
import scala.collection.immutable.Seq;
import scala.jdk.CollectionConverters$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ModuleSerializationProxy;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.java8.JFunction1;

public final class XGBoostBooster$
implements Booster.BoosterFactory<DMatrix, XGBoostBooster, XGBoostOptions>,
Logging,
Serializable {
    public static final XGBoostBooster$ MODULE$ = new XGBoostBooster$();
    private static final int BITSTREAM_VERSION;
    private static Logger logger;

    static {
        Booster.BoosterFactory.$init$(MODULE$);
        Logging.$init$(MODULE$);
        BITSTREAM_VERSION = 2;
    }

    @Override
    public Logger logger() {
        return logger;
    }

    @Override
    public void io$github$metarank$ltrlib$util$Logging$_setter_$logger_$eq(Logger x$1) {
        logger = x$1;
    }

    public int BITSTREAM_VERSION() {
        return BITSTREAM_VERSION;
    }

    @Override
    public XGBoostBooster apply(byte[] string) {
        byte version;
        DataInputStream stream = new DataInputStream(new ByteArrayInputStream(string));
        byte by = version = stream.readByte();
        if (this.BITSTREAM_VERSION() == by) {
            int featureTypesSize = stream.readInt();
            String[] featureTypes = (String[])RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), featureTypesSize).map((Function1 & Serializable)x$3 -> stream.readUTF()).toArray(ClassTag$.MODULE$.apply(String.class));
            int boosterSize = stream.readInt();
            byte[] buffer = new byte[boosterSize];
            stream.readFully(buffer);
            Booster booster = XGBoost.loadModel((byte[])buffer);
            return new XGBoostBooster(booster, featureTypes);
        }
        throw new Exception("you use old binary xgboost serialization format, please re-serialize");
    }

    @Override
    public DMatrix formatData(BoosterDataset d, Option<DMatrix> parent, XGBoostOptions options) {
        DMatrix mat = new DMatrix((float[])ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.doubleArrayOps(d.data()), (Function1)(JFunction1.mcFD.sp & Serializable)x$4 -> (float)x$4, (ClassTag)ClassTag$.MODULE$.Float()), d.rows(), d.cols(), Float.NaN);
        mat.setLabel((float[])ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.doubleArrayOps(d.labels()), (Function1)(JFunction1.mcFD.sp & Serializable)x$5 -> (float)x$5, (ClassTag)ClassTag$.MODULE$.Float()));
        mat.setGroup(d.groups());
        String[] ftypes = new String[d.original().desc().dim()];
        for (int i = 0; i < d.original().desc().dim(); ++i) {
            ftypes[i] = ArrayOps$.MODULE$.contains$extension(Predef$.MODULE$.intArrayOps(d.categoricalIndices()), (Object)BoxesRunTime.boxToInteger((int)i)) ? "c" : "q";
        }
        mat.setFeatureTypes(ftypes);
        return mat;
    }

    @Override
    public XGBoostBooster train(DMatrix dataset, Option<DMatrix> test, XGBoostOptions options, Booster.DatasetOptions dso) {
        java.util.Map opts = CollectionConverters$.MODULE$.MapHasAsJava((Map)Predef$.MODULE$.Map().apply((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"objective"), (Object)"rank:pairwise"), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"eval_metric"), (Object)new StringBuilder(5).append("ndcg@").append(options.ndcgCutoff()).toString()), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"num_round"), (Object)options.trees()), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"max_depth"), (Object)Integer.toString(options.maxDepth())), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"eta"), (Object)Double.toString(options.learningRate())), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"seed"), (Object)Integer.toString(options.randomSeed())), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"subsample"), (Object)Double.toString(options.subsample())), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"tree_method"), (Object)options.treeMethod()), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"enable_categorical"), (Object)(ArrayOps$.MODULE$.isEmpty$extension(Predef$.MODULE$.intArrayOps(dso.categoryFeatures())) ? "false" : "true")), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"lambdarank_unbiased"), (Object)(options.debias() ? "true" : "false"))}))).asJava();
        Booster model = XGBoost.train((DMatrix)dataset, (java.util.Map)opts, (int)0, (java.util.Map)CollectionConverters$.MODULE$.MapHasAsJava((Map)Predef$.MODULE$.Map().empty()).asJava(), null, null);
        boolean earlyStop = false;
        double lastBest = 0.0;
        int lastBestIter = 0;
        for (int it = 0; it < options.trees() && !earlyStop; ++it) {
            model.update(dataset, it);
            double ndcgTrain = this.evalMetric(model, dataset, it);
            Option<DMatrix> option = test;
            if (option instanceof Some) {
                Some some = (Some)option;
                DMatrix value = (DMatrix)some.value();
                double ndcgTest = this.evalMetric(model, value, it);
                this.logger().info(new StringBuilder(31).append("[").append(it).append("] NDCG@").append(options.ndcgCutoff()).append(":train = ").append(ndcgTrain).append(" NDCG@").append(options.ndcgCutoff()).append(":test = ").append(ndcgTest).toString());
                Option<Object> option2 = options.earlyStopping();
                if (option2 instanceof Some) {
                    Some some2 = (Some)option2;
                    int esThreshold = BoxesRunTime.unboxToInt((Object)some2.value());
                    if (ndcgTest > lastBest) {
                        lastBest = ndcgTest;
                        lastBestIter = it;
                    }
                    if (it - lastBestIter > esThreshold) {
                        this.logger().info(new StringBuilder(39).append("early stop: ").append(esThreshold).append(" rounds passed, best=").append(lastBest).append(" last=").append(ndcgTest).toString());
                        earlyStop = true;
                    }
                    continue;
                }
                if (None$.MODULE$.equals(option2)) {
                    continue;
                }
                throw new MatchError(option2);
            }
            if (None$.MODULE$.equals(option)) {
                this.logger().info(new StringBuilder(16).append("[").append(it).append("] NDCG@train = ").append(ndcgTrain).toString());
                continue;
            }
            throw new MatchError(option);
        }
        String[] ftypes = (String[])RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), dso.dims()).map((Function1 & Serializable)x -> XGBoostBooster$.$anonfun$train$1(dso, BoxesRunTime.unboxToInt((Object)x))).toArray(ClassTag$.MODULE$.apply(String.class));
        return new XGBoostBooster(model, ftypes);
    }

    @Override
    public void closeData(DMatrix d) {
        d.dispose();
    }

    public double evalMetric(Booster model, DMatrix dataset, int it) {
        String result = model.evalSet((DMatrix[])((Object[])new DMatrix[]{dataset}), (String[])((Object[])new String[]{"test"}), it);
        return StringOps$.MODULE$.toDouble$extension(Predef$.MODULE$.augmentString((String)ArrayOps$.MODULE$.last$extension(Predef$.MODULE$.refArrayOps((Object[])StringOps$.MODULE$.split$extension(Predef$.MODULE$.augmentString(result), ':')))));
    }

    public XGBoostBooster apply(Booster model, String[] featureTypes) {
        return new XGBoostBooster(model, featureTypes);
    }

    public Option<Tuple2<Booster, String[]>> unapply(XGBoostBooster x$0) {
        if (x$0 == null) {
            return None$.MODULE$;
        }
        return new Some((Object)new Tuple2((Object)x$0.model(), (Object)x$0.featureTypes()));
    }

    private Object writeReplace() {
        return new ModuleSerializationProxy(XGBoostBooster$.class);
    }

    public static final /* synthetic */ String $anonfun$train$1(Booster.DatasetOptions dso$1, int x) {
        if (ArrayOps$.MODULE$.contains$extension(Predef$.MODULE$.intArrayOps(dso$1.categoryFeatures()), (Object)BoxesRunTime.boxToInteger((int)x))) {
            return "c";
        }
        return "q";
    }

    private XGBoostBooster$() {
    }
}

