/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions;

import java.util.List;
import java.util.Map;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.properties.nodes.NodePropertyContainer;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.ml.pipeline.FeatureStepUtil;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureAppender;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureStep;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureStepFactory;
import org.neo4j.gds.utils.StringFormatting;

public class CosineFeatureStep
implements LinkFeatureStep {
    private final List<String> nodePropertyNames;

    public CosineFeatureStep(List<String> nodeProperties) {
        this.nodePropertyNames = nodeProperties;
    }

    @Override
    public LinkFeatureAppender linkFeatureAppender(Graph graph) {
        final PartialL2WithNormsComputer[] partialL2WithNormsComputers = (PartialL2WithNormsComputer[])this.nodePropertyNames.stream().map(nodePropertyName -> this.createComputer((NodePropertyContainer)graph, (String)nodePropertyName)).toArray(PartialL2WithNormsComputer[]::new);
        return new LinkFeatureAppender(){

            @Override
            public void appendFeatures(long source, long target, double[] linkFeatures, int offset) {
                CosineComputationResult partialResults = new CosineComputationResult();
                for (PartialL2WithNormsComputer partialL2WithNormsComputer : partialL2WithNormsComputers) {
                    partialL2WithNormsComputer.compute(source, target, partialResults);
                }
                linkFeatures[offset] = partialResults.dotProduct;
                double l2Norm = Math.sqrt(partialResults.sourceSquareNorm * partialResults.targetSquareNorm);
                if (Double.isNaN(l2Norm)) {
                    FeatureStepUtil.throwNanError(CosineFeatureStep.this.name(), CosineFeatureStep.this.nodePropertyNames, source, target);
                } else if (l2Norm != 0.0) {
                    linkFeatures[offset] = partialResults.dotProduct / l2Norm;
                }
            }

            @Override
            public int dimension() {
                return 1;
            }
        };
    }

    @Override
    public List<String> inputNodeProperties() {
        return this.nodePropertyNames;
    }

    @Override
    public String name() {
        return LinkFeatureStepFactory.COSINE.name();
    }

    @Override
    public Map<String, Object> configuration() {
        return Map.of("nodeProperties", this.nodePropertyNames);
    }

    private PartialL2WithNormsComputer createComputer(NodePropertyContainer graph, String propertyName) {
        NodePropertyValues values = graph.nodeProperties(propertyName);
        switch (values.valueType()) {
            case DOUBLE_ARRAY: {
                return new DoubleArrayComputer(values);
            }
            case FLOAT_ARRAY: {
                return new FloatArrayComputer(values);
            }
            case LONG_ARRAY: {
                return new LongArrayComputer(values);
            }
            case LONG: {
                return new LongComputer(values);
            }
            case DOUBLE: {
                return new DoubleComputer(values);
            }
        }
        throw new IllegalStateException(StringFormatting.formatWithLocale((String)"Unsupported ValueType %s", (Object[])new Object[]{values.valueType()}));
    }

    private static class LongComputer
    extends PartialL2WithNormsComputer {
        LongComputer(NodePropertyValues values) {
            super(values);
        }

        @Override
        void compute(long source, long target, CosineComputationResult result) {
            long sourceArrayPropValues = this.values.longValue(source);
            long targetArrayPropValues = this.values.longValue(target);
            result.dotProduct += (double)(sourceArrayPropValues * targetArrayPropValues);
            result.sourceSquareNorm += (double)(sourceArrayPropValues * sourceArrayPropValues);
            result.targetSquareNorm += (double)(targetArrayPropValues * targetArrayPropValues);
        }
    }

    private static class DoubleComputer
    extends PartialL2WithNormsComputer {
        DoubleComputer(NodePropertyValues values) {
            super(values);
        }

        @Override
        void compute(long source, long target, CosineComputationResult result) {
            double sourceArrayPropValues = this.values.doubleValue(source);
            double targetArrayPropValues = this.values.doubleValue(target);
            result.dotProduct += sourceArrayPropValues * targetArrayPropValues;
            result.sourceSquareNorm += sourceArrayPropValues * sourceArrayPropValues;
            result.targetSquareNorm += targetArrayPropValues * targetArrayPropValues;
        }
    }

    private static class LongArrayComputer
    extends PartialL2WithNormsComputer {
        LongArrayComputer(NodePropertyValues values) {
            super(values);
        }

        @Override
        void compute(long source, long target, CosineComputationResult result) {
            long[] sourceArrayPropValues = this.values.longArrayValue(source);
            long[] targetArrayPropValues = this.values.longArrayValue(target);
            assert (sourceArrayPropValues.length == targetArrayPropValues.length);
            for (int i = 0; i < sourceArrayPropValues.length; ++i) {
                result.dotProduct += (double)(sourceArrayPropValues[i] * targetArrayPropValues[i]);
                result.sourceSquareNorm += (double)(sourceArrayPropValues[i] * sourceArrayPropValues[i]);
                result.targetSquareNorm += (double)(targetArrayPropValues[i] * targetArrayPropValues[i]);
            }
        }
    }

    private static class FloatArrayComputer
    extends PartialL2WithNormsComputer {
        FloatArrayComputer(NodePropertyValues values) {
            super(values);
        }

        @Override
        void compute(long source, long target, CosineComputationResult result) {
            float[] sourceArrayPropValues = this.values.floatArrayValue(source);
            float[] targetArrayPropValues = this.values.floatArrayValue(target);
            assert (sourceArrayPropValues.length == targetArrayPropValues.length);
            for (int i = 0; i < sourceArrayPropValues.length; ++i) {
                result.dotProduct += (double)(sourceArrayPropValues[i] * targetArrayPropValues[i]);
                result.sourceSquareNorm += (double)(sourceArrayPropValues[i] * sourceArrayPropValues[i]);
                result.targetSquareNorm += (double)(targetArrayPropValues[i] * targetArrayPropValues[i]);
            }
        }
    }

    private static class DoubleArrayComputer
    extends PartialL2WithNormsComputer {
        DoubleArrayComputer(NodePropertyValues values) {
            super(values);
        }

        @Override
        void compute(long source, long target, CosineComputationResult result) {
            double[] sourceArrayPropValues = this.values.doubleArrayValue(source);
            double[] targetArrayPropValues = this.values.doubleArrayValue(target);
            assert (sourceArrayPropValues.length == targetArrayPropValues.length);
            for (int i = 0; i < sourceArrayPropValues.length; ++i) {
                result.dotProduct += sourceArrayPropValues[i] * targetArrayPropValues[i];
                result.sourceSquareNorm += sourceArrayPropValues[i] * sourceArrayPropValues[i];
                result.targetSquareNorm += targetArrayPropValues[i] * targetArrayPropValues[i];
            }
        }
    }

    private static abstract class PartialL2WithNormsComputer {
        protected final NodePropertyValues values;

        PartialL2WithNormsComputer(NodePropertyValues values) {
            this.values = values;
        }

        abstract void compute(long var1, long var3, CosineComputationResult var5);
    }

    private static class CosineComputationResult {
        double sourceSquareNorm = 0.0;
        double targetSquareNorm = 0.0;
        double dotProduct = 0.0;

        CosineComputationResult() {
        }
    }
}

