/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.decisiontree;

import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.decisiontree.DecisionTreeLoss;
import org.neo4j.gds.decisiontree.GroupSizes;
import org.neo4j.gds.decisiontree.Groups;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;

public class GiniIndex
implements DecisionTreeLoss {
    private final HugeIntArray expectedMappedLabels;
    private final int numberOfClasses;

    public GiniIndex(HugeIntArray expectedMappedLabels, int numberOfClasses) {
        this.expectedMappedLabels = expectedMappedLabels;
        this.numberOfClasses = numberOfClasses;
    }

    public static GiniIndex fromOriginalLabels(HugeLongArray expectedOriginalLabels, LocalIdMap classMapping) {
        assert (expectedOriginalLabels.size() > 0L);
        HugeIntArray mappedLabels = HugeIntArray.newArray((long)expectedOriginalLabels.size());
        mappedLabels.setAll(idx -> classMapping.toMapped(expectedOriginalLabels.get(idx)));
        return new GiniIndex(mappedLabels, classMapping.size());
    }

    @Override
    public double splitLoss(Groups groups, GroupSizes groupSizes) {
        long totalSize = groupSizes.left() + groupSizes.right();
        if (totalSize == 0L) {
            throw new IllegalStateException("Cannot compute loss over only empty groups");
        }
        double loss = this.computeGroupLoss(groups.left(), groupSizes.left()) + this.computeGroupLoss(groups.right(), groupSizes.right());
        return loss / (double)totalSize;
    }

    private double computeGroupLoss(HugeLongArray group, long groupSize) {
        assert (group.size() >= groupSize);
        if (groupSize == 0L) {
            return 0.0;
        }
        long[] groupClassCounts = new long[this.numberOfClasses];
        for (long i = 0L; i < groupSize; ++i) {
            int expectedLabel;
            int n = expectedLabel = this.expectedMappedLabels.get(group.get(i));
            groupClassCounts[n] = groupClassCounts[n] + 1L;
        }
        long score = 0L;
        for (long count : groupClassCounts) {
            score += count * count;
        }
        return (double)groupSize - (double)score / (double)groupSize;
    }
}

