/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.hashgnn;

import java.util.ArrayList;
import java.util.List;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.embeddings.hashgnn.BinarizeFeaturesConfig;
import org.neo4j.gds.embeddings.hashgnn.GenerateFeaturesConfig;
import org.neo4j.gds.embeddings.hashgnn.HashGNN;
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
import org.neo4j.gds.embeddings.hashgnn.HashTask;
import org.neo4j.gds.mem.MemoryUsage;

public class HashGNNFactory<CONFIG extends HashGNNConfig>
extends GraphAlgorithmFactory<HashGNN, CONFIG> {
    public String taskName() {
        return "HashGNN";
    }

    public HashGNN build(Graph graph, CONFIG configuration, ProgressTracker progressTracker) {
        return new HashGNN(graph, (HashGNNConfig)configuration, progressTracker);
    }

    public Task progressTask(Graph graph, CONFIG config) {
        ArrayList<Object> tasks = new ArrayList<Object>();
        if (config.generateFeatures().isPresent()) {
            tasks.add(Tasks.leaf((String)"Generate base node property features", (long)graph.nodeCount()));
        } else if (config.binarizeFeatures().isPresent()) {
            tasks.add(Tasks.leaf((String)"Binarize node property features", (long)graph.nodeCount()));
        } else {
            tasks.add(Tasks.leaf((String)"Extract raw node property features", (long)graph.nodeCount()));
        }
        int numRelTypes = config.heterogeneous() ? config.relationshipTypes().size() : 1;
        tasks.add(Tasks.iterativeFixed((String)"Propagate embeddings", () -> List.of(Tasks.leaf((String)"Precompute hashes", (long)(config.embeddingDensity() * (2 + numRelTypes))), Tasks.leaf((String)"Perform min-hashing", (long)((2L * graph.nodeCount() + graph.relationshipCount()) * (long)config.embeddingDensity()))), (int)config.iterations()));
        if (config.outputDimension().isPresent()) {
            tasks.add(Tasks.leaf((String)"Densify output embeddings", (long)graph.nodeCount()));
        }
        return Tasks.task((String)this.taskName(), tasks);
    }

    public MemoryEstimation memoryEstimation(CONFIG config) {
        int FUDGED_BINARY_DIMENSION = 1024;
        int binaryDimension = config.generateFeatures().map(GenerateFeaturesConfig::dimension).orElse(config.binarizeFeatures().map(BinarizeFeaturesConfig::dimension).orElse(FUDGED_BINARY_DIMENSION));
        MemoryEstimations.Builder builder = MemoryEstimations.builder((String)HashGNN.class.getSimpleName());
        builder.perNode("Embeddings cache 1", n -> HugeObjectArray.memoryEstimation((long)n, (long)HugeAtomicBitSet.memoryEstimation((long)binaryDimension)));
        builder.perNode("Embeddings cache 2", n -> HugeObjectArray.memoryEstimation((long)n, (long)HugeAtomicBitSet.memoryEstimation((long)binaryDimension)));
        builder.perGraphDimension("Hashes cache", (dims, concurrency) -> MemoryRange.of((long)((long)config.embeddingDensity() * HashTask.Hashes.memoryEstimation(binaryDimension, config.heterogeneous() ? dims.relationshipCounts().size() : 1))));
        int outputDimension = config.outputDimension().orElse(binaryDimension);
        builder.perNode("Embeddings output", n -> HugeObjectArray.memoryEstimation((long)n, (long)MemoryUsage.sizeOfDoubleArray((long)outputDimension)));
        return builder.build();
    }
}

