/*
 * Decompiled with CFR 0.152.
 */
package com.microsoft.azure.synapse.ml.lightgbm;

import com.microsoft.azure.synapse.ml.lightgbm.LightGBMDelegate;
import com.microsoft.azure.synapse.ml.lightgbm.PartitionTaskContext;
import com.microsoft.azure.synapse.ml.lightgbm.PartitionTaskTrainingState;
import com.microsoft.azure.synapse.ml.lightgbm.TrainingContext;
import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster;
import com.microsoft.azure.synapse.ml.lightgbm.dataset.LightGBMDataset;
import com.microsoft.azure.synapse.ml.lightgbm.params.BaseTrainParams;
import com.microsoft.azure.synapse.ml.lightgbm.params.FObjTrait;
import java.io.Serializable;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function1;
import scala.Function3;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.Map;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public final class TrainUtils$
implements Serializable {
    public static TrainUtils$ MODULE$;

    static {
        new TrainUtils$();
    }

    public LightGBMBooster createBooster(BaseTrainParams trainParams, LightGBMDataset trainDataset, Option<LightGBMDataset> validDatasetOpt) {
        String parameters = trainParams.toString();
        LightGBMBooster booster = new LightGBMBooster(trainDataset, parameters);
        trainParams.generalParams().modelString().foreach((Function1 & Serializable & scala.Serializable)modelStr -> {
            booster.mergeBooster(modelStr);
            return BoxedUnit.UNIT;
        });
        validDatasetOpt.foreach((Function1 & Serializable & scala.Serializable)dataset -> {
            booster.addValidationDataset(dataset);
            return BoxedUnit.UNIT;
        });
        return booster;
    }

    public void beforeTrainIteration(PartitionTaskTrainingState state, Logger log) {
        if (state.ctx().trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate)state.ctx().trainingParams().delegate().get()).beforeTrainIteration(state.ctx().trainingCtx().batchIndex(), state.ctx().partitionId(), state.iteration(), log, state.ctx().trainingParams(), state.booster(), state.ctx().trainingCtx().hasValidationData());
            return;
        }
    }

    public void afterTrainIteration(PartitionTaskTrainingState state, Logger log, Option<Map<String, Object>> trainEvalResults, Option<Map<String, Object>> validEvalResults) {
        PartitionTaskContext ctx = state.ctx();
        TrainingContext trainingCtx = ctx.trainingCtx();
        if (ctx.trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate)ctx.trainingParams().delegate().get()).afterTrainIteration(trainingCtx.batchIndex(), ctx.partitionId(), state.iteration(), log, trainingCtx.trainingParams(), state.booster(), trainingCtx.hasValidationData(), state.isFinished(), trainEvalResults, validEvalResults);
            return;
        }
    }

    public double getLearningRate(PartitionTaskTrainingState state, Logger log) {
        Option<LightGBMDelegate> option = state.ctx().trainingParams().delegate();
        if (option instanceof Some) {
            Some some = (Some)option;
            LightGBMDelegate delegate = (LightGBMDelegate)some.value();
            return delegate.getLearningRate(state.ctx().trainingCtx().batchIndex(), state.ctx().partitionId(), state.iteration(), log, state.ctx().trainingParams(), state.learningRate());
        }
        if (None$.MODULE$.equals(option)) {
            return state.learningRate();
        }
        throw new MatchError(option);
    }

    public void updateOneIteration(PartitionTaskTrainingState state, Logger log) {
        block4: {
            try {
                log.info(new StringBuilder(33).append("LightGBM task starting iteration ").append(state.iteration()).toString());
                Option<FObjTrait> fobj = state.ctx().trainingParams().objectiveParams().fobj();
                if (fobj.isDefined()) {
                    Tuple2<float[], float[]> tuple2 = ((FObjTrait)fobj.get()).getGradient(state.booster().innerPredict(0, state.ctx().trainingCtx().isClassification()), (LightGBMDataset)state.booster().trainDataset().get());
                    if (tuple2 == null) {
                        throw new MatchError(tuple2);
                    }
                    float[] gradient = (float[])tuple2._1();
                    float[] hessian = (float[])tuple2._2();
                    Tuple2 tuple22 = new Tuple2((Object)gradient, (Object)hessian);
                    float[] gradient2 = (float[])tuple22._1();
                    float[] hessian2 = (float[])tuple22._2();
                    state.isFinished_$eq(state.booster().updateOneIterationCustom(gradient2, hessian2));
                    break block4;
                }
                state.isFinished_$eq(state.booster().updateOneIteration());
            }
            catch (Exception e) {
                log.warn(new StringBuilder(126).append("LightGBM reached early termination on one task, stopping training on task. This message should rarely occur. Inner exception: ").append(e.toString()).toString());
                state.isFinished_$eq(true);
            }
        }
    }

    public Option<Object> executeTrainingIterations(PartitionTaskTrainingState state, Logger log) {
        log.info(new StringBuilder(60).append("Beginning training on LightGBM Booster for task ").append(state.ctx().taskId()).append(", partition ").append(state.ctx().partitionId()).toString());
        state.ctx().measures().markTrainingIterationsStart();
        Option result = this.iterationLoop$1(state.ctx().trainingParams().generalParams().numIterations(), state, log);
        state.ctx().measures().markTrainingIterationsStop();
        return result;
    }

    public Option<Map<String, Object>> getTrainEvalResults(PartitionTaskTrainingState state, Logger log) {
        Tuple2<String, Object>[] evalResults = state.booster().getEvalResults(state.evalNames(), 0);
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])evalResults)).foreach((Function1 & Serializable & scala.Serializable)x0$1 -> {
            TrainUtils$.$anonfun$getTrainEvalResults$1(log, x0$1);
            return BoxedUnit.UNIT;
        });
        return Option$.MODULE$.apply((Object)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])evalResults)));
    }

    public Option<Map<String, Object>> getValidEvalResults(PartitionTaskTrainingState state, Logger log) {
        Tuple2<String, Object>[] evalResults = state.booster().getEvalResults(state.evalNames(), 1);
        Tuple2[] results = (Tuple2[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])evalResults)).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map((Function1 & Serializable & scala.Serializable)x0$1 -> {
            Tuple2 tuple2 = x0$1;
            if (tuple2 != null) {
                Tuple2 tuple22 = (Tuple2)tuple2._1();
                int index = tuple2._2$mcI$sp();
                if (tuple22 != null) {
                    Function3 & Serializable & scala.Serializable cmp;
                    String evalName = (String)tuple22._1();
                    double evalScore = tuple22._2$mcD$sp();
                    log.info(new StringBuilder(7).append("Valid ").append(evalName).append("=").append(evalScore).toString());
                    Function3 & Serializable & scala.Serializable intersect = cmp = evalName.startsWith("auc") || evalName.startsWith("ndcg@") || evalName.startsWith("map@") || evalName.startsWith("average_precision") ? (Function3 & Serializable & scala.Serializable)(x, y, tol) -> BoxesRunTime.boxToBoolean((boolean)TrainUtils$.$anonfun$getValidEvalResults$2(BoxesRunTime.unboxToDouble((Object)x), BoxesRunTime.unboxToDouble((Object)y), BoxesRunTime.unboxToDouble((Object)tol))) : (Function3 & Serializable & scala.Serializable)(x, y, tol) -> BoxesRunTime.boxToBoolean((boolean)TrainUtils$.$anonfun$getValidEvalResults$3(BoxesRunTime.unboxToDouble((Object)x), BoxesRunTime.unboxToDouble((Object)y), BoxesRunTime.unboxToDouble((Object)tol)));
                    if (state.bestScores()[index] == null || BoxesRunTime.unboxToBoolean((Object)cmp.apply((Object)BoxesRunTime.boxToDouble((double)evalScore), (Object)BoxesRunTime.boxToDouble((double)state.bestScore()[index]), (Object)BoxesRunTime.boxToDouble((double)state.ctx().trainingCtx().improvementTolerance())))) {
                        state$2.bestScore()[index] = evalScore;
                        state$2.bestIteration()[index] = state.iteration();
                        state$2.bestScores()[index] = (double[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])evalResults)).map((Function1 & Serializable & scala.Serializable)x$2 -> BoxesRunTime.boxToDouble((double)x$2._2$mcD$sp()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
                    } else if (state.iteration() - state.bestIteration()[index] >= state.ctx().trainingCtx().earlyStoppingRound()) {
                        state.isFinished_$eq(true);
                        log.info(new StringBuilder(34).append("Early stopping, best iteration is ").append(state.bestIteration()[index]).toString());
                        state.bestIterationResult_$eq((Option<Object>)new Some((Object)BoxesRunTime.boxToInteger((int)state.bestIteration()[index])));
                    }
                    return new Tuple2((Object)evalName, (Object)BoxesRunTime.boxToDouble((double)evalScore));
                }
            }
            throw new MatchError((Object)tuple2);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
        return Option$.MODULE$.apply((Object)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])results)));
    }

    public void beforeGenerateTrainDataset(PartitionTaskContext ctx, Logger log) {
        TrainingContext trainingCtx = ctx.trainingCtx();
        if (trainingCtx.trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate)trainingCtx.trainingParams().delegate().get()).beforeGenerateTrainDataset(trainingCtx.batchIndex(), ctx.partitionId(), trainingCtx.columnParams(), trainingCtx.schema(), log, trainingCtx.trainingParams());
            return;
        }
    }

    public void afterGenerateTrainDataset(PartitionTaskContext ctx, Logger log) {
        TrainingContext trainingCtx = ctx.trainingCtx();
        if (trainingCtx.trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate)trainingCtx.trainingParams().delegate().get()).afterGenerateTrainDataset(trainingCtx.batchIndex(), ctx.partitionId(), trainingCtx.columnParams(), trainingCtx.schema(), log, trainingCtx.trainingParams());
            return;
        }
    }

    public void beforeGenerateValidDataset(PartitionTaskContext ctx, Logger log) {
        TrainingContext trainingCtx = ctx.trainingCtx();
        if (ctx.trainingCtx().trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate)trainingCtx.trainingParams().delegate().get()).beforeGenerateValidDataset(trainingCtx.batchIndex(), ctx.partitionId(), trainingCtx.columnParams(), trainingCtx.schema(), log, trainingCtx.trainingParams());
            return;
        }
    }

    public void afterGenerateValidDataset(PartitionTaskContext ctx, Logger log) {
        TrainingContext trainingCtx = ctx.trainingCtx();
        if (trainingCtx.trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate)trainingCtx.trainingParams().delegate().get()).afterGenerateValidDataset(trainingCtx.batchIndex(), ctx.partitionId(), trainingCtx.columnParams(), trainingCtx.schema(), log, trainingCtx.trainingParams());
            return;
        }
    }

    private Object readResolve() {
        return MODULE$;
    }

    private final Option iterationLoop$1(int maxIterations, PartitionTaskTrainingState state$1, Logger log$1) {
        do {
            this.beforeTrainIteration(state$1, log$1);
            double newLearningRate = this.getLearningRate(state$1, log$1);
            if (newLearningRate != state$1.learningRate()) {
                log$1.info(new StringBuilder(86).append("LightGBM task calling booster.resetParameter to reset learningRate").append(" (newLearningRate: ").append(newLearningRate).append(")").toString());
                state$1.booster().resetParameter(new StringBuilder(14).append("learning_rate=").append(newLearningRate).toString());
                state$1.learningRate_$eq(newLearningRate);
            }
            this.updateOneIteration(state$1, log$1);
            Option<Map<String, Object>> trainEvalResults = state$1.ctx().trainingCtx().isProvideTrainingMetric() && !state$1.isFinished() ? this.getTrainEvalResults(state$1, log$1) : None$.MODULE$;
            Option<Map<String, Object>> validEvalResults = state$1.ctx().trainingCtx().hasValidationData() && !state$1.isFinished() ? this.getValidEvalResults(state$1, log$1) : None$.MODULE$;
            this.afterTrainIteration(state$1, log$1, trainEvalResults, validEvalResults);
            state$1.iteration_$eq(state$1.iteration() + 1);
        } while (!state$1.isFinished() && state$1.iteration() < maxIterations);
        return state$1.bestIterationResult();
    }

    public static final /* synthetic */ void $anonfun$getTrainEvalResults$1(Logger log$2, Tuple2 x0$1) {
        Tuple2 tuple2 = x0$1;
        if (tuple2 != null) {
            String evalName = (String)tuple2._1();
            double score = tuple2._2$mcD$sp();
            if (evalName != null) {
                String string = evalName;
                double d = score;
                log$2.info(new StringBuilder(7).append("Train ").append(string).append("=").append(d).toString());
                return;
            }
        }
        throw new MatchError((Object)tuple2);
    }

    public static final /* synthetic */ boolean $anonfun$getValidEvalResults$2(double x, double y, double tol) {
        return x - y > tol;
    }

    public static final /* synthetic */ boolean $anonfun$getValidEvalResults$3(double x, double y, double tol) {
        return x - y < tol;
    }

    private TrainUtils$() {
        MODULE$ = this;
    }
}

