/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.BatchSizeConfig;
import org.neo4j.gds.config.EmbeddingDimensionConfig;
import org.neo4j.gds.config.FeaturePropertiesConfig;
import org.neo4j.gds.config.IterationsConfig;
import org.neo4j.gds.config.RandomSeedConfig;
import org.neo4j.gds.config.RelationshipWeightConfig;
import org.neo4j.gds.config.ToleranceConfig;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.embeddings.graphsage.ActivationFunction;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.embeddings.graphsage.LayerConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfigImpl;
import org.neo4j.gds.embeddings.graphsage.algo.ImmutableGraphSageTrainConfig;
import org.neo4j.gds.model.ModelConfig;

@ValueClass
@Configuration(value="GraphSageTrainConfigImpl")
public interface GraphSageTrainConfig
extends AlgoBaseConfig,
ModelConfig,
BatchSizeConfig,
IterationsConfig,
ToleranceConfig,
EmbeddingDimensionConfig,
RelationshipWeightConfig,
FeaturePropertiesConfig,
RandomSeedConfig {
    public static final long serialVersionUID = 66L;

    @Value.Default
    default public int embeddingDimension() {
        return 64;
    }

    @Value.Default
    @Configuration.IntegerRange(min=1)
    @Configuration.ConvertWith(value="convertToIntSamples")
    default public List<Integer> sampleSizes() {
        return List.of(Integer.valueOf(25), Integer.valueOf(10));
    }

    public static List<Integer> convertToIntSamples(List<Number> input) {
        try {
            return input.stream().map(Number::longValue).map(Math::toIntExact).collect(Collectors.toList());
        }
        catch (ArithmeticException e) {
            throw new IllegalArgumentException("Sample size must smaller than 2^31", e);
        }
    }

    @Value.Default
    @Configuration.ConvertWith(value="org.neo4j.gds.embeddings.graphsage.Aggregator.AggregatorType#parse")
    @Configuration.ToMapValue(value="org.neo4j.gds.embeddings.graphsage.Aggregator.AggregatorType#toString")
    default public Aggregator.AggregatorType aggregator() {
        return Aggregator.AggregatorType.MEAN;
    }

    @Value.Default
    @Configuration.ConvertWith(value="org.neo4j.gds.embeddings.graphsage.ActivationFunction#parse")
    @Configuration.ToMapValue(value="org.neo4j.gds.embeddings.graphsage.ActivationFunction#toString")
    default public ActivationFunction activationFunction() {
        return ActivationFunction.SIGMOID;
    }

    @Value.Default
    default public double tolerance() {
        return 1.0E-4;
    }

    @Value.Default
    default public double learningRate() {
        return 0.1;
    }

    @Value.Default
    @Configuration.IntegerRange(min=1)
    default public int epochs() {
        return 1;
    }

    @Value.Default
    default public int maxIterations() {
        return 10;
    }

    @Value.Default
    default public int searchDepth() {
        return 5;
    }

    @Value.Default
    default public int negativeSampleWeight() {
        return 20;
    }

    @Configuration.IntegerRange(min=1)
    public Optional<Integer> projectedFeatureDimension();

    @Configuration.Ignore
    default public boolean propertiesMustExistForEachNodeLabel() {
        return false;
    }

    @Configuration.Ignore
    @Value.Derived
    default public boolean isWeighted() {
        return this.relationshipWeightProperty() != null;
    }

    @Configuration.Ignore
    default public List<LayerConfig> layerConfigs(int featureDimension) {
        ArrayList<LayerConfig> result = new ArrayList<LayerConfig>(this.sampleSizes().size());
        Random random = new Random();
        this.randomSeed().ifPresent(random::setSeed);
        for (int i = 0; i < this.sampleSizes().size(); ++i) {
            LayerConfig layerConfig = LayerConfig.builder().aggregatorType(this.aggregator()).activationFunction(this.activationFunction()).rows(this.embeddingDimension()).cols(i == 0 ? featureDimension : this.embeddingDimension()).sampleSize(this.sampleSizes().get(i)).randomSeed(random.nextLong()).build();
            result.add(layerConfig);
        }
        return result;
    }

    @Configuration.Ignore
    default public boolean isMultiLabel() {
        return this.projectedFeatureDimension().isPresent();
    }

    @Configuration.Ignore
    default public int estimationFeatureDimension() {
        return this.projectedFeatureDimension().orElse(this.featureProperties().size());
    }

    @Value.Check
    default public void validate() {
        if (this.featureProperties().isEmpty()) {
            throw new IllegalArgumentException("GraphSage requires at least one property.");
        }
    }

    public static GraphSageTrainConfig of(String username, CypherMapWrapper userInput) {
        return new GraphSageTrainConfigImpl(username, userInput);
    }

    public static ImmutableGraphSageTrainConfig.Builder testBuilder() {
        return ImmutableGraphSageTrainConfig.builder();
    }
}

