/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.INumericArrayInstance;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.ClusterStratiAssigner;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KMeansStratiAssigner<I extends INumericArrayInstance, D extends IDataset<I>>
extends ClusterStratiAssigner<I, D> {
    private Logger logger = LoggerFactory.getLogger(KMeansStratiAssigner.class);

    public KMeansStratiAssigner(DistanceMeasure distanceMeasure, int randomSeed) {
        this.randomSeed = randomSeed;
        this.distanceMeasure = distanceMeasure;
    }

    @Override
    public void init(D dataset, int stratiAmount) {
        JDKRandomGenerator rand = new JDKRandomGenerator();
        rand.setSeed(this.randomSeed);
        KMeansPlusPlusClusterer clusterer = new KMeansPlusPlusClusterer(stratiAmount, -1, this.distanceMeasure, (RandomGenerator)rand);
        this.logger.info("Clustering dataset with {} instances.", (Object)dataset.size());
        this.clusters = clusterer.cluster(dataset);
        this.logger.info("Finished clustering");
    }
}

