/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.boosting;

import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

public class ArcX4
implements Classifier,
Parameterized {
    private static final long serialVersionUID = 3831448932874147550L;
    private Classifier weakLearner;
    private int iterations;
    private double coef = 1.0;
    private double expo = 4.0;
    private CategoricalData predicing;
    private Classifier[] hypoths;

    public ArcX4(Classifier weakLearner, int iterations) {
        this.setWeakLearner(weakLearner);
        this.setIterations(iterations);
    }

    public void setWeakLearner(Classifier weakLearner) {
        if (!weakLearner.supportsWeightedData()) {
            throw new RuntimeException("Weak learners must support weighted data samples");
        }
        this.weakLearner = weakLearner;
    }

    public Classifier getWeakLearner() {
        return this.weakLearner;
    }

    public void setIterations(int iterations) {
        this.iterations = iterations;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setCoefficient(double coef) {
        if (coef <= 0.0 || Double.isInfinite(coef) || Double.isNaN(coef)) {
            throw new ArithmeticException("The coefficient must be a positive constant");
        }
        this.coef = coef;
    }

    public double getCoefficient() {
        return this.coef;
    }

    public void setExponent(double expo) {
        if (expo <= 0.0 || Double.isInfinite(expo) || Double.isNaN(expo)) {
            throw new ArithmeticException("The exponent must be a positive constant");
        }
        this.expo = expo;
    }

    public double getExponent() {
        return this.expo;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(this.predicing.getNumOfCategories());
        for (Classifier hypoth : this.hypoths) {
            cr.incProb(hypoth.classify(data).mostLikely(), 1.0);
        }
        cr.normalize();
        return cr;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        ClassificationDataSet cds = dataSet.shallowClone();
        for (int i = 0; i < cds.getSampleSize(); ++i) {
            DataPoint dp = cds.getDataPoint(i);
            cds.setDataPoint(i, new DataPoint(dp.getNumericalValues(), dp.getCategoricalValues(), dp.getCategoricalData()));
        }
        int[] errors = new int[cds.getSampleSize()];
        int blockSize = errors.length / SystemInfo.LogicalCores;
        this.hypoths = new Classifier[this.iterations];
        for (int t = 0; t < this.hypoths.length; ++t) {
            for (int i = 0; i < cds.getSampleSize(); ++i) {
                cds.getDataPoint(i).setWeight(1.0 + this.coef * Math.pow(errors[i], this.expo));
            }
            Classifier hypoth = this.weakLearner.clone();
            if (threadPool == null || threadPool instanceof FakeExecutor) {
                hypoth.trainC(cds);
            } else {
                hypoth.trainC(cds, threadPool);
            }
            this.hypoths[t] = hypoth;
            if (blockSize > 0) {
                int extra = errors.length % SystemInfo.LogicalCores;
                CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
                int start = 0;
                while (start < errors.length) {
                    int end = start + blockSize;
                    if (extra-- > 0) {
                        ++end;
                    }
                    threadPool.submit(new Tester(cds, errors, start, end, hypoth, latch));
                    start = end;
                }
                try {
                    latch.await();
                }
                catch (InterruptedException ex) {
                    Logger.getLogger(ArcX4.class.getName()).log(Level.SEVERE, null, ex);
                }
                continue;
            }
            new Tester(cds, errors, 0, errors.length, hypoth, new CountDownLatch(1)).run();
        }
        this.predicing = cds.getPredicting();
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public ArcX4 clone() {
        ArcX4 clone = new ArcX4(this.weakLearner.clone(), this.iterations);
        clone.coef = this.coef;
        clone.expo = this.expo;
        if (this.predicing != null) {
            clone.predicing = this.predicing.clone();
        }
        if (this.hypoths != null) {
            clone.hypoths = new Classifier[this.hypoths.length];
            for (int i = 0; i < clone.hypoths.length; ++i) {
                clone.hypoths[i] = this.hypoths[i].clone();
            }
        }
        return clone;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    private class Tester
    implements Runnable {
        final ClassificationDataSet cds;
        final int[] errors;
        final int start;
        final int end;
        final Classifier hypoth;
        final CountDownLatch latch;

        public Tester(ClassificationDataSet cds, int[] errors, int start, int end, Classifier hypoth, CountDownLatch latch) {
            this.cds = cds;
            this.errors = errors;
            this.start = start;
            this.end = end;
            this.hypoth = hypoth;
            this.latch = latch;
        }

        @Override
        public void run() {
            for (int i = this.start; i < this.end; ++i) {
                if (this.hypoth.classify(this.cds.getDataPoint(i)).mostLikely() == this.cds.getDataPointCategory(i)) continue;
                int n = i;
                this.errors[n] = this.errors[n] + 1;
            }
            this.latch.countDown();
        }
    }
}

