/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.ClusterStratiAssigner;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KMeansStratiAssigner
extends ClusterStratiAssigner {
    private Logger logger = LoggerFactory.getLogger(KMeansStratiAssigner.class);

    public KMeansStratiAssigner(DistanceMeasure distanceMeasure, int randomSeed) {
        this.randomSeed = randomSeed;
        this.distanceMeasure = distanceMeasure;
    }

    @Override
    public void init(IDataset<?> dataset, int stratiAmount) {
        this.setDataset(dataset);
        JDKRandomGenerator rand = new JDKRandomGenerator();
        rand.setSeed(this.randomSeed);
        IDataset<?> cDataset = dataset;
        KMeansPlusPlusClusterer clusterer = new KMeansPlusPlusClusterer(stratiAmount, -1, this.distanceMeasure, (RandomGenerator)rand);
        this.logger.info("Clustering dataset with {} instances.", (Object)dataset.size());
        this.setClusters(clusterer.cluster(cDataset));
        this.logger.info("Finished clustering");
    }
}

