/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;

public class AdaBoostM1
implements Classifier,
Parameterized {
    private static final long serialVersionUID = 4205232097748332861L;
    private Classifier weakLearner;
    private int maxIterations;
    protected List<Classifier> hypoths;
    protected List<Double> hypWeights;
    protected CategoricalData predicting;

    public AdaBoostM1(Classifier weakLearner, int maxIterations) {
        this.setWeakLearner(weakLearner);
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public List<Classifier> getModels() {
        return Collections.unmodifiableList(this.hypoths);
    }

    public List<Double> getModelWeights() {
        return Collections.unmodifiableList(this.hypWeights);
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1) {
            throw new IllegalArgumentException("Number of iterations must be a positive value, no " + maxIterations);
        }
        this.maxIterations = maxIterations;
    }

    public Classifier getWeakLearner() {
        return this.weakLearner;
    }

    public void setWeakLearner(Classifier weakLearner) {
        if (!weakLearner.supportsWeightedData()) {
            throw new FailedToFitException("WeakLearner must support weighted data to be boosted");
        }
        this.weakLearner = weakLearner;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.predicting == null) {
            throw new RuntimeException("Classifier has not been trained yet");
        }
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        for (int i = 0; i < this.hypoths.size(); ++i) {
            cr.incProb(this.hypoths.get(i).classify(data).mostLikely(), this.hypWeights.get(i));
        }
        cr.normalize();
        return cr;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.predicting = dataSet.getPredicting();
        this.hypWeights = new DoubleList(this.maxIterations);
        this.hypoths = new ArrayList<Classifier>(this.maxIterations);
        List<DataPointPair<Integer>> dataPoints = dataSet.getAsDPPList();
        for (DataPointPair<Integer> dpp : dataPoints) {
            dpp.getDataPoint().setWeight(1.0);
        }
        double scaledBy = dataPoints.size();
        boolean[] wasCorrect = new boolean[dataPoints.size()];
        for (int t = 0; t < this.maxIterations; ++t) {
            if (threadPool != null) {
                this.weakLearner.trainC(new ClassificationDataSet(dataPoints, this.predicting), threadPool);
            } else {
                this.weakLearner.trainC(new ClassificationDataSet(dataPoints, this.predicting));
            }
            double error = 0.0;
            for (int i = 0; i < dataPoints.size(); ++i) {
                wasCorrect[i] = this.weakLearner.classify(dataPoints.get(i).getDataPoint()).mostLikely() == dataPoints.get(i).getPair().intValue();
                if (wasCorrect[i]) continue;
                error += dataPoints.get(i).getDataPoint().getWeight();
            }
            if ((error /= scaledBy) > 0.5 || error == 0.0) {
                return;
            }
            double bt = error / (1.0 - error);
            double Zt = 0.0;
            double newScale = scaledBy;
            for (int i = 0; i < wasCorrect.length; ++i) {
                double trueWeight;
                DataPoint dp = dataPoints.get(i).getDataPoint();
                if (wasCorrect[i]) {
                    double w = dp.getWeight() * bt;
                    dp.setWeight(w);
                }
                if (1.0 / (trueWeight = dp.getWeight() / scaledBy) > newScale) {
                    newScale = 1.0 / trueWeight;
                }
                Zt += dp.getWeight() / scaledBy;
            }
            for (DataPointPair<Integer> dpp : dataPoints) {
                dpp.getDataPoint().setWeight(dpp.getDataPoint().getWeight() / scaledBy * newScale / Zt);
            }
            scaledBy = newScale;
            this.hypoths.add(this.weakLearner.clone());
            this.hypWeights.add(Math.log(1.0 / bt));
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, null);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public AdaBoostM1 clone() {
        AdaBoostM1 copy = new AdaBoostM1(this.weakLearner.clone(), this.maxIterations);
        if (this.hypWeights != null) {
            copy.hypWeights = new DoubleList(this.hypWeights);
        }
        if (this.hypoths != null) {
            copy.hypoths = new ArrayList<Classifier>(this.hypoths.size());
            for (int i = 0; i < this.hypoths.size(); ++i) {
                copy.hypoths.add(this.hypoths.get(i).clone());
            }
        }
        if (this.predicting != null) {
            copy.predicting = this.predicting.clone();
        }
        return copy;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

