/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.rtree.impurity;

import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.List;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;

public class MeanAbsoluteError
implements RegressorImpurity {
    @Override
    public double impurity(float[] targets, float[] weights) {
        float weightedSum = 0.0f;
        float weightSum = 0.0f;
        for (int i = 0; i < targets.length; ++i) {
            weightedSum += targets[i] * weights[i];
            weightSum += weights[i];
        }
        float mean = weightedSum / weightSum;
        float absoluteError = 0.0f;
        for (int i = 0; i < targets.length; ++i) {
            float error = Math.abs(mean - targets[i]);
            absoluteError += error * weights[i];
        }
        return absoluteError / weightSum;
    }

    @Override
    public RegressorImpurity.ImpurityTuple impurityTuple(int[] indices, int indicesLength, float[] targets, float[] weights) {
        if (indicesLength == 1) {
            return new RegressorImpurity.ImpurityTuple(0.0f, weights[indices[0]]);
        }
        float weightedSum = 0.0f;
        float weightSum = 0.0f;
        for (int i = 0; i < indicesLength; ++i) {
            int idx = indices[i];
            weightedSum += targets[idx] * weights[idx];
            weightSum += weights[idx];
        }
        float mean = weightedSum / weightSum;
        float absoluteError = 0.0f;
        for (int i = 0; i < indicesLength; ++i) {
            int idx = indices[i];
            float error = Math.abs(mean - targets[idx]);
            absoluteError += error * weights[idx];
        }
        return new RegressorImpurity.ImpurityTuple(absoluteError, weightSum);
    }

    @Override
    public RegressorImpurity.ImpurityTuple impurityTuple(List<int[]> indices, float[] targets, float[] weights) {
        float weightedSum = 0.0f;
        float weightSum = 0.0f;
        for (int[] curIndices : indices) {
            for (int i = 0; i < curIndices.length; ++i) {
                int idx = curIndices[i];
                weightedSum += targets[idx] * weights[idx];
                weightSum += weights[idx];
            }
        }
        float mean = weightedSum / weightSum;
        float absoluteError = 0.0f;
        for (int[] curIndices : indices) {
            for (int i = 0; i < curIndices.length; ++i) {
                int idx = curIndices[i];
                float error = Math.abs(mean - targets[idx]);
                absoluteError += error * weights[idx];
            }
        }
        return new RegressorImpurity.ImpurityTuple(absoluteError, weightSum);
    }

    public String toString() {
        return "MeanAbsoluteError";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "RegressorImpurity");
    }
}

