/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.rtree.impurity;

import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.List;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;

public class MeanSquaredError
implements RegressorImpurity {
    @Override
    public double impurity(float[] targets, float[] weights) {
        float weightedSum = 0.0f;
        float weightSum = 0.0f;
        for (int i = 0; i < targets.length; ++i) {
            weightedSum += targets[i] * weights[i];
            weightSum += weights[i];
        }
        float mean = weightedSum / weightSum;
        float squaredError = 0.0f;
        for (int i = 0; i < targets.length; ++i) {
            float error = mean - targets[i];
            squaredError += error * error * weights[i];
        }
        return squaredError / weightSum;
    }

    @Override
    public RegressorImpurity.ImpurityTuple impurityTuple(int[] indices, int indicesLength, float[] targets, float[] weights) {
        if (indicesLength == 1) {
            return new RegressorImpurity.ImpurityTuple(0.0f, weights[indices[0]]);
        }
        float weightedSquaredSum = 0.0f;
        float weightedSum = 0.0f;
        float weightSum = 0.0f;
        for (int i = 0; i < indicesLength; ++i) {
            int idx = indices[i];
            float weight = weights[idx];
            float target = targets[idx];
            float curVal = target * weight;
            weightedSum += curVal;
            weightedSquaredSum += curVal * target;
            weightSum += weight;
        }
        float mean = weightedSum / weightSum;
        return new RegressorImpurity.ImpurityTuple(weightedSquaredSum / weightSum - mean * mean, weightSum);
    }

    @Override
    public RegressorImpurity.ImpurityTuple impurityTuple(List<int[]> indices, float[] targets, float[] weights) {
        float weightedSquaredSum = 0.0f;
        float weightedSum = 0.0f;
        float weightSum = 0.0f;
        for (int[] curIndices : indices) {
            for (int i = 0; i < curIndices.length; ++i) {
                int idx = curIndices[i];
                float weight = weights[idx];
                float target = targets[idx];
                float curVal = target * weight;
                weightedSum += curVal;
                weightedSquaredSum += curVal * target;
                weightSum += weight;
            }
        }
        float mean = weightedSum / weightSum;
        return new RegressorImpurity.ImpurityTuple(weightedSquaredSum / weightSum - mean * mean, weightSum);
    }

    public String toString() {
        return "MeanSquaredError";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "RegressorImpurity");
    }
}

