/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.predictions.smile;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.kie.api.runtime.process.WorkItem;
import org.kie.kogito.prediction.api.PredictionOutcome;
import org.kie.kogito.prediction.api.PredictionService;
import org.kie.kogito.predictions.smile.AbstractPredictionEngine;
import org.kie.kogito.predictions.smile.AttributeType;
import org.kie.kogito.predictions.smile.RandomForestConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.RandomForest;
import smile.data.Attribute;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.NumericAttribute;
import smile.data.StringAttribute;

public class SmileRandomForest
extends AbstractPredictionEngine
implements PredictionService {
    public static final String IDENTIFIER = "SMILERandomForest";
    private static final String UNABLE_PARSE_TEXT = "Unable to parse text";
    private static final Logger logger = LoggerFactory.getLogger(SmileRandomForest.class);
    private final AttributeDataset dataset;
    private final Map<String, Attribute> smileAttributes;
    private final Attribute outcomeAttribute;
    private final AttributeType outcomeAttributeType;
    private final int numAttributes;
    private final int numberTrees;
    protected List<String> attributeNames = new ArrayList<String>();
    private Set<String> outcomeSet = new HashSet<String>();
    private static final int MINIMUM_OBSERVATIONS = 1200;
    private int observations = 0;

    public SmileRandomForest(RandomForestConfiguration configuration) {
        this(configuration.getInputFeatures(), configuration.getOutcomeName(), configuration.getOutcomeType(), configuration.getConfidenceThreshold(), configuration.getNumTrees());
    }

    public SmileRandomForest(Map<String, AttributeType> inputFeatures, String outputFeatureName, AttributeType outputFeatureType, double confidenceThreshold, int numberTrees) {
        super(inputFeatures, outputFeatureName, outputFeatureType, confidenceThreshold);
        this.numberTrees = numberTrees;
        this.smileAttributes = new HashMap<String, Attribute>();
        for (Map.Entry<String, AttributeType> inputFeature : inputFeatures.entrySet()) {
            String name = inputFeature.getKey();
            AttributeType type = inputFeature.getValue();
            this.smileAttributes.put(name, this.createAttribute(name, type));
            this.attributeNames.add(name);
        }
        this.numAttributes = this.smileAttributes.size();
        this.outcomeAttribute = this.createAttribute(outputFeatureName, outputFeatureType);
        this.outcomeAttributeType = outputFeatureType;
        this.dataset = new AttributeDataset("dataset", this.smileAttributes.values().toArray(new Attribute[this.numAttributes]), this.outcomeAttribute);
    }

    protected Attribute createAttribute(String name, AttributeType type) {
        if (type == AttributeType.NOMINAL || type == AttributeType.BOOLEAN) {
            return new NominalAttribute(name);
        }
        if (type == AttributeType.NUMERIC) {
            return new NumericAttribute(name);
        }
        return new StringAttribute(name);
    }

    protected Object convertValue(String value, AttributeType type) {
        if (type == AttributeType.NOMINAL) {
            return value;
        }
        if (type == AttributeType.NUMERIC) {
            return Long.valueOf(value);
        }
        if (type == AttributeType.BOOLEAN) {
            return Boolean.valueOf(value);
        }
        return value;
    }

    public void addData(Map<String, Object> data, Object outcome) {
        double[] features = new double[this.numAttributes];
        int i = 0;
        for (Map.Entry<String, Attribute> entry : this.smileAttributes.entrySet()) {
            try {
                features[i] = this.smileAttributes.get(entry.getKey()).valueOf(data.get(entry.getKey()).toString());
            }
            catch (ParseException e) {
                logger.error(UNABLE_PARSE_TEXT, (Throwable)e);
            }
            ++i;
        }
        try {
            String outcomeStr = outcome.toString();
            this.outcomeSet.add(outcomeStr);
            this.dataset.add(features, this.outcomeAttribute.valueOf(outcomeStr));
        }
        catch (ParseException e) {
            logger.error(UNABLE_PARSE_TEXT, (Throwable)e);
        }
    }

    protected double[] buildFeatures(Map<String, Object> data) {
        double[] features = new double[this.numAttributes];
        for (int i = 0; i < this.numAttributes; ++i) {
            String attrName = this.attributeNames.get(i);
            try {
                features[i] = this.smileAttributes.get(attrName).valueOf(data.get(attrName).toString());
                continue;
            }
            catch (ParseException e) {
                logger.error(UNABLE_PARSE_TEXT, (Throwable)e);
            }
        }
        return features;
    }

    public String getIdentifier() {
        return IDENTIFIER;
    }

    public PredictionOutcome predict(WorkItem task, Map<String, Object> inputData) {
        logger.debug("Predicting with input data: {}", inputData);
        RandomForest model = null;
        if (this.observations > 1200) {
            this.confidenceThreshold = 0.75;
        }
        HashMap<String, Object> outcomes = new HashMap<String, Object>();
        if (this.outcomeSet.size() >= 2) {
            model = new RandomForest(this.dataset, this.numberTrees);
            double[] features = this.buildFeatures(inputData);
            double[] posteriori = new double[this.outcomeSet.size()];
            double prediction = model.predict(features, posteriori);
            String predictionStr = this.dataset.responseAttribute().toString(prediction);
            outcomes.put(this.outcomeAttribute.getName(), this.convertValue(predictionStr, this.outcomeAttributeType));
            double confidence = posteriori[(int)prediction];
            outcomes.put("confidence", confidence);
            logger.debug("task id {}, total {} observations, prediction = {}, confidence = {} (threshold = {})", new Object[]{task.getId(), this.observations, predictionStr, confidence, this.confidenceThreshold});
            return new PredictionOutcome(confidence, this.confidenceThreshold, outcomes);
        }
        outcomes.put("confidence", 0.0);
        return new PredictionOutcome(0.0, this.confidenceThreshold, outcomes);
    }

    public void train(WorkItem task, Map<String, Object> inputData, Map<String, Object> outputData) {
        logger.debug("Training with input data: {}", inputData);
        logger.debug("Training with output data: {}", outputData);
        ++this.observations;
        this.addData(inputData, outputData.get(this.outcomeAttribute.getName()));
    }
}

