/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DenseSparseMetric;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.FakeExecutor;

public class Rocchio
implements Classifier {
    private static final long serialVersionUID = 889524967453326517L;
    private List<Vec> rocVecs;
    private final DistanceMetric dm;
    private final DenseSparseMetric dsdm;
    private double[] summaryConsts;

    public Rocchio() {
        this(new EuclideanDistance());
    }

    public Rocchio(DistanceMetric dm) {
        this.dm = dm;
        this.dsdm = dm instanceof DenseSparseMetric ? (DenseSparseMetric)dm : null;
        this.rocVecs = null;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        int i;
        CategoricalResults cr = new CategoricalResults(this.rocVecs.size());
        double sum = 0.0;
        Vec target = data.getNumericalValues();
        for (i = 0; i < this.rocVecs.size(); ++i) {
            double distance = this.summaryConsts == null ? this.dm.dist(this.rocVecs.get(i), target) : this.dsdm.dist(this.summaryConsts[i], this.rocVecs.get(i), target);
            sum += distance;
            cr.setProb(i, distance);
        }
        for (i = 0; i < this.rocVecs.size(); ++i) {
            cr.setProb(i, 1.0 - cr.getProb(i) / sum);
        }
        return cr;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        if (dataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("Classifier requires all variables be numerical");
        }
        int N = dataSet.getClassSize();
        this.rocVecs = new ArrayList<Vec>(N);
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, threadPool);
        int d = dataSet.getNumNumericalVars();
        this.summaryConsts = new double[d];
        CountDownLatch cdl = new CountDownLatch(N);
        for (int i = 0; i < N; ++i) {
            DenseVector rochVec = new DenseVector(d);
            this.rocVecs.add(rochVec);
            threadPool.submit(new RocchioAdder(cdl, i, rochVec, dataSet.getSamples(i)));
        }
        try {
            cdl.await();
        }
        catch (InterruptedException ex) {
            // empty catch block
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public Rocchio clone() {
        Rocchio copy = new Rocchio(this.dm);
        if (this.rocVecs != null) {
            copy.rocVecs = new ArrayList<Vec>(this.rocVecs.size());
            for (Vec v : this.rocVecs) {
                copy.rocVecs.add(v.clone());
            }
        }
        if (this.summaryConsts != null) {
            copy.summaryConsts = Arrays.copyOf(this.summaryConsts, this.summaryConsts.length);
        }
        return copy;
    }

    private class RocchioAdder
    implements Runnable {
        double weightSum;
        final CountDownLatch latch;
        final Vec rocchioVec;
        final List<DataPoint> input;
        final int index;

        public RocchioAdder(CountDownLatch latch, int index, Vec rocchioVec, List<DataPoint> input) {
            this.latch = latch;
            this.index = index;
            this.rocchioVec = rocchioVec;
            this.input = input;
            this.weightSum = 0.0;
        }

        @Override
        public void run() {
            for (DataPoint dp : this.input) {
                double w = dp.getWeight();
                Vec v = dp.getNumericalValues();
                this.weightSum += w;
                this.rocchioVec.mutableAdd(w, v);
            }
            this.rocchioVec.mutableDivide(this.weightSum);
            if (Rocchio.this.dsdm != null) {
                ((Rocchio)Rocchio.this).summaryConsts[this.index] = Rocchio.this.dsdm.getVectorConstant(this.rocchioVec);
            }
            this.latch.countDown();
        }
    }
}

