/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.models.logisticregression;

import org.apache.commons.lang3.mutable.MutableInt;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.batch.SingletonBatch;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.MatrixVectorSum;
import org.neo4j.gds.ml.core.functions.ReducedSoftmax;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.functions.Softmax;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.models.Classifier;
import org.neo4j.gds.models.Features;
import org.neo4j.gds.models.logisticregression.LogisticRegressionData;

public final class LogisticRegressionClassifier
implements Classifier {
    private final LogisticRegressionData data;
    private final LogisticRegressionPredictionStrategy predictionStrategy;

    private LogisticRegressionClassifier(LogisticRegressionData data, LogisticRegressionPredictionStrategy predictionStrategy) {
        this.data = data;
        this.predictionStrategy = predictionStrategy;
    }

    public static LogisticRegressionClassifier from(LogisticRegressionData data) {
        LogisticRegressionPredictionStrategy predictionStrategy = data.classIdMap().size() == 2 && ((Matrix)data.weights().data()).rows() == 1 ? LogisticRegressionPredictionStrategy.binary() : LogisticRegressionPredictionStrategy.multiClass();
        return new LogisticRegressionClassifier(data, predictionStrategy);
    }

    public static long sizeOfPredictionsVariableInBytes(int batchSize, int numberOfFeatures, int numberOfClasses, int normalizedNumberOfClasses) {
        int[] dimensionsOfFirstMatrix = Dimensions.matrix((int)batchSize, (int)numberOfFeatures);
        long softmaxSize = numberOfClasses == normalizedNumberOfClasses ? Softmax.sizeInBytes((int)batchSize, (int)numberOfClasses) : ReducedSoftmax.sizeInBytes((int)batchSize, (int)numberOfClasses);
        return LogisticRegressionClassifier.sizeOfFeatureExtractorsInBytes(numberOfFeatures) + Constant.sizeInBytes((int[])dimensionsOfFirstMatrix) + MatrixMultiplyWithTransposedSecondOperand.sizeInBytes((int)batchSize, (int)normalizedNumberOfClasses) + softmaxSize;
    }

    private static long sizeOfFeatureExtractorsInBytes(int numberOfFeatures) {
        return FeatureExtraction.memoryUsageInBytes((int)numberOfFeatures);
    }

    @Override
    public LocalIdMap classIdMap() {
        return this.data.classIdMap();
    }

    @Override
    public double[] predictProbabilities(long id, Features features) {
        return this.predictionStrategy.predictProbabilities(id, features, this);
    }

    @Override
    public Matrix predictProbabilities(Batch batch, Features features) {
        ComputationContext ctx = new ComputationContext();
        return (Matrix)ctx.forward(this.predictionsVariable(LogisticRegressionClassifier.batchFeatureMatrix(batch, features)));
    }

    Variable<Matrix> predictionsVariable(Constant<Matrix> batchFeatures) {
        Weights<Matrix> weights = this.data.weights();
        MatrixMultiplyWithTransposedSecondOperand weightedFeatures = MatrixMultiplyWithTransposedSecondOperand.of(batchFeatures, weights);
        MatrixVectorSum softmaxInput = new MatrixVectorSum((Variable)weightedFeatures, this.data.bias());
        return ((Matrix)weights.data()).rows() == this.numberOfClasses() ? new Softmax((Variable)softmaxInput) : new ReducedSoftmax((Variable)softmaxInput);
    }

    static Constant<Matrix> batchFeatureMatrix(Batch batch, Features features) {
        Matrix batchFeatures = new Matrix(batch.size(), features.featureDimension());
        MutableInt batchFeaturesOffset = new MutableInt();
        batch.nodeIds().forEach(id -> batchFeatures.setRow(batchFeaturesOffset.getAndIncrement(), features.get((long)id)));
        return new Constant((Tensor)batchFeatures);
    }

    @Override
    public LogisticRegressionData data() {
        return this.data;
    }

    static interface LogisticRegressionPredictionStrategy {
        public double[] predictProbabilities(long var1, Features var3, LogisticRegressionClassifier var4);

        public static LogisticRegressionPredictionStrategy binary() {
            return (id, features, classifier) -> {
                double affinity = 0.0;
                double[] featuresForNode = features.get(id);
                Matrix weights = (Matrix)classifier.data().weights().data();
                for (int i = 0; i < features.featureDimension(); ++i) {
                    affinity += weights.dataAt(i) * featuresForNode[i];
                }
                double sigmoid = Sigmoid.sigmoid((double)(affinity + ((Vector)classifier.data().bias().data()).dataAt(0)));
                return new double[]{sigmoid, 1.0 - sigmoid};
            };
        }

        public static LogisticRegressionPredictionStrategy multiClass() {
            return (id, features, classifier) -> {
                SingletonBatch batch = new SingletonBatch(id);
                ComputationContext ctx = new ComputationContext();
                return ((Matrix)ctx.forward(classifier.predictionsVariable(LogisticRegressionClassifier.batchFeatureMatrix((Batch)batch, features)))).data();
            };
        }
    }
}

