/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.graphalgo;

import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.stream.Stream;
import org.neo4j.graphalgo.core.GraphLoader;
import org.neo4j.graphalgo.core.ProcedureConfiguration;
import org.neo4j.graphalgo.core.heavyweight.HeavyGraph;
import org.neo4j.graphalgo.core.heavyweight.HeavyGraphFactory;
import org.neo4j.graphalgo.core.utils.Pools;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.ProgressTimer;
import org.neo4j.graphalgo.core.utils.TerminationFlag;
import org.neo4j.graphalgo.core.write.Exporter;
import org.neo4j.graphalgo.core.write.IntArrayTranslator;
import org.neo4j.graphalgo.impl.LabelPropagation;
import org.neo4j.graphalgo.results.LabelPropagationStats;
import org.neo4j.graphdb.Direction;
import org.neo4j.kernel.api.KernelTransaction;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.Log;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

public final class LabelPropagationProc {
    public static final String CONFIG_WEIGHT_KEY = "weightProperty";
    public static final String CONFIG_PARTITION_KEY = "partitionProperty";
    public static final Integer DEFAULT_ITERATIONS = 1;
    public static final Boolean DEFAULT_WRITE = Boolean.TRUE;
    public static final String DEFAULT_WEIGHT_KEY = "weight";
    public static final String DEFAULT_PARTITION_KEY = "partition";
    @Context
    public GraphDatabaseAPI dbAPI;
    @Context
    public Log log;
    @Context
    public KernelTransaction transaction;
    private static final Direction[] ALLOWED_DIRECTION = (Direction[])Arrays.stream(Direction.values()).filter(d -> d != Direction.BOTH).toArray(Direction[]::new);

    @Procedure(name="algo.labelPropagation", mode=Mode.WRITE)
    @Description(value="CALL algo.labelPropagation(label:String, relationship:String, direction:String, {iterations:1, weightProperty:'weight', partitionProperty:'partition', write:true, concurrency:4}) YIELD nodes, iterations, didConverge, loadMillis, computeMillis, writeMillis, write, weightProperty, partitionProperty - simple label propagation kernel")
    public Stream<LabelPropagationStats> labelPropagation(@Name(value="label", defaultValue="") String label, @Name(value="relationship", defaultValue="") String relationshipType, @Name(value="direction", defaultValue="OUTGOING") String directionName, @Name(value="config", defaultValue="{}") Map<String, Object> config) {
        ProcedureConfiguration configuration = ProcedureConfiguration.create(config).overrideNodeLabelOrQuery(label).overrideRelationshipTypeOrQuery(relationshipType);
        Direction direction = LabelPropagationProc.parseDirection(directionName);
        int iterations = configuration.getIterations(DEFAULT_ITERATIONS);
        int batchSize = configuration.getBatchSize();
        int concurrency = configuration.getConcurrency();
        String partitionProperty = configuration.getString(CONFIG_PARTITION_KEY, DEFAULT_PARTITION_KEY);
        String weightProperty = configuration.getString(CONFIG_WEIGHT_KEY, DEFAULT_WEIGHT_KEY);
        LabelPropagationStats.Builder stats = new LabelPropagationStats.Builder().iterations(iterations).partitionProperty(partitionProperty).weightProperty(weightProperty);
        HeavyGraph graph = this.load(configuration.getNodeLabelOrQuery(), configuration.getRelationshipOrQuery(), direction, partitionProperty, weightProperty, batchSize, concurrency, stats);
        int[] labels = this.compute(direction, iterations, batchSize, concurrency, graph, stats);
        if (configuration.isWriteFlag(DEFAULT_WRITE) && partitionProperty != null) {
            this.write(concurrency, partitionProperty, graph, labels, stats);
        }
        return Stream.of(stats.build());
    }

    private HeavyGraph load(String label, String relationshipType, Direction direction, String partitionKey, String weightKey, int batchSize, int concurrency, LabelPropagationStats.Builder stats) {
        try (ProgressTimer timer = stats.timeLoad();){
            HeavyGraph heavyGraph = (HeavyGraph)new GraphLoader(this.dbAPI, Pools.DEFAULT).withLog(this.log).withOptionalLabel(label).withOptionalRelationshipType(relationshipType).withOptionalRelationshipWeightsFromProperty(weightKey, 1.0).withOptionalNodeWeightsFromProperty(weightKey, 1.0).withOptionalNodeProperty(partitionKey, 0.0).withDirection(direction).withBatchSize(batchSize).withConcurrency(concurrency).load(HeavyGraphFactory.class);
            return heavyGraph;
        }
    }

    private int[] compute(Direction direction, int iterations, int batchSize, int concurrency, HeavyGraph graph, LabelPropagationStats.Builder stats) {
        try (ProgressTimer timer = stats.timeEval();){
            ExecutorService pool = batchSize > 0 ? Pools.DEFAULT : null;
            batchSize = Math.max(1, batchSize);
            LabelPropagation labelPropagation = new LabelPropagation(graph, batchSize, concurrency, pool);
            ((LabelPropagation)((LabelPropagation)labelPropagation.withProgressLogger(ProgressLogger.wrap(this.log, "LabelPropagation"))).withTerminationFlag(TerminationFlag.wrap(this.transaction))).compute(direction, iterations);
            int[] result = labelPropagation.labels();
            stats.iterations(labelPropagation.ranIterations());
            stats.didConverge(labelPropagation.didConverge());
            stats.nodes(result.length);
            labelPropagation.release();
            graph.release();
            int[] nArray = result;
            return nArray;
        }
    }

    private void write(int concurrency, String partitionKey, HeavyGraph graph, int[] labels, LabelPropagationStats.Builder stats) {
        stats.write(true);
        try (ProgressTimer timer = stats.timeWrite();){
            Exporter.of(this.dbAPI, graph).withLog(this.log).parallel(Pools.DEFAULT, concurrency, TerminationFlag.wrap(this.transaction)).build().write(partitionKey, labels, IntArrayTranslator.INSTANCE);
        }
    }

    private static Direction parseDirection(String directionString) {
        Direction direction;
        if (null == directionString) {
            return Direction.OUTGOING;
        }
        try {
            direction = Direction.valueOf((String)directionString.toUpperCase());
        }
        catch (Exception e) {
            throw new IllegalArgumentException(String.format("Cannot convert value '%s' to Direction. Legal values are '%s'.", directionString, Arrays.toString(ALLOWED_DIRECTION)));
        }
        if (direction == Direction.BOTH) {
            throw new IllegalArgumentException(String.format("Direction BOTH is not allowed. Legal values are '%s'.", Arrays.toString(ALLOWED_DIRECTION)));
        }
        return direction;
    }
}

