/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.image.objectdetection.haar;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.openimaj.image.objectdetection.haar.Classifier;
import org.openimaj.image.objectdetection.haar.HaarFeature;
import org.openimaj.image.objectdetection.haar.HaarFeatureClassifier;
import org.openimaj.image.objectdetection.haar.Stage;
import org.openimaj.image.objectdetection.haar.StageTreeClassifier;
import org.openimaj.image.objectdetection.haar.ValueClassifier;
import org.openimaj.image.objectdetection.haar.WeightedRectangle;
import org.xmlpull.v1.XmlPullParser;
import org.xmlpull.v1.XmlPullParserException;
import org.xmlpull.v1.XmlPullParserFactory;

public class OCVHaarLoader {
    private static final float ICV_STAGE_THRESHOLD_BIAS = 1.0E-4f;
    private static final String NEXT_NODE = "next";
    private static final String PARENT_NODE = "parent";
    private static final String STAGE_THRESHOLD_NODE = "stage_threshold";
    private static final String ANONYMOUS_NODE = "_";
    private static final String RIGHT_NODE_NODE = "right_node";
    private static final String RIGHT_VAL_NODE = "right_val";
    private static final String LEFT_NODE_NODE = "left_node";
    private static final String LEFT_VAL_NODE = "left_val";
    private static final String THRESHOLD_NODE = "threshold";
    private static final String TILTED_NODE = "tilted";
    private static final String RECTS_NODE = "rects";
    private static final String FEATURE_NODE = "feature";
    private static final String TREES_NODE = "trees";
    private static final String STAGES_NODE = "stages";
    private static final String SIZE_NODE = "size";
    private static final String OCV_STORAGE_NODE = "opencv_storage";

    static OCVHaarClassifierNode readXPP(InputStream in) throws IOException {
        try {
            XmlPullParserFactory factory = XmlPullParserFactory.newInstance();
            XmlPullParser reader = factory.newPullParser();
            reader.setInput(in, null);
            reader.nextTag();
            OCVHaarLoader.checkNode(reader, OCV_STORAGE_NODE);
            reader.nextTag();
            if (!"opencv-haar-classifier".equals(reader.getAttributeValue(null, "type_id"))) {
                throw new IOException("Unsupported format: " + reader.getAttributeValue(null, "type_id"));
            }
            OCVHaarClassifierNode root = new OCVHaarClassifierNode();
            root.name = reader.getName();
            reader.nextTag();
            OCVHaarLoader.checkNode(reader, SIZE_NODE);
            String sizeStr = reader.nextText();
            String[] widthHeight = sizeStr.trim().split(" ");
            if (widthHeight.length != 2) {
                throw new IOException("expecting 'w h' for size element, got: " + sizeStr);
            }
            root.width = Integer.parseInt(widthHeight[0]);
            root.height = Integer.parseInt(widthHeight[1]);
            reader.nextTag();
            OCVHaarLoader.checkNode(reader, STAGES_NODE);
            while (reader.nextTag() == 2) {
                OCVHaarLoader.checkNode(reader, ANONYMOUS_NODE);
                StageNode currentStage = new StageNode();
                root.stages.add(currentStage);
                reader.nextTag();
                OCVHaarLoader.checkNode(reader, TREES_NODE);
                while (reader.nextTag() == 2) {
                    OCVHaarLoader.checkNode(reader, ANONYMOUS_NODE);
                    ArrayList<TreeNode> currentTree = new ArrayList<TreeNode>();
                    currentStage.trees.add(currentTree);
                    while (reader.nextTag() == 2) {
                        OCVHaarLoader.checkNode(reader, ANONYMOUS_NODE);
                        ArrayList<WeightedRectangle> regions = new ArrayList<WeightedRectangle>(3);
                        reader.nextTag();
                        OCVHaarLoader.checkNode(reader, FEATURE_NODE);
                        reader.nextTag();
                        OCVHaarLoader.checkNode(reader, RECTS_NODE);
                        while (reader.nextTag() == 2) {
                            OCVHaarLoader.checkNode(reader, ANONYMOUS_NODE);
                            regions.add(WeightedRectangle.parse(reader.nextText()));
                        }
                        reader.nextTag();
                        OCVHaarLoader.checkNode(reader, TILTED_NODE);
                        boolean tilted = "1".equals(reader.nextText());
                        if (tilted) {
                            root.hasTiltedFeatures = true;
                        }
                        reader.nextTag();
                        OCVHaarLoader.checkNode(reader, FEATURE_NODE);
                        HaarFeature currentFeature = HaarFeature.create(regions, tilted);
                        reader.nextTag();
                        OCVHaarLoader.checkNode(reader, THRESHOLD_NODE);
                        float threshold = (float)Double.parseDouble(reader.nextText());
                        TreeNode treeNode = new TreeNode();
                        treeNode.threshold = threshold;
                        treeNode.feature = currentFeature;
                        reader.nextTag();
                        OCVHaarLoader.checkNode(reader, LEFT_VAL_NODE, LEFT_NODE_NODE);
                        String leftText = reader.nextText();
                        if (LEFT_VAL_NODE.equals(reader.getName())) {
                            treeNode.left_val = Float.parseFloat(leftText);
                        } else {
                            treeNode.left_node = Integer.parseInt(leftText);
                        }
                        reader.nextTag();
                        OCVHaarLoader.checkNode(reader, RIGHT_VAL_NODE, RIGHT_NODE_NODE);
                        String rightText = reader.nextText();
                        if (RIGHT_VAL_NODE.equals(reader.getName())) {
                            treeNode.right_val = Float.parseFloat(rightText);
                        } else {
                            treeNode.right_node = Integer.parseInt(rightText);
                        }
                        reader.nextTag();
                        OCVHaarLoader.checkNode(reader, ANONYMOUS_NODE);
                        currentTree.add(treeNode);
                    }
                }
                reader.nextTag();
                OCVHaarLoader.checkNode(reader, STAGE_THRESHOLD_NODE);
                currentStage.threshold = Float.parseFloat(reader.nextText()) - 1.0E-4f;
                reader.nextTag();
                OCVHaarLoader.checkNode(reader, PARENT_NODE);
                currentStage.parent = Integer.parseInt(reader.nextText());
                reader.nextTag();
                OCVHaarLoader.checkNode(reader, NEXT_NODE);
                currentStage.next = Integer.parseInt(reader.nextText());
                reader.nextTag();
                OCVHaarLoader.checkNode(reader, ANONYMOUS_NODE);
            }
            return root;
        }
        catch (XmlPullParserException ex) {
            throw new IOException(ex);
        }
    }

