/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.ml.clustering.GMeans;
import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.INumericArrayInstance;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.ClusterStratiAssigner;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAmountSelector;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;

public class GMeansStratiAmountSelectorAndAssigner<I extends INumericArrayInstance, D extends IDataset<I>>
extends ClusterStratiAssigner<I, D>
implements IStratiAmountSelector<D> {
    private GMeans<I> clusterer;

    public GMeansStratiAmountSelectorAndAssigner(int randomSeed) {
        this.randomSeed = randomSeed;
        this.distanceMeasure = new ManhattanDistance();
    }

    public GMeansStratiAmountSelectorAndAssigner(DistanceMeasure distanceMeasure, int randomSeed) {
        this.randomSeed = randomSeed;
        this.distanceMeasure = distanceMeasure;
    }

    @Override
    public int selectStratiAmount(D dataset) {
        this.clusterer = new GMeans(dataset, this.distanceMeasure, this.randomSeed);
        this.clusters = this.clusterer.cluster();
        return this.clusters.size();
    }

    @Override
    public void init(D dataset, int stratiAmount) {
        if (this.clusterer == null || this.clusters == null) {
            JDKRandomGenerator rand = new JDKRandomGenerator();
            rand.setSeed(this.randomSeed);
            KMeansPlusPlusClusterer kmeans = new KMeansPlusPlusClusterer(stratiAmount, -1, this.distanceMeasure, (RandomGenerator)rand);
            this.clusters = kmeans.cluster(dataset);
        }
    }
}

