/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;

public class OneVSAll
implements Classifier,
Parameterized {
    private static final long serialVersionUID = -326668337438092217L;
    private Classifier[] oneVsAlls;
    @Parameter.ParameterHolder
    private Classifier baseClassifier;
    private CategoricalData predicting;
    private boolean concurrentTraining;
    private boolean useScoreIfAvailable = true;

    public OneVSAll(Classifier baseClassifier) {
        this(baseClassifier, true);
    }

    public OneVSAll(Classifier baseClassifier, boolean concurrentTraining) {
        this.baseClassifier = baseClassifier;
        this.concurrentTraining = concurrentTraining;
    }

    public void setConcurrentTraining(boolean concurrentTraining) {
        this.concurrentTraining = concurrentTraining;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        if (this.useScoreIfAvailable && this.oneVsAlls[0] instanceof BinaryScoreClassifier) {
            int maxIndx = 0;
            double maxScore = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < this.predicting.getNumOfCategories(); ++i) {
                double score = -((BinaryScoreClassifier)this.oneVsAlls[i]).getScore(data);
                if (!(score > maxScore)) continue;
                maxIndx = i;
                maxScore = score;
            }
            cr.setProb(maxIndx, 1.0);
        } else {
            for (int i = 0; i < this.predicting.getNumOfCategories(); ++i) {
                CategoricalResults oneVsAllCR = this.oneVsAlls[i].classify(data);
                double tmp = oneVsAllCR.getProb(0);
                if (!(tmp > 0.0)) continue;
                cr.setProb(i, tmp);
            }
            cr.normalize();
        }
        return cr;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.oneVsAlls = new Classifier[dataSet.getClassSize()];
        this.predicting = dataSet.getPredicting();
        ArrayList categorized = new ArrayList();
        for (int i = 0; i < this.oneVsAlls.length; ++i) {
            List<DataPoint> tmp = dataSet.getSamples(i);
            ArrayList<DataPoint> oneCat = new ArrayList<DataPoint>(tmp.size());
            oneCat.addAll(tmp);
            categorized.add(oneCat);
        }
        int numer = dataSet.getNumNumericalVars();
        CategoricalData[] categories = dataSet.getCategories();
        final CountDownLatch latch = new CountDownLatch(this.oneVsAlls.length);
        for (int i = 0; i < this.oneVsAlls.length; ++i) {
            final ClassificationDataSet cds = new ClassificationDataSet(numer, categories, new CategoricalData(2));
            for (Object dp : (List)categorized.get(i)) {
                cds.addDataPoint(((DataPoint)dp).getNumericalValues(), ((DataPoint)dp).getCategoricalValues(), 0);
            }
            for (int j = 0; j < categorized.size(); ++j) {
                Object dp;
                if (j == i) continue;
                dp = ((List)categorized.get(j)).iterator();
                while (dp.hasNext()) {
                    DataPoint dp2 = (DataPoint)dp.next();
                    cds.addDataPoint(dp2.getNumericalValues(), dp2.getCategoricalValues(), 1);
                }
            }
            if (!this.concurrentTraining) {
                this.oneVsAlls[i] = this.baseClassifier.clone();
                if (threadPool == null || threadPool instanceof FakeExecutor) {
                    this.oneVsAlls[i].trainC(cds);
                    continue;
                }
                this.oneVsAlls[i].trainC(cds, threadPool);
                continue;
            }
            final Classifier aClassifier = this.baseClassifier.clone();
            final int ii = i;
            threadPool.submit(new Runnable(){

                @Override
                public void run() {
                    aClassifier.trainC(cds);
                    ((OneVSAll)OneVSAll.this).oneVsAlls[ii] = aClassifier;
                    latch.countDown();
                }
            });
        }
        if (this.concurrentTraining) {
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(OneVSAll.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public OneVSAll clone() {
        OneVSAll clone = new OneVSAll(this.baseClassifier.clone(), this.concurrentTraining);
        if (this.predicting != null) {
            clone.predicting = this.predicting.clone();
        }
        if (this.oneVsAlls != null) {
            clone.oneVsAlls = new Classifier[this.oneVsAlls.length];
            for (int i = 0; i < this.oneVsAlls.length; ++i) {
                if (this.oneVsAlls[i] == null) continue;
                clone.oneVsAlls[i] = this.oneVsAlls[i].clone();
            }
        }
        return clone;
    }

    @Override
    public boolean supportsWeightedData() {
        return this.baseClassifier.supportsWeightedData();
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

