/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.models.logisticregression;

import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.models.Classifier;
import org.neo4j.gds.models.TrainingMethod;
import org.neo4j.gds.models.logisticregression.ImmutableLogisticRegressionData;

@ValueClass
public interface LogisticRegressionData
extends Classifier.ClassifierData {
    public Weights<Matrix> weights();

    public Weights<Vector> bias();

    @Override
    @Value.Derived
    default public TrainingMethod trainerMethod() {
        return TrainingMethod.LogisticRegression;
    }

    @Override
    @Value.Derived
    default public int featureDimension() {
        return this.weights().dimension(1);
    }

    public static LogisticRegressionData standard(int featureCount, LocalIdMap classIdMap) {
        return LogisticRegressionData.create(classIdMap.size(), featureCount, classIdMap);
    }

    public static LogisticRegressionData withReducedClassCount(int featureCount, LocalIdMap classIdMap) {
        return LogisticRegressionData.create(classIdMap.size() - 1, featureCount, classIdMap);
    }

    private static LogisticRegressionData create(int classCount, int featureCount, LocalIdMap classIdMap) {
        Weights weights = Weights.ofMatrix((int)classCount, (int)featureCount);
        Weights bias = new Weights((Tensor)new Vector(classCount));
        return ImmutableLogisticRegressionData.builder().weights((Weights<Matrix>)weights).classIdMap(classIdMap).bias((Weights<Vector>)bias).build();
    }

    public static LogisticRegressionData create(Weights<Matrix> weights, Weights<Vector> bias, LocalIdMap classIdMap) {
        return ImmutableLogisticRegressionData.builder().bias(bias).weights(weights).classIdMap(classIdMap).build();
    }

    public static MemoryEstimation memoryEstimation(boolean isReduced, int numberOfClasses, MemoryRange featureDimension) {
        int normalizedNumberOfClasses = isReduced ? numberOfClasses - 1 : numberOfClasses;
        return MemoryEstimations.builder(LogisticRegressionData.class).add("classIdMap", LocalIdMap.memoryEstimation((int)numberOfClasses)).fixed("weights", featureDimension.apply(featureDim -> Weights.sizeInBytes((int)normalizedNumberOfClasses, (int)Math.toIntExact(featureDim)))).fixed("bias", Weights.sizeInBytes((int)normalizedNumberOfClasses, (int)1)).build();
    }

    public static ImmutableLogisticRegressionData.Builder builder() {
        return ImmutableLogisticRegressionData.builder();
    }
}

