/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.streaming.cluster;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.WeightedVector;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.Multinomial;
import org.apache.mahout.math.random.WeightedThing;

public class BallKMeans
implements Iterable<Centroid> {
    private final UpdatableSearcher centroids;
    private final int numClusters;
    private final int maxNumIterations;
    private final double trimFraction;
    private final boolean kMeansPlusPlusInit;
    private final boolean correctWeights;
    private final double testProbability;
    private final boolean splitTrainTest;
    private final int numRuns;
    private final Random random;

    public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations) {
        this(searcher, numClusters, maxNumIterations, 0.9, true, true, 0.0, 1);
    }

    public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations, boolean kMeansPlusPlusInit, int numRuns) {
        this(searcher, numClusters, maxNumIterations, 0.9, kMeansPlusPlusInit, true, 0.1, numRuns);
    }

    public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations, double trimFraction, boolean kMeansPlusPlusInit, boolean correctWeights, double testProbability, int numRuns) {
        Preconditions.checkArgument((searcher.size() == 0 ? 1 : 0) != 0, (Object)"Searcher must be empty initially to populate with centroids");
        Preconditions.checkArgument((numClusters > 0 ? 1 : 0) != 0, (Object)"The requested number of clusters must be positive");
        Preconditions.checkArgument((maxNumIterations > 0 ? 1 : 0) != 0, (Object)"The maximum number of iterations must be positive");
        Preconditions.checkArgument((trimFraction > 0.0 ? 1 : 0) != 0, (Object)"The trim fraction must be positive");
        Preconditions.checkArgument((testProbability >= 0.0 && testProbability < 1.0 ? 1 : 0) != 0, (Object)"The testProbability must be in [0, 1)");
        Preconditions.checkArgument((numRuns > 0 ? 1 : 0) != 0, (Object)"There has to be at least one run");
        this.centroids = searcher;
        this.numClusters = numClusters;
        this.maxNumIterations = maxNumIterations;
        this.trimFraction = trimFraction;
        this.kMeansPlusPlusInit = kMeansPlusPlusInit;
        this.correctWeights = correctWeights;
        this.testProbability = testProbability;
        this.splitTrainTest = testProbability > 0.0;
        this.numRuns = numRuns;
        this.random = RandomUtils.getRandom();
    }

    public Pair<List<? extends WeightedVector>, List<? extends WeightedVector>> splitTrainTest(List<? extends WeightedVector> datapoints) {
        if (this.testProbability == 0.0) {
            return new Pair<List<? extends WeightedVector>, List<? extends WeightedVector>>(datapoints, new ArrayList());
        }
        int numTest = (int)(this.testProbability * (double)datapoints.size());
        Preconditions.checkArgument((numTest > 0 && numTest < datapoints.size() ? 1 : 0) != 0, (String)"Must have nonzero number of training and test vectors. Asked for %.1f %% of %d vectors for test", (Object[])new Object[]{this.testProbability * 100.0, datapoints.size()});
        Collections.shuffle(datapoints);
        return new Pair<List<? extends WeightedVector>, List<? extends WeightedVector>>(datapoints.subList(numTest, datapoints.size()), datapoints.subList(0, numTest));
    }

    public UpdatableSearcher cluster(List<? extends WeightedVector> datapoints) {
        Pair<List<? extends WeightedVector>, List<? extends WeightedVector>> trainTestSplit = this.splitTrainTest(datapoints);
        ArrayList bestCentroids = new ArrayList();
        double cost = Double.POSITIVE_INFINITY;
        double bestCost = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.numRuns; ++i) {
            this.centroids.clear();
            if (this.kMeansPlusPlusInit) {
                this.initializeSeedsKMeansPlusPlus(trainTestSplit.getFirst());
            } else {
                this.initializeSeedsRandomly(trainTestSplit.getFirst());
            }
            if (this.numRuns > 1) {
                this.iterativeAssignment(trainTestSplit.getFirst());
                cost = ClusteringUtils.totalClusterCost(this.splitTrainTest ? datapoints : trainTestSplit.getSecond(), this.centroids);
                if (!(cost < bestCost)) continue;
                bestCost = cost;
                bestCentroids.clear();
                Iterables.addAll(bestCentroids, (Iterable)this.centroids);
                continue;
            }
            this.iterativeAssignment(datapoints);
            return this.centroids;
        }
        if (bestCost == Double.POSITIVE_INFINITY) {
            throw new RuntimeException("No valid clustering was found");
        }
        if (cost != bestCost) {
            this.centroids.clear();
            this.centroids.addAll(bestCentroids);
        }
        if (this.correctWeights) {
            for (WeightedVector weightedVector : trainTestSplit.getSecond()) {
                WeightedVector closest = (WeightedVector)this.centroids.searchFirst((Vector)weightedVector, false).getValue();
                closest.setWeight(closest.getWeight() + weightedVector.getWeight());
            }
        }
        return this.centroids;
    }

    /*
     * WARNING - void declaration
     */
    private void initializeSeedsRandomly(List<? extends WeightedVector> datapoints) {
        void var6_9;
        void var6_7;
        int numDatapoints = datapoints.size();
        double totalWeight = 0.0;
        for (WeightedVector weightedVector : datapoints) {
            totalWeight += weightedVector.getWeight();
        }
        Multinomial seedSelector = new Multinomial();
        boolean bl = false;
        while (var6_7 < numDatapoints) {
            seedSelector.add((Object)((int)var6_7), datapoints.get((int)var6_7).getWeight() / totalWeight);
            ++var6_7;
        }
        boolean bl2 = false;
        while (var6_9 < this.numClusters) {
            int sample = (Integer)seedSelector.sample();
            seedSelector.delete((Object)sample);
            Centroid centroid = new Centroid(datapoints.get(sample));
            centroid.setIndex((int)var6_9);
            this.centroids.add((Vector)centroid);
            ++var6_9;
        }
    }

    /*
     * WARNING - void declaration
     */
    private void initializeSeedsKMeansPlusPlus(List<? extends WeightedVector> datapoints) {
        void var7_10;
        Preconditions.checkArgument((datapoints.size() > 1 ? 1 : 0) != 0, (Object)"Must have at least two datapoints points to cluster sensibly");
        Preconditions.checkArgument((datapoints.size() >= this.numClusters ? 1 : 0) != 0, (Object)String.format("Must have more datapoints [%d] than clusters [%d]", datapoints.size(), this.numClusters));
        Centroid center = new Centroid(datapoints.iterator().next());
        for (WeightedVector row : Iterables.skip(datapoints, (int)1)) {
            center.update((Vector)row);
        }
        double deltaX = 0.0;
        DistanceMeasure distanceMeasure = this.centroids.getDistanceMeasure();
        for (WeightedVector weightedVector : datapoints) {
            deltaX += distanceMeasure.distance((Vector)weightedVector, (Vector)center);
        }
        Multinomial seedSelector = new Multinomial();
        boolean bl = false;
        while (var7_10 < datapoints.size()) {
            double selectionProbability = deltaX + (double)datapoints.size() * distanceMeasure.distance((Vector)datapoints.get((int)var7_10), (Vector)center);
            seedSelector.add((Object)((int)var7_10), selectionProbability);
            ++var7_10;
        }
        int n = this.random.nextInt(datapoints.size());
        Centroid c_1 = new Centroid(datapoints.get(n).clone());
        c_1.setIndex(0);
        for (int i = 0; i < datapoints.size(); ++i) {
            WeightedVector row = datapoints.get(i);
            double w = distanceMeasure.distance((Vector)c_1, (Vector)row) * 2.0 * Math.log(1.0 + row.getWeight());
            seedSelector.set((Object)i, w);
        }
        this.centroids.add((Vector)c_1);
        int clusterIndex = 1;
        while (this.centroids.size() < this.numClusters) {
            int seedIndex = (Integer)seedSelector.sample();
            Centroid nextSeed = new Centroid(datapoints.get(seedIndex));
            nextSeed.setIndex(clusterIndex++);
            this.centroids.add((Vector)nextSeed);
            seedSelector.delete((Object)seedIndex);
            Iterator iterator = seedSelector.iterator();
            while (iterator.hasNext()) {
                int currSeedIndex = (Integer)iterator.next();
                WeightedVector curr = datapoints.get(currSeedIndex);
                double newWeight = nextSeed.getWeight() * distanceMeasure.distance((Vector)nextSeed, (Vector)curr);
                if (!(newWeight < seedSelector.getWeight((Object)currSeedIndex))) continue;
                seedSelector.set((Object)currSeedIndex, newWeight);
            }
        }
    }

    private void iterativeAssignment(List<? extends WeightedVector> datapoints) {
        DistanceMeasure distanceMeasure = this.centroids.getDistanceMeasure();
        ArrayList<Double> closestClusterDistances = new ArrayList<Double>(this.numClusters);
        ArrayList<Integer> clusterAssignments = new ArrayList<Integer>(Collections.nCopies(datapoints.size(), -1));
        boolean changed = true;
        for (int i = 0; changed && i < this.maxNumIterations; ++i) {
            Object center22;
            changed = false;
            closestClusterDistances.clear();
            for (Object center22 : this.centroids) {
                Vector closestOtherCluster = (Vector)this.centroids.searchFirst((Vector)center22, true).getValue();
                closestClusterDistances.add(distanceMeasure.distance((Vector)center22, closestOtherCluster));
            }
            ArrayList<Centroid> arrayList = new ArrayList<Centroid>();
            center22 = this.centroids.iterator();
            while (center22.hasNext()) {
                Vector centroid = (Vector)center22.next();
                Centroid newCentroid = (Centroid)centroid.clone();
                newCentroid.setWeight(0.0);
                arrayList.add(newCentroid);
            }
            for (int j = 0; j < datapoints.size(); ++j) {
                WeightedVector datapoint = datapoints.get(j);
                WeightedThing<Vector> closestPair = this.centroids.searchFirst((Vector)datapoint, false);
                int closestIndex = ((WeightedVector)closestPair.getValue()).getIndex();
                double closestDistance = closestPair.getWeight();
                if (closestIndex != (Integer)clusterAssignments.get(j)) {
                    changed = true;
                    clusterAssignments.set(j, closestIndex);
                }
                if (!(closestDistance < this.trimFraction * (Double)closestClusterDistances.get(closestIndex))) continue;
                ((Centroid)arrayList.get(closestIndex)).update((Vector)datapoint);
            }
            this.centroids.clear();
            this.centroids.addAll(arrayList);
        }
        if (this.correctWeights) {
            for (Vector vector : this.centroids) {
                ((Centroid)vector).setWeight(0.0);
            }
            for (WeightedVector weightedVector : datapoints) {
                Centroid closestCentroid = (Centroid)this.centroids.searchFirst((Vector)weightedVector, false).getValue();
                closestCentroid.setWeight(closestCentroid.getWeight() + weightedVector.getWeight());
            }
        }
    }

    @Override
    public Iterator<Centroid> iterator() {
        return Iterators.transform(this.centroids.iterator(), (Function)new Function<Vector, Centroid>(){

            public Centroid apply(Vector input) {
                Preconditions.checkArgument((boolean)(input instanceof Centroid), (Object)"Non-centroid in centroids searcher");
                return (Centroid)input;
            }
        });
    }
}

