/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.bayesian.graphicalmodel;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.bayesian.ConditionalProbabilityTable;
import jsat.classifiers.bayesian.graphicalmodel.DirectedGraph;
import jsat.exceptions.FailedToFitException;
import jsat.utils.IntSet;

public class DiscreteBayesNetwork
implements Classifier {
    private static final long serialVersionUID = 2980734594356260141L;
    protected DirectedGraph<Integer> dag = new DirectedGraph();
    protected Map<Integer, ConditionalProbabilityTable> cpts;
    protected CategoricalData predicting;
    protected double[] priors;
    private boolean usePriors = true;
    public static final boolean DEFAULT_USE_PRIORS = true;

    @Override
    public CategoricalResults classify(DataPoint data) {
        int i;
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        int classId = data.numCategoricalValues();
        double logPSum = 0.0;
        double[] logProbs = new double[cr.size()];
        for (i = 0; i < cr.size(); ++i) {
            DataPointPair<Integer> dpp = new DataPointPair<Integer>(data, i);
            for (int classParent : this.dag.getChildren(classId)) {
                int n = i;
                logProbs[n] = logProbs[n] + Math.log(this.cpts.get(classParent).query(classParent, dpp));
            }
            if (this.usePriors) {
                int n = i;
                logProbs[n] = logProbs[n] + Math.log(this.priors[i]);
            }
            logPSum += logProbs[i];
        }
        for (i = 0; i < cr.size(); ++i) {
            cr.setProb(i, Math.exp(logProbs[i] - logPSum));
        }
        return cr;
    }

    public void depends(int parent, int child) {
        this.dag.addNode(child);
        this.dag.addNode(parent);
        this.dag.addEdge(parent, child);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.trainC(dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        int classID = dataSet.getNumCategoricalVars();
        if (classID == 0) {
            throw new FailedToFitException("Network needs categorical attribtues to work");
        }
        this.predicting = dataSet.getPredicting();
        this.priors = dataSet.getPriors();
        this.cpts = new HashMap<Integer, ConditionalProbabilityTable>();
        IntSet cptTrainSet = new IntSet();
        if (this.dag.getNodes().isEmpty()) {
            for (int i = 0; i < classID; ++i) {
                this.depends(classID, i);
            }
        }
        for (int classParent : this.dag.getChildren(classID)) {
            Set<Integer> depends = this.dag.getChildren(classParent);
            ConditionalProbabilityTable cpt = new ConditionalProbabilityTable();
            cptTrainSet.clear();
            cptTrainSet.addAll(depends);
            cptTrainSet.add(Integer.valueOf(classParent));
            cptTrainSet.add(Integer.valueOf(classID));
            cpt.trainC(dataSet, cptTrainSet);
            this.cpts.put(classParent, cpt);
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Classifier clone() {
        throw new UnsupportedOperationException("Not supported yet.");
    }
}

