/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.hashgnn;

import java.util.Collection;
import java.util.List;
import java.util.SplittableRandom;
import java.util.stream.Collectors;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.hashgnn.BinarizeFeaturesConfig;
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
import org.neo4j.gds.ml.core.features.FeatureConsumer;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.utils.StringFormatting;

class BinarizeTask
implements Runnable {
    private final Partition partition;
    private final HugeObjectArray<HugeAtomicBitSet> truncatedFeatures;
    private final List<FeatureExtractor> featureExtractors;
    private final double[][] propertyEmbeddings;
    private final double threshold;
    private final int dimension;
    private final ProgressTracker progressTracker;
    private long totalFeatureCount;
    private double scalarProductSum;
    private double scalarProductSumOfSquares;

    BinarizeTask(Partition partition, BinarizeFeaturesConfig config, HugeObjectArray<HugeAtomicBitSet> truncatedFeatures, List<FeatureExtractor> featureExtractors, double[][] propertyEmbeddings, ProgressTracker progressTracker) {
        this.partition = partition;
        this.dimension = config.dimension();
        this.threshold = config.threshold();
        this.truncatedFeatures = truncatedFeatures;
        this.featureExtractors = featureExtractors;
        this.propertyEmbeddings = propertyEmbeddings;
        this.progressTracker = progressTracker;
    }

    static HugeObjectArray<HugeAtomicBitSet> compute(Graph graph, List<Partition> partition, HashGNNConfig config, SplittableRandom rng, ProgressTracker progressTracker, TerminationFlag terminationFlag, MutableLong totalFeatureCountOutput) {
        progressTracker.beginSubTask("Binarize node property features");
        BinarizeFeaturesConfig binarizationConfig = config.binarizeFeatures().orElseThrow();
        List featureExtractors = FeatureExtraction.propertyExtractors((Graph)graph, (Collection)config.featureProperties());
        int inputDimension = FeatureExtraction.featureCount((Collection)featureExtractors);
        double[][] propertyEmbeddings = BinarizeTask.embedProperties(binarizationConfig.dimension(), rng, inputDimension);
        HugeObjectArray truncatedFeatures = HugeObjectArray.newArray(HugeAtomicBitSet.class, (long)graph.nodeCount());
        List tasks = partition.stream().map(p -> new BinarizeTask((Partition)p, binarizationConfig, (HugeObjectArray<HugeAtomicBitSet>)truncatedFeatures, featureExtractors, propertyEmbeddings, progressTracker)).collect(Collectors.toList());
        RunWithConcurrency.builder().concurrency(config.concurrency()).tasks(tasks).terminationFlag(terminationFlag).run();
        totalFeatureCountOutput.add(tasks.stream().mapToLong(BinarizeTask::totalFeatureCount).sum());
        double squaredSum = tasks.stream().mapToDouble(BinarizeTask::scalarProductSumOfSquares).sum();
        double sum = tasks.stream().mapToDouble(BinarizeTask::scalarProductSum).sum();
        long exampleCount = graph.nodeCount() * (long)binarizationConfig.dimension();
        double avg = sum / (double)exampleCount;
        double variance = (squaredSum - (double)exampleCount * avg * avg) / (double)exampleCount;
        double std = Math.sqrt(variance);
        progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Hyperplane scalar products have mean %.4f and standard deviation %.4f. A threshold for binarization may be set to the mean plus a few standard deviations.", (Object[])new Object[]{avg, std}));
        progressTracker.endSubTask("Binarize node property features");
        return truncatedFeatures;
    }

    public static double[][] embedProperties(int vectorDimension, SplittableRandom rng, int inputDimension) {
        double[][] propertyEmbeddings = new double[inputDimension][];
        for (int inputFeature = 0; inputFeature < inputDimension; ++inputFeature) {
            propertyEmbeddings[inputFeature] = new double[vectorDimension];
            for (int feature = 0; feature < vectorDimension; ++feature) {
                propertyEmbeddings[inputFeature][feature] = BinarizeTask.boxMullerGaussianRandom(rng);
            }
        }
        return propertyEmbeddings;
    }

    private static double boxMullerGaussianRandom(SplittableRandom rng) {
        return Math.sqrt(-2.0 * Math.log(rng.nextDouble(0.0, 1.0))) * Math.cos(Math.PI * 2 * rng.nextDouble(0.0, 1.0));
    }

    @Override
    public void run() {
        this.partition.consume(nodeId -> {
            final float[] featureVector = new float[this.dimension];
            FeatureExtraction.extract((long)nodeId, (long)-1L, this.featureExtractors, (FeatureConsumer)new FeatureConsumer(){

                public void acceptScalar(long nodeOffset, int offset, double value) {
                    int feature = 0;
                    while (feature < BinarizeTask.this.dimension) {
                        double featureValue = BinarizeTask.this.propertyEmbeddings[offset][feature];
                        int n = feature++;
                        featureVector[n] = (float)((double)featureVector[n] + value * featureValue);
                    }
                }

                public void acceptArray(long nodeOffset, int offset, double[] values) {
                    for (int inputFeatureOffset = 0; inputFeatureOffset < values.length; ++inputFeatureOffset) {
                        double value = values[inputFeatureOffset];
                        int feature = 0;
                        while (feature < BinarizeTask.this.dimension) {
                            double featureValue = BinarizeTask.this.propertyEmbeddings[offset + inputFeatureOffset][feature];
                            int n = feature++;
                            featureVector[n] = (float)((double)featureVector[n] + value * featureValue);
                        }
                    }
                }
            });
            HugeAtomicBitSet featureSet = this.round(featureVector);
            this.totalFeatureCount += featureSet.cardinality();
            this.truncatedFeatures.set(nodeId, (Object)featureSet);
        });
        this.progressTracker.logProgress(this.partition.nodeCount());
    }

    private HugeAtomicBitSet round(float[] floatVector) {
        HugeAtomicBitSet bitset = HugeAtomicBitSet.create((long)floatVector.length);
        for (int feature = 0; feature < floatVector.length; ++feature) {
            float scalarProduct = floatVector[feature];
            this.scalarProductSum += (double)scalarProduct;
            this.scalarProductSumOfSquares += (double)(scalarProduct * scalarProduct);
            if (!((double)scalarProduct > this.threshold)) continue;
            bitset.set((long)feature);
        }
        return bitset;
    }

    public long totalFeatureCount() {
        return this.totalFeatureCount;
    }

    public double scalarProductSum() {
        return this.scalarProductSum;
    }

    public double scalarProductSumOfSquares() {
        return this.scalarProductSumOfSquares;
    }
}

