/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree;

import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.PredictionFailedException;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.RangeQueryPredictor;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.aggregation.AggressiveAggregator;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.aggregation.IntervalAggregator;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.util.RQPHelper;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Map;
import org.apache.commons.math3.geometry.euclidean.oned.Interval;
import weka.classifiers.trees.m5.M5Base;
import weka.classifiers.trees.m5.PreConstructedLinearModel;
import weka.classifiers.trees.m5.RuleNode;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

public class ExtendedM5Tree
extends M5Base
implements RangeQueryPredictor {
    private static final long serialVersionUID = 6099808075887732225L;
    private final IntervalAggregator intervalAggregator;

    public ExtendedM5Tree() {
        this(new AggressiveAggregator());
    }

    public ExtendedM5Tree(IntervalAggregator intervalAggregator) {
        try {
            this.setOptions(new String[]{"-U"});
        }
        catch (Exception e) {
            throw new IllegalStateException("Couldn't unprune the tree");
        }
        this.intervalAggregator = intervalAggregator;
    }

    @Override
    public Interval predictInterval(RQPHelper.IntervalAndHeader intervalAndHeader) {
        Interval[] queriedInterval = intervalAndHeader.getIntervals();
        ArrayDeque<Map.Entry<Interval[], RuleNode>> stack = new ArrayDeque<Map.Entry<Interval[], RuleNode>>();
        stack.push(RQPHelper.getEntry(queriedInterval, this.getM5RootNode()));
        ArrayList<Double> list = new ArrayList<Double>();
        while (stack.peek() != null) {
            Map.Entry toProcess = (Map.Entry)stack.pop();
            RuleNode nextTree = (RuleNode)toProcess.getValue();
            double threshold = nextTree.splitVal();
            int attribute = nextTree.splitAtt();
            if (nextTree.isLeaf()) {
                this.predictLeaf(list, toProcess, nextTree, intervalAndHeader.getHeaderInformation());
                continue;
            }
            Interval intervalForAttribute = queriedInterval[attribute];
            RuleNode leftChild = nextTree.leftNode();
            RuleNode rightChild = nextTree.rightNode();
            if (intervalForAttribute.getInf() <= threshold) {
                if (threshold <= intervalForAttribute.getSup()) {
                    Interval[] leftInterval = RQPHelper.substituteInterval((Interval[])toProcess.getKey(), new Interval(intervalForAttribute.getInf(), threshold), attribute);
                    stack.push(RQPHelper.getEntry(leftInterval, leftChild));
                    Interval[] rightInterval = RQPHelper.substituteInterval((Interval[])toProcess.getKey(), new Interval(threshold, intervalForAttribute.getSup()), attribute);
                    stack.push(RQPHelper.getEntry(rightInterval, rightChild));
                    continue;
                }
                stack.push(RQPHelper.getEntry((Interval[])toProcess.getKey(), leftChild));
                continue;
            }
            stack.push(RQPHelper.getEntry((Interval[])toProcess.getKey(), rightChild));
        }
        return this.intervalAggregator.aggregate(list);
    }

    private void predictLeaf(ArrayList<Double> list, Map.Entry<Interval[], RuleNode> toProcess, RuleNode nextTree, Instances header) {
        Interval[] usedBounds = toProcess.getKey();
        PreConstructedLinearModel model = nextTree.getModel();
        DenseInstance instanceLower = new DenseInstance(usedBounds.length + 1);
        DenseInstance instanceUpper = new DenseInstance(usedBounds.length + 1);
        double[] coefficients = model.coefficients();
        for (int i = 0; i < usedBounds.length; ++i) {
            double coefficient = coefficients[i];
            if (coefficient < 0.0) {
                instanceLower.setValue(i + 1, usedBounds[i].getInf());
                instanceUpper.setValue(i + 1, usedBounds[i].getSup());
                continue;
            }
            instanceLower.setValue(i + 1, usedBounds[i].getSup());
            instanceUpper.setValue(i + 1, usedBounds[i].getInf());
        }
        instanceLower.setValue(0, 1.0);
        instanceUpper.setValue(0, 1.0);
        instanceLower.setDataset(header);
        instanceUpper.setDataset(header);
        try {
            double predictionLower = model.classifyInstance((Instance)instanceLower);
            double predictionUpper = model.classifyInstance((Instance)instanceUpper);
            list.add(predictionLower);
            list.add(predictionUpper);
        }
        catch (Exception e) {
            throw new PredictionFailedException(e);
        }
    }
}

