/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.node2vec;

import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.node2vec.CompressedRandomWalks;
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
import org.neo4j.gds.embeddings.node2vec.Node2VecModel;
import org.neo4j.gds.embeddings.node2vec.RandomWalkProbabilities;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.traversal.RandomWalk;

public class Node2Vec
extends Algorithm<Node2VecModel.Result> {
    private final Graph graph;
    private final Node2VecBaseConfig config;

    public static MemoryEstimation memoryEstimation(Node2VecBaseConfig config) {
        return MemoryEstimations.builder((String)Node2Vec.class.getSimpleName()).perNode("random walks", nodeCount -> {
            long numberOfRandomWalks = nodeCount * (long)config.walksPerNode();
            long randomWalkMemoryUsage = MemoryUsage.sizeOfLongArray((long)config.walkLength());
            return HugeObjectArray.memoryEstimation((long)numberOfRandomWalks, (long)randomWalkMemoryUsage);
        }).add("probability cache", RandomWalkProbabilities.memoryEstimation()).add("model", Node2VecModel.memoryEstimation(config)).build();
    }

    public Node2Vec(Graph graph, Node2VecBaseConfig config, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.config = config;
    }

    public Node2VecModel.Result compute() {
        this.progressTracker.beginSubTask("Node2Vec");
        RandomWalk randomWalk = RandomWalk.create(this.graph, this.config, this.progressTracker, Pools.DEFAULT);
        RandomWalkProbabilities.Builder probabilitiesBuilder = new RandomWalkProbabilities.Builder(this.graph.nodeCount(), this.config.positiveSamplingFactor(), this.config.negativeSamplingExponent(), this.config.concurrency());
        CompressedRandomWalks walks = new CompressedRandomWalks(this.graph.nodeCount() * (long)this.config.walksPerNode());
        randomWalk.compute().forEach(walk -> {
            probabilitiesBuilder.registerWalk((long[])walk);
            walks.add((long)walk);
        });
        Node2VecModel node2VecModel = new Node2VecModel(arg_0 -> ((Graph)this.graph).toOriginalNodeId(arg_0), this.graph.nodeCount(), this.config, walks, probabilitiesBuilder.build(), this.progressTracker);
        Node2VecModel.Result result = node2VecModel.train();
        this.progressTracker.endSubTask("Node2Vec");
        return result;
    }

    public void release() {
    }
}