    public static StageTreeClassifier read(InputStream is) throws IOException {
        OCVHaarClassifierNode root = OCVHaarLoader.readXPP(is);
        return OCVHaarLoader.buildCascade(root);
    }

    private static StageTreeClassifier buildCascade(OCVHaarClassifierNode root) throws IOException {
        return new StageTreeClassifier(root.width, root.height, root.name, root.hasTiltedFeatures, OCVHaarLoader.buildStages(root.stages));
    }

    private static Stage buildStages(List<StageNode> stageNodes) throws IOException {
        Stage[] stages = new Stage[stageNodes.size()];
        for (int i = 0; i < stages.length; ++i) {
            StageNode node = stageNodes.get(i);
            stages[i] = new Stage(node.threshold, OCVHaarLoader.buildClassifiers(node.trees), null, null);
        }
        Stage root = null;
        boolean isCascade = true;
        for (int i = 0; i < stages.length; ++i) {
            StageNode node = stageNodes.get(i);
            if (node.parent == -1 && node.next == -1) {
                if (root == null) {
                    root = stages[i];
                } else {
                    throw new IOException("Inconsistent cascade/tree: multiple roots found");
                }
            }
            if (node.parent != -1 && stages[((StageNode)node).parent].successStage == null) {
                stages[((StageNode)node).parent].successStage = stages[i];
            }
            if (node.next == -1) continue;
            isCascade = false;
            stages[i].failureStage = stages[node.next];
        }
        if (!isCascade) {
            OCVHaarLoader.optimiseTree(root);
        }
        return root;
    }

    private static void optimiseTree(Stage root) {
        ArrayDeque<Stage> stack = new ArrayDeque<Stage>();
        stack.push(root);
        Stage failureStage = null;
        while (!stack.isEmpty()) {
            Stage stage = (Stage)stack.pop();
            if (stage.failureStage == null) {
                stage.failureStage = failureStage;
                if (stage.successStage == null) continue;
                stack.push(stage.successStage);
                continue;
            }
            if (stage.failureStage != failureStage) {
                stack.push(stage);
                failureStage = stage.failureStage;
                if (stage.successStage == null) continue;
                stack.push(stage.successStage);
                continue;
            }
            stack.push(stage.failureStage);
            failureStage = null;
        }
    }

    private static Classifier[] buildClassifiers(List<List<TreeNode>> trees) {
        Classifier[] classifiers = new Classifier[trees.size()];
        for (int i = 0; i < classifiers.length; ++i) {
            classifiers[i] = OCVHaarLoader.buildClassifier(trees.get(i));
        }
        return classifiers;
    }

    private static Classifier buildClassifier(List<TreeNode> tree) {
        return OCVHaarLoader.buildClassifier(tree, tree.get(0));
    }

    private static Classifier buildClassifier(List<TreeNode> tree, TreeNode current) {
        HaarFeatureClassifier fc = new HaarFeatureClassifier(current.feature, current.threshold, null, null);
        fc.left = current.left_node == -1 ? new ValueClassifier(current.left_val) : OCVHaarLoader.buildClassifier(tree, tree.get(current.left_node));
        fc.right = current.right_node == -1 ? new ValueClassifier(current.right_val) : OCVHaarLoader.buildClassifier(tree, tree.get(current.right_node));
        return fc;
    }

    private static void checkNode(XmlPullParser reader, String ... expected) throws IOException {
        for (String e : expected) {
            if (!e.equals(reader.getName())) continue;
            return;
        }
        throw new IOException("Unexpected tag: " + reader.getName() + " (expected: " + Arrays.toString(expected) + ")");
    }

    static class OCVHaarClassifierNode {
        int width;
        int height;
        String name;
        boolean hasTiltedFeatures = false;
        List<StageNode> stages = new ArrayList<StageNode>();

        OCVHaarClassifierNode() {
        }
    }

    static class StageNode {
        private int parent = -1;
        private int next = -1;
        private float threshold;
        private List<List<TreeNode>> trees = new ArrayList<List<TreeNode>>();

        StageNode() {
        }
    }

    static class TreeNode {
        HaarFeature feature;
        float threshold;
        float left_val;
        float right_val;
        int left_node = -1;
        int right_node = -1;

        TreeNode() {
        }
    }
}

