/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.gradientdescent;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.gradientdescent.GradientDescentConfig;
import org.neo4j.gds.gradientdescent.Objective;
import org.neo4j.gds.gradientdescent.TrainingStopper;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.optimizer.AdamOptimizer;
import org.neo4j.gds.ml.core.optimizer.Updater;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.TensorFunctions;
import org.neo4j.gds.utils.StringFormatting;

public class Training {
    private final GradientDescentConfig config;
    private final ProgressTracker progressTracker;
    private final long trainSize;
    private final TerminationFlag terminationFlag;

    public Training(GradientDescentConfig config, ProgressTracker progressTracker, long trainSize, TerminationFlag terminationFlag) {
        this.config = config;
        this.progressTracker = progressTracker;
        this.trainSize = trainSize;
        this.terminationFlag = terminationFlag;
    }

    public static MemoryEstimation memoryEstimation(int numberOfFeatures, int numberOfClasses) {
        return Training.memoryEstimation(MemoryRange.of((long)numberOfFeatures), numberOfClasses);
    }

    public static MemoryEstimation memoryEstimation(MemoryRange numberOfFeaturesRange, int numberOfClasses) {
        return MemoryEstimations.builder(Training.class).add(MemoryEstimations.of((String)"updater", (MemoryRange)numberOfFeaturesRange.apply(features -> AdamOptimizer.sizeInBytes((int)numberOfClasses, (int)Math.toIntExact(features))))).perThread("weight gradients", numberOfFeaturesRange.apply(features -> Weights.sizeInBytes((int)numberOfClasses, (int)Math.toIntExact(features)))).build();
    }

    public void train(Objective<?> objective, Supplier<BatchQueue> queueSupplier, int concurrency) {
        double initialLoss;
        AdamOptimizer updater = new AdamOptimizer(objective.weights());
        int epoch = 0;
        TrainingStopper stopper = TrainingStopper.defaultStopper(this.config);
        double lastLoss = initialLoss = this.evaluateLoss(objective, queueSupplier.get(), concurrency);
        while (!stopper.terminated()) {
            this.terminationFlag.assertRunning();
            this.trainEpoch(objective, queueSupplier.get(), concurrency, (Updater)updater);
            lastLoss = this.evaluateLoss(objective, queueSupplier.get(), concurrency);
            stopper.registerLoss(lastLoss);
            this.progressTracker.logProgress(1L, StringFormatting.formatWithLocale((String)":: Epoch %d with loss %s", (Object[])new Object[]{++epoch, lastLoss}));
        }
        this.progressTracker.logMessage(StringFormatting.formatWithLocale((String)"%s after %d epochs. Initial loss: %s, Last loss: %s.%s", (Object[])new Object[]{stopper.converged() ? "converged" : "terminated", epoch, initialLoss, lastLoss, stopper.converged() ? "" : " Did not converge"}));
    }

    private double evaluateLoss(Objective<?> objective, BatchQueue batches, int concurrency) {
        DoubleAdder totalLoss = new DoubleAdder();
        batches.parallelConsume((Consumer)new LossEvalConsumer(objective, totalLoss, this.trainSize), concurrency, this.terminationFlag);
        return totalLoss.doubleValue();
    }

    private void trainEpoch(Objective<?> objective, BatchQueue batches, int concurrency, Updater updater) {
        ArrayList<ObjectiveUpdateConsumer> consumers = new ArrayList<ObjectiveUpdateConsumer>(concurrency);
        for (int i = 0; i < concurrency; ++i) {
            consumers.add(new ObjectiveUpdateConsumer(objective, this.trainSize));
        }
        batches.parallelConsume(concurrency, consumers, this.terminationFlag);
        List localGradientSums = consumers.stream().map(ObjectiveUpdateConsumer::summedWeightGradients).collect(Collectors.toList());
        int numberOfBatches = consumers.stream().mapToInt(ObjectiveUpdateConsumer::consumedBatches).sum();
        List avgWeightGradients = TensorFunctions.averageTensors(localGradientSums, (int)numberOfBatches);
        updater.update(avgWeightGradients);
    }

    static class LossEvalConsumer
    implements Consumer<Batch> {
        private final Objective<?> objective;
        private final DoubleAdder totalLoss;
        private final long trainSize;

        LossEvalConsumer(Objective<?> objective, DoubleAdder lossAdder, long trainSize) {
            this.objective = objective;
            this.totalLoss = lossAdder;
            this.trainSize = trainSize;
        }

        @Override
        public void accept(Batch batch) {
            Variable<Scalar> loss = this.objective.loss(batch, this.trainSize);
            ComputationContext ctx = new ComputationContext();
            this.totalLoss.add(((Scalar)ctx.forward(loss)).value());
        }
    }

    static class ObjectiveUpdateConsumer
    implements Consumer<Batch> {
        private final Objective<?> objective;
        private final long trainSize;
        private List<? extends Tensor<?>> summedWeightGradients;
        private int consumedBatches;

        ObjectiveUpdateConsumer(Objective<?> objective, long trainSize) {
            this.objective = objective;
            this.trainSize = trainSize;
            this.summedWeightGradients = objective.weights().stream().map(weight -> weight.data().createWithSameDimensions()).collect(Collectors.toList());
            this.consumedBatches = 0;
        }

        @Override
        public void accept(Batch batch) {
            Variable<Scalar> loss = this.objective.loss(batch, this.trainSize);
            ComputationContext ctx = new ComputationContext();
            ctx.forward(loss);
            ctx.backward(loss);
            List localWeightGradient = this.objective.weights().stream().map(arg_0 -> ((ComputationContext)ctx).gradient(arg_0)).collect(Collectors.toList());
            for (int i = 0; i < this.summedWeightGradients.size(); ++i) {
                this.summedWeightGradients.get(i).addInPlace((Tensor)localWeightGradient.get(i));
            }
            ++this.consumedBatches;
        }

        List<? extends Tensor<?>> summedWeightGradients() {
            return this.summedWeightGradients;
        }

        int consumedBatches() {
            return this.consumedBatches;
        }
    }
}

