/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.models.logisticregression;

import java.util.function.Supplier;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.gradientdescent.Training;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.batch.HugeBatchQueue;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.models.Features;
import org.neo4j.gds.models.Trainer;
import org.neo4j.gds.models.logisticregression.LogisticRegressionClassifier;
import org.neo4j.gds.models.logisticregression.LogisticRegressionData;
import org.neo4j.gds.models.logisticregression.LogisticRegressionObjective;
import org.neo4j.gds.models.logisticregression.LogisticRegressionTrainConfig;

public final class LogisticRegressionTrainer
implements Trainer {
    private final LogisticRegressionTrainConfig trainConfig;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;
    private final LocalIdMap classIdMap;
    private final boolean reduceClassCount;
    private final int concurrency;

    public static MemoryEstimation memoryEstimation(boolean isReduced, int numberOfClasses, MemoryRange featureDimension, int batchSize) {
        return MemoryEstimations.builder((String)"train logistic regression", LogisticRegressionTrainer.class).add("model data", LogisticRegressionData.memoryEstimation(isReduced, numberOfClasses, featureDimension)).add("update weights", Training.memoryEstimation(featureDimension, numberOfClasses)).perThread("computation graph", featureDimension.apply(dim -> LogisticRegressionTrainer.sizeInBytesOfComputationGraph(isReduced, batchSize, (int)dim, numberOfClasses))).build();
    }

    private static long sizeInBytesOfComputationGraph(boolean isReduced, int batchSize, int numberOfFeatures, int numberOfClasses) {
        return LogisticRegressionObjective.sizeOfBatchInBytes(isReduced, batchSize, numberOfFeatures, numberOfClasses);
    }

    public LogisticRegressionTrainer(int concurrency, LogisticRegressionTrainConfig trainConfig, LocalIdMap classIdMap, boolean reduceClassCount, TerminationFlag terminationFlag, ProgressTracker progressTracker) {
        this.concurrency = concurrency;
        this.trainConfig = trainConfig;
        this.classIdMap = classIdMap;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
        this.reduceClassCount = reduceClassCount;
    }

    @Override
    public LogisticRegressionClassifier train(Features features, HugeLongArray labels, ReadOnlyHugeLongArray trainSet) {
        LogisticRegressionData data = this.reduceClassCount ? LogisticRegressionData.withReducedClassCount(features.featureDimension(), this.classIdMap) : LogisticRegressionData.standard(features.featureDimension(), this.classIdMap);
        LogisticRegressionClassifier classifier = LogisticRegressionClassifier.from(data);
        LogisticRegressionObjective objective = new LogisticRegressionObjective(classifier, this.trainConfig.penalty(), features, labels);
        Training training = new Training(this.trainConfig, this.progressTracker, trainSet.size(), this.terminationFlag);
        Supplier<BatchQueue> queueSupplier = () -> new HugeBatchQueue(trainSet, this.trainConfig.batchSize());
        training.train(objective, queueSupplier, this.concurrency);
        return classifier;
    }
}

