/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.clustering.gmm;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.clustering.gmm.GmmPartitionData;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;

class MeanWithClusterProbAggregator
implements Serializable {
    private static final long serialVersionUID = 2700985110021774629L;
    private Vector weightedXsSum;
    private double pcxiSum;
    private int rowCount;

    MeanWithClusterProbAggregator() {
    }

    MeanWithClusterProbAggregator(Vector weightedXsSum, double pcxiSum, int rowCount) {
        this.weightedXsSum = weightedXsSum;
        this.pcxiSum = pcxiSum;
        this.rowCount = rowCount;
    }

    public Vector mean() {
        return this.weightedXsSum.divide(this.pcxiSum);
    }

    public double clusterProb() {
        return this.pcxiSum / (double)this.rowCount;
    }

    public static AggregatedStats aggreateStats(Dataset<EmptyContext, GmmPartitionData> dataset, int countOfComponents) {
        return new AggregatedStats((List)dataset.compute(data -> MeanWithClusterProbAggregator.map(data, countOfComponents), MeanWithClusterProbAggregator::reduce));
    }

    void add(Vector x, double pcxi) {
        A.ensure((pcxi >= 0.0 && pcxi <= 1.0 ? 1 : 0) != 0, (String)"pcxi >= 0 && pcxi <= 1.");
        Vector weightedVector = x.times(pcxi);
        this.weightedXsSum = this.weightedXsSum == null ? weightedVector : this.weightedXsSum.plus(weightedVector);
        this.pcxiSum += pcxi;
        ++this.rowCount;
    }

    MeanWithClusterProbAggregator plus(MeanWithClusterProbAggregator other) {
        return new MeanWithClusterProbAggregator(this.weightedXsSum.plus(other.weightedXsSum), this.pcxiSum + other.pcxiSum, this.rowCount + other.rowCount);
    }

    static List<MeanWithClusterProbAggregator> map(GmmPartitionData data, int countOfComponents) {
        int i;
        ArrayList<MeanWithClusterProbAggregator> aggregators = new ArrayList<MeanWithClusterProbAggregator>();
        for (i = 0; i < countOfComponents; ++i) {
            aggregators.add(new MeanWithClusterProbAggregator());
        }
        for (i = 0; i < data.size(); ++i) {
            for (int c = 0; c < countOfComponents; ++c) {
                ((MeanWithClusterProbAggregator)aggregators.get(c)).add(data.getX(i), data.pcxi(c, i));
            }
        }
        return aggregators;
    }

    static List<MeanWithClusterProbAggregator> reduce(List<MeanWithClusterProbAggregator> l, List<MeanWithClusterProbAggregator> r) {
        A.ensure((l != null || r != null ? 1 : 0) != 0, (String)"Both partitions cannot equal to null");
        if (l == null || l.isEmpty()) {
            return r;
        }
        if (r == null || r.isEmpty()) {
            return l;
        }
        A.ensure((l.size() == r.size() ? 1 : 0) != 0, (String)"l.size() == r.size()");
        ArrayList<MeanWithClusterProbAggregator> res = new ArrayList<MeanWithClusterProbAggregator>();
        for (int i = 0; i < l.size(); ++i) {
            res.add(l.get(i).plus(r.get(i)));
        }
        return res;
    }

    public static class AggregatedStats {
        private final Vector clusterProbs;
        private final List<Vector> means;

        private AggregatedStats(List<MeanWithClusterProbAggregator> stats) {
            this.clusterProbs = VectorUtils.of(stats.stream().mapToDouble(MeanWithClusterProbAggregator::clusterProb).toArray());
            this.means = stats.stream().map(MeanWithClusterProbAggregator::mean).collect(Collectors.toList());
        }

        public Vector clusterProbabilities() {
            return this.clusterProbs;
        }

        public List<Vector> means() {
            return this.means;
        }
    }
}

