/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.basic.sets.ListView;
import ai.libs.jaicore.ml.clustering.learner.GMeans;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.ClusterStratiAssigner;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.IStratiAmountSelector;
import java.util.Collection;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.api4.java.ai.ml.core.dataset.IDataset;

public class GMeansStratiAmountSelectorAndAssigner
extends ClusterStratiAssigner
implements IStratiAmountSelector {
    private GMeans<Clusterable> clusterer;

    public GMeansStratiAmountSelectorAndAssigner(int randomSeed) {
        this.randomSeed = randomSeed;
        this.distanceMeasure = new ManhattanDistance();
    }

    public GMeansStratiAmountSelectorAndAssigner(DistanceMeasure distanceMeasure, int randomSeed) {
        this.randomSeed = randomSeed;
        this.distanceMeasure = distanceMeasure;
    }

    @Override
    public int selectStratiAmount(IDataset<?> dataset) {
        ListView cDataset = new ListView(dataset);
        this.clusterer = new GMeans(cDataset, this.distanceMeasure, this.randomSeed);
        this.setClusters(this.clusterer.cluster());
        return this.getClusters().size();
    }

    @Override
    public void init(IDataset<?> dataset, int stratiAmount) {
        this.setDataset(dataset);
        if (this.clusterer == null || this.getClusters() == null) {
            JDKRandomGenerator rand = new JDKRandomGenerator();
            rand.setSeed(this.randomSeed);
            KMeansPlusPlusClusterer kmeans = new KMeansPlusPlusClusterer(stratiAmount, -1, this.distanceMeasure, (RandomGenerator)rand);
            this.setClusters(kmeans.cluster((Collection)new ListView(dataset)));
        }
    }
}

