/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.randomForest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.dromara.easyai.randomForest.DataTable;
import org.dromara.easyai.randomForest.Node;
import org.dromara.easyai.randomForest.RfModel;
import org.dromara.easyai.randomForest.Tree;
import org.dromara.easyai.randomForest.TreeWithTrust;
import org.dromara.easyai.tools.ArithUtil;

public class RandomForest {
    private Random random = new Random();
    private Tree[] forest;
    private float trustTh = 0.1f;
    private float trustPunishment = 0.1f;

    public float getTrustPunishment() {
        return this.trustPunishment;
    }

    public void setTrustPunishment(float trustPunishment) {
        this.trustPunishment = trustPunishment;
    }

    public float getTrustTh() {
        return this.trustTh;
    }

    public void setTrustTh(float trustTh) {
        this.trustTh = trustTh;
    }

    public RandomForest() {
    }

    public RandomForest(int treeNub) throws Exception {
        if (treeNub <= 0) {
            throw new Exception("Number of trees must be greater than 0");
        }
        this.forest = new Tree[treeNub];
    }

    public RfModel getModel() {
        RfModel rfModel = new RfModel();
        HashMap<Integer, Node> nodeMap = new HashMap<Integer, Node>();
        for (int i = 0; i < this.forest.length; ++i) {
            Node node = this.forest[i].getRootNode();
            nodeMap.put(i, node);
        }
        rfModel.setNodeMap(nodeMap);
        return rfModel;
    }

    public int forest(Object object) throws Exception {
        HashMap<Integer, Float> map = new HashMap<Integer, Float>();
        for (int i = 0; i < this.forest.length; ++i) {
            Tree tree = this.forest[i];
            TreeWithTrust treeWithTrust = tree.judge(object);
            int type = treeWithTrust.getType();
            float trust = treeWithTrust.getTrust();
            if (map.containsKey(type)) {
                map.put(type, Float.valueOf(((Float)map.get(type)).floatValue() + trust));
                continue;
            }
            map.put(type, Float.valueOf(trust));
        }
        int type = 0;
        float nub = 0.0f;
        for (Map.Entry entry : map.entrySet()) {
            float myNub = ((Float)entry.getValue()).floatValue();
            if (!(myNub > nub)) continue;
            type = (Integer)entry.getKey();
            nub = myNub;
        }
        if (nub < ArithUtil.mul(this.forest.length, this.trustTh)) {
            type = 0;
        }
        return type;
    }

    public void init(DataTable dataTable) throws Exception {
        if (dataTable.getSize() > 4) {
            int kNub = (int)((float)((int)Math.log(dataTable.getSize())) / (float)Math.log(2.0));
            for (int i = 0; i < this.forest.length; ++i) {
                Tree tree;
                this.forest[i] = tree = new Tree(this.getRandomData(dataTable, kNub), this.trustPunishment);
            }
        } else {
            throw new Exception("Number of feature categories must be greater than 3");
        }
    }

    public void study() throws Exception {
        for (int i = 0; i < this.forest.length; ++i) {
            Tree tree = this.forest[i];
            tree.study();
        }
    }

    public void insert(Object object) {
        for (int i = 0; i < this.forest.length; ++i) {
            Tree tree = this.forest[i];
            tree.getDataTable().insert(object);
        }
    }

    private DataTable getRandomData(DataTable dataTable, int kNub) throws Exception {
        Set<String> attr = dataTable.getKeyType();
        HashSet<String> myName = new HashSet<String>();
        String key = dataTable.getKey();
        ArrayList<String> list = new ArrayList<String>();
        for (String name : attr) {
            if (name.equals(key)) continue;
            list.add(name);
        }
        for (int i = 0; i < kNub; ++i) {
            int index = this.random.nextInt(list.size());
            myName.add((String)list.get(index));
            list.remove(index);
        }
        myName.add(key);
        DataTable data = new DataTable(myName);
        data.setKey(key);
        return data;
    }
}

