/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.rangequery;

import ai.libs.jaicore.ml.weka.rangequery.AbstractAugmentedSpaceSampler;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.DistanceFunction;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.neighboursearch.NearestNeighbourSearch;

public class KNNAugSpaceSampler
extends AbstractAugmentedSpaceSampler {
    private static final Logger logger = LoggerFactory.getLogger(KNNAugSpaceSampler.class);
    private final NearestNeighbourSearch nearestNeighbour;
    private int k;

    public KNNAugSpaceSampler(Instances preciseInsts, Random rng, int k, NearestNeighbourSearch nearestNeighbour) {
        super(preciseInsts, rng);
        this.k = k;
        EuclideanDistance dist = new EuclideanDistance(preciseInsts);
        String distOptionColumns = String.format("-R first-%d", preciseInsts.numAttributes() - 1);
        String[] distOptions = new String[]{distOptionColumns};
        try {
            dist.setOptions(distOptions);
            nearestNeighbour.setDistanceFunction((DistanceFunction)dist);
            nearestNeighbour.setInstances(preciseInsts);
        }
        catch (Exception e) {
            logger.error("Could not configure distance function or setup nearest neighbour.", (Throwable)e);
        }
        nearestNeighbour.setMeasurePerformance(false);
        this.nearestNeighbour = nearestNeighbour;
    }

    @Override
    public Instance augSpaceSample() {
        Instances preciseInsts = this.getPreciseInsts();
        int numInsts = preciseInsts.size();
        Instance x = preciseInsts.get(this.getRng().nextInt(numInsts));
        Instances kNNs = null;
        try {
            kNNs = this.nearestNeighbour.kNearestNeighbours(x, this.k);
        }
        catch (Exception e) {
            logger.error("Creating the augmented space sample failed with exception.", (Throwable)e);
        }
        ArrayList<Instance> sampledPoints = new ArrayList<Instance>();
        sampledPoints.add(x);
        sampledPoints.addAll((Collection<Instance>)kNNs);
        return KNNAugSpaceSampler.generateAugPoint(sampledPoints);
    }

    public int getK() {
        return this.k;
    }

    public void setK(int k) {
        this.k = k;
    }
}

