/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.utils.FakeExecutor;
import jsat.utils.IntSet;
import jsat.utils.ModifiableCountDownLatch;

public class ID3
implements Classifier {
    private static final long serialVersionUID = -8473683139353205898L;
    private CategoricalData predicting;
    private CategoricalData[] attributes;
    private ID3Node root;
    private ModifiableCountDownLatch latch;

    @Override
    public CategoricalResults classify(DataPoint data) {
        return ID3.walkTree(this.root, data);
    }

    private static CategoricalResults walkTree(ID3Node node, DataPoint data) {
        if (node.isLeaf()) {
            return node.getResult();
        }
        return ID3.walkTree(node.getNode(data.getCategoricalValue(node.getAttributeId())), data);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        if (dataSet.getNumNumericalVars() != 0) {
            throw new RuntimeException("ID3 only supports categorical data");
        }
        this.predicting = dataSet.getPredicting();
        this.attributes = dataSet.getCategories();
        List<DataPointPair<Integer>> dataPoints = dataSet.getAsDPPList();
        IntSet availableAttributes = new IntSet(dataSet.getNumCategoricalVars());
        for (int i = 0; i < dataSet.getNumCategoricalVars(); ++i) {
            availableAttributes.add(Integer.valueOf(i));
        }
        this.latch = new ModifiableCountDownLatch(1);
        this.root = this.buildTree(dataPoints, availableAttributes, threadPool);
        try {
            this.latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(ID3.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    private ID3Node buildTree(List<DataPointPair<Integer>> dataPoints, Set<Integer> remainingAtribues, final ExecutorService threadPool) {
        double curEntropy = this.entropy(dataPoints);
        double size = dataPoints.size();
        if (remainingAtribues.isEmpty() || curEntropy == 0.0) {
            CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
            for (DataPointPair<Integer> dpp : dataPoints) {
                cr.setProb(dpp.getPair(), cr.getProb(dpp.getPair()) + 1.0);
            }
            cr.divideConst(size);
            this.latch.countDown();
            return new ID3Node(cr);
        }
        int bestAttribute = -1;
        double bestInfoGain = Double.MIN_VALUE;
        ArrayList bestSplit = null;
        for (int attribute : remainingAtribues) {
            ArrayList newSplit = new ArrayList(this.attributes[attribute].getNumOfCategories());
            for (int i = 0; i < this.attributes[attribute].getNumOfCategories(); ++i) {
                newSplit.add(new ArrayList());
            }
            for (DataPointPair<Integer> dpp : dataPoints) {
                ((List)newSplit.get(dpp.getDataPoint().getCategoricalValue(attribute))).add(dpp);
            }
            double splitEntrop = 0.0;
            for (int i = 0; i < newSplit.size(); ++i) {
                splitEntrop += this.entropy((List)newSplit.get(i)) * (double)((List)newSplit.get(i)).size() / size;
            }
            double infoGain = curEntropy - splitEntrop;
            if (!(infoGain > bestInfoGain)) continue;
            bestAttribute = attribute;
            bestInfoGain = infoGain;
            bestSplit = newSplit;
        }
        final ID3Node node = new ID3Node(this.attributes[bestAttribute].getNumOfCategories(), bestAttribute);
        final IntSet newRemaining = new IntSet(remainingAtribues);
        newRemaining.remove((Object)bestAttribute);
        int i = 0;
        while (i < bestSplit.size()) {
            final int ii = i++;
            final List bestSplitII = (List)bestSplit.get(ii);
            this.latch.countUp();
            threadPool.submit(new Runnable(){

                @Override
                public void run() {
                    node.setNode(ii, ID3.this.buildTree(bestSplitII, newRemaining, threadPool));
                }
            });
        }
        this.latch.countDown();
        return node;
    }

    private double entropy(List<DataPointPair<Integer>> s) {
        if (s.isEmpty()) {
            return 0.0;
        }
        double[] probs = new double[this.predicting.getNumOfCategories()];
        for (DataPointPair<Integer> dpp : s) {
            int n = dpp.getPair();
            probs[n] = probs[n] + 1.0;
        }
        int i = 0;
        while (i < probs.length) {
            int n = i++;
            probs[n] = probs[n] / (double)s.size();
        }
        double entr = 0.0;
        for (int i2 = 0; i2 < probs.length; ++i2) {
            if (probs[i2] == 0.0) continue;
            entr += probs[i2] * (Math.log(probs[i2]) / Math.log(2.0));
        }
        return Math.abs(entr);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Classifier clone() {
        ID3 copy = new ID3();
        copy.attributes = this.attributes;
        copy.latch = null;
        copy.predicting = this.predicting;
        copy.root = this.root.copy();
        return copy;
    }

    private static class ID3Node {
        ID3Node[] children;
        CategoricalResults cr;
        int attributeId;

        private ID3Node() {
        }

        public ID3Node(int atributes, int attributeId) {
            this.cr = null;
            this.children = new ID3Node[atributes];
            this.attributeId = attributeId;
        }

        public ID3Node(CategoricalResults cr) {
            this.children = null;
            this.cr = cr;
        }

        public boolean isLeaf() {
            return this.cr != null;
        }

        public void setNode(int i, ID3Node node) {
            this.children[i] = node;
        }

        public ID3Node getNode(int i) {
            return this.children[i];
        }

        public int getAttributeId() {
            return this.attributeId;
        }

        public CategoricalResults getResult() {
            return this.cr;
        }

        public ID3Node copy() {
            ID3Node copy = new ID3Node();
            copy.cr = this.cr;
            copy.attributeId = this.attributeId;
            if (this.children != null) {
                copy.children = new ID3Node[this.children.length];
                for (int i = 0; i < this.children.length; ++i) {
                    copy.children[i] = this.children[i].copy();
                }
            }
            return copy;
        }
    }
}

