/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.naive_bayes;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.math3.util.Precision;
import org.dmg.pmml.ContinuousDistribution;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Extension;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.GaussianDistribution;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.PoissonDistribution;
import org.dmg.pmml.naive_bayes.BayesInput;
import org.dmg.pmml.naive_bayes.BayesInputs;
import org.dmg.pmml.naive_bayes.BayesOutput;
import org.dmg.pmml.naive_bayes.NaiveBayesModel;
import org.dmg.pmml.naive_bayes.PairCounts;
import org.dmg.pmml.naive_bayes.TargetValueCount;
import org.dmg.pmml.naive_bayes.TargetValueCounts;
import org.dmg.pmml.naive_bayes.TargetValueStat;
import org.dmg.pmml.naive_bayes.TargetValueStats;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.DiscretizationUtil;
import org.jpmml.evaluator.DistributionUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.HasParsedValueMapping;
import org.jpmml.evaluator.InvalidFeatureException;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedFeatureException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueUtil;
import org.jpmml.evaluator.VerificationUtil;
import org.jpmml.evaluator.naive_bayes.ProbabilityMap;

public class NaiveBayesModelEvaluator
extends ModelEvaluator<NaiveBayesModel> {
    private transient List<BayesInput> bayesInputs = null;
    private transient Map<FieldName, Map<String, Double>> fieldCountSums = null;
    private static final LoadingCache<NaiveBayesModel, List<BayesInput>> bayesInputCache = CacheUtil.buildLoadingCache(new CacheLoader<NaiveBayesModel, List<BayesInput>>(){

        public List<BayesInput> load(NaiveBayesModel naiveBayesModel) {
            return ImmutableList.copyOf((Collection)NaiveBayesModelEvaluator.parseBayesInputs(naiveBayesModel));
        }
    });
    private static final LoadingCache<NaiveBayesModel, Map<FieldName, Map<String, Double>>> fieldCountSumCache = CacheUtil.buildLoadingCache(new CacheLoader<NaiveBayesModel, Map<FieldName, Map<String, Double>>>(){

        public Map<FieldName, Map<String, Double>> load(NaiveBayesModel naiveBayesModel) {
            return ImmutableMap.copyOf((Map)NaiveBayesModelEvaluator.calculateFieldCountSums(naiveBayesModel));
        }
    });

    public NaiveBayesModelEvaluator(PMML pmml) {
        this(pmml, NaiveBayesModelEvaluator.selectModel(pmml, NaiveBayesModel.class));
    }

    public NaiveBayesModelEvaluator(PMML pmml, NaiveBayesModel naiveBayesModel) {
        super(pmml, naiveBayesModel);
        BayesInputs bayesInputs = naiveBayesModel.getBayesInputs();
        if (bayesInputs == null) {
            throw new InvalidFeatureException((PMMLObject)naiveBayesModel);
        }
        if (!bayesInputs.hasBayesInputs() && !bayesInputs.hasExtensions()) {
            throw new InvalidFeatureException((PMMLObject)bayesInputs);
        }
        BayesOutput bayesOutput = naiveBayesModel.getBayesOutput();
        if (bayesOutput == null) {
            throw new InvalidFeatureException((PMMLObject)naiveBayesModel);
        }
        TargetValueCounts targetValueCounts = bayesOutput.getTargetValueCounts();
        if (targetValueCounts == null) {
            throw new InvalidFeatureException((PMMLObject)bayesOutput);
        }
        if (!targetValueCounts.hasTargetValueCounts()) {
            throw new InvalidFeatureException((PMMLObject)targetValueCounts);
        }
    }

    @Override
    public String getSummary() {
        return "Naive Bayes model";
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        Map<FieldName, ? extends Classification> predictions;
        ValueFactory<Double> valueFactory;
        NaiveBayesModel naiveBayesModel = (NaiveBayesModel)this.getModel();
        if (!naiveBayesModel.isScorable()) {
            throw new InvalidResultException((PMMLObject)naiveBayesModel);
        }
        MathContext mathContext = naiveBayesModel.getMathContext();
        switch (mathContext) {
            case DOUBLE: {
                valueFactory = this.getValueFactory();
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)naiveBayesModel, (Enum<?>)mathContext);
            }
        }
        MiningFunction miningFunction = naiveBayesModel.getMiningFunction();
        switch (miningFunction) {
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(valueFactory, context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)naiveBayesModel, (Enum<?>)miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ? extends Classification> evaluateClassification(final ValueFactory<Double> valueFactory, EvaluationContext context) {
        NaiveBayesModel naiveBayesModel = (NaiveBayesModel)this.getModel();
        TargetField targetField = this.getTargetField();
        double threshold = naiveBayesModel.getThreshold();
        Map<FieldName, Map<String, Double>> fieldCountSums = this.getFieldCountSums();
        ProbabilityMap<String, Double> probabilities = new ProbabilityMap<String, Double>(){

            @Override
            public ValueFactory<Double> getValueFactory() {
                return valueFactory;
            }

            @Override
            public void multiply(String key, double probability) {
                Value value = this.ensureValue(key);
                value.add(Math.log(probability));
            }
        };
        List<BayesInput> bayesInputs = this.getBayesInputs();
        for (BayesInput bayesInput : bayesInputs) {
            FieldName name = bayesInput.getFieldName();
            FieldValue value = context.evaluate(name);
            if (value == null) continue;
            TargetValueStats targetValueStats = NaiveBayesModelEvaluator.getTargetValueStats(bayesInput);
            if (targetValueStats != null) {
                this.calculateContinuousProbabilities(probabilities, targetValueStats, threshold, value);
                continue;
            }
            DerivedField derivedField = bayesInput.getDerivedField();
            if (derivedField != null) {
                Expression expression = derivedField.getExpression();
                if (expression == null) {
                    throw new InvalidFeatureException((PMMLObject)derivedField);
                }
                if (expression instanceof Discretize) {
                    Discretize discretize = (Discretize)expression;
                    if ((value = DiscretizationUtil.discretize(discretize, value)) == null) {
                        throw new EvaluationException();
                    }
                    value = FieldValueUtil.refine((Field)derivedField, value);
                } else {
                    throw new UnsupportedFeatureException((PMMLObject)expression);
                }
            }
            Map<String, Double> countSums = fieldCountSums.get(name);
            TargetValueCounts targetValueCounts = NaiveBayesModelEvaluator.getTargetValueCounts(bayesInput, value);
            if (targetValueCounts == null) continue;
            this.calculateDiscreteProbabilities(probabilities, targetValueCounts, threshold, countSums);
        }
        BayesOutput bayesOutput = naiveBayesModel.getBayesOutput();
        this.calculatePriorProbabilities(probabilities, bayesOutput.getTargetValueCounts());
        ValueUtil.normalizeSoftMax(probabilities);
        ProbabilityDistribution result = new ProbabilityDistribution(probabilities.asDoubleMap());
        FieldName targetFieldName = bayesOutput.getFieldName();
        if (targetFieldName == null || !Objects.equals(targetField.getName(), targetFieldName)) {
            throw new InvalidFeatureException((PMMLObject)bayesOutput);
        }
        return TargetUtil.evaluateClassification(targetField, result);
    }

    private void calculateContinuousProbabilities(ProbabilityMap<String, Double> probabilities, TargetValueStats targetValueStats, double threshold, FieldValue value) {
        Number x = value.asNumber();
        for (TargetValueStat targetValueStat : targetValueStats) {
            String targetCategory = targetValueStat.getValue();
            if (targetCategory == null) {
                throw new InvalidFeatureException((PMMLObject)targetValueStat);
            }
            ContinuousDistribution distribution = targetValueStat.getContinuousDistribution();
            if (!(distribution instanceof GaussianDistribution) && !(distribution instanceof PoissonDistribution)) {
                throw new InvalidFeatureException((PMMLObject)targetValueStat);
            }
            if (DistributionUtil.isNoOp(distribution)) continue;
            double probability = DistributionUtil.probability(distribution, x);
            probability = Math.max(probability, threshold);
            probabilities.multiply(targetCategory, probability);
        }
    }

    private void calculateDiscreteProbabilities(ProbabilityMap<String, Double> probabilities, TargetValueCounts targetValueCounts, double threshold, Map<String, Double> countSums) {
        for (TargetValueCount targetValueCount : targetValueCounts) {
            double probability;
            String targetCategory = targetValueCount.getValue();
            if (targetCategory == null) {
                throw new InvalidFeatureException(targetCategory);
            }
            double count = targetValueCount.getCount();
            if (VerificationUtil.isZero(count, Precision.EPSILON)) {
                probability = threshold;
            } else {
                Double countSum = countSums.get(targetCategory);
                probability = count / countSum;
            }
            probabilities.multiply(targetCategory, probability);
        }
    }

    private void calculatePriorProbabilities(ProbabilityMap<String, Double> probabilities, TargetValueCounts targetValueCounts) {
        for (TargetValueCount targetValueCount : targetValueCounts) {
            String targetCategory = targetValueCount.getValue();
            if (targetCategory == null) {
                throw new InvalidFeatureException((PMMLObject)targetValueCount);
            }
            double probability = targetValueCount.getCount();
            probabilities.multiply(targetCategory, probability);
        }
    }

    protected List<BayesInput> getBayesInputs() {
        if (this.bayesInputs == null) {
            this.bayesInputs = this.getValue(bayesInputCache);
        }
        return this.bayesInputs;
    }

    protected Map<FieldName, Map<String, Double>> getFieldCountSums() {
        if (this.fieldCountSums == null) {
            this.fieldCountSums = this.getValue(fieldCountSumCache);
        }
        return this.fieldCountSums;
    }

    private static Map<FieldName, Map<String, Double>> calculateFieldCountSums(NaiveBayesModel naiveBayesModel) {
        LinkedHashMap<FieldName, Map<String, Double>> result = new LinkedHashMap<FieldName, Map<String, Double>>();
        List<BayesInput> bayesInputs = CacheUtil.getValue(naiveBayesModel, bayesInputCache);
        for (BayesInput bayesInput : bayesInputs) {
            FieldName name = bayesInput.getFieldName();
            LinkedHashMap<String, Double> counts = new LinkedHashMap<String, Double>();
            List pairCounts = bayesInput.getPairCounts();
            for (PairCounts pairCount : pairCounts) {
                TargetValueCounts targetValueCounts = pairCount.getTargetValueCounts();
                for (TargetValueCount targetValueCount : targetValueCounts) {
                    Double count = (Double)counts.get(targetValueCount.getValue());
                    if (count == null) {
                        count = 0.0;
                    }
                    counts.put(targetValueCount.getValue(), count + targetValueCount.getCount());
                }
            }
            result.put(name, counts);
        }
        return result;
    }

    private static List<BayesInput> parseBayesInputs(NaiveBayesModel naiveBayesModel) {
        BayesInputs bayesInputs = naiveBayesModel.getBayesInputs();
        if (!bayesInputs.hasExtensions()) {
            return bayesInputs.getBayesInputs();
        }
        ArrayList<BayesInput> result = new ArrayList<BayesInput>(bayesInputs.getBayesInputs());
        List extensions = bayesInputs.getExtensions();
        for (Extension extension : extensions) {
            List objects = extension.getContent();
            for (Object object : objects) {
                if (!(object instanceof BayesInput)) continue;
                BayesInput bayesInput = (BayesInput)object;
                result.add(bayesInput);
            }
        }
        return result;
    }

    private static TargetValueStats getTargetValueStats(BayesInput bayesInput) {
        return bayesInput.getTargetValueStats();
    }

    private static TargetValueCounts getTargetValueCounts(BayesInput bayesInput, FieldValue value) {
        if (bayesInput instanceof HasParsedValueMapping) {
            HasParsedValueMapping hasParsedValueMapping = (HasParsedValueMapping)bayesInput;
            return (TargetValueCounts)value.getMapping(hasParsedValueMapping);
        }
        List pairCounts = bayesInput.getPairCounts();
        for (PairCounts pairCount : pairCounts) {
            if (!value.equalsString(pairCount.getValue())) continue;
            TargetValueCounts targetValueCounts = pairCount.getTargetValueCounts();
            if (targetValueCounts == null) {
                throw new InvalidFeatureException((PMMLObject)pairCount);
            }
            return targetValueCounts;
        }
        return null;
    }
}

