/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.regression;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasValue;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.NumericPredictor;
import org.dmg.pmml.regression.PMMLAttributes;
import org.dmg.pmml.regression.PredictorTerm;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.regression.RegressionModelUtil;
import org.jpmml.model.InvalidAttributeException;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.InvalidElementListException;
import org.jpmml.model.UnsupportedAttributeException;

public class RegressionModelEvaluator
extends ModelEvaluator<RegressionModel> {
    private RegressionModelEvaluator() {
    }

    public RegressionModelEvaluator(PMML pmml) {
        this(pmml, PMMLUtil.findModel(pmml, RegressionModel.class));
    }

    public RegressionModelEvaluator(PMML pmml, RegressionModel regressionModel) {
        super(pmml, regressionModel);
        List regressionTables = regressionModel.requireRegressionTables();
    }

    @Override
    public String getSummary() {
        return "Regression";
    }

    @Override
    protected <V extends Number> Map<String, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext context) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        TargetField targetField = this.getTargetField();
        String targetFieldName = regressionModel.getTargetField();
        if (targetFieldName != null && !Objects.equals(targetField.getName(), targetFieldName)) {
            throw new InvalidAttributeException((PMMLObject)regressionModel, PMMLAttributes.REGRESSIONMODEL_TARGETFIELD, (Object)targetFieldName);
        }
        List regressionTables = regressionModel.requireRegressionTables();
        if (regressionTables.size() != 1) {
            throw new InvalidElementListException(regressionTables);
        }
        RegressionTable regressionTable = (RegressionTable)regressionTables.get(0);
        Value<V> result = this.evaluateRegressionTable(valueFactory, regressionTable, context);
        if (result == null) {
            return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
        }
        RegressionModel.NormalizationMethod normalizationMethod = regressionModel.getNormalizationMethod();
        switch (normalizationMethod) {
            case NONE: 
            case SOFTMAX: 
            case LOGIT: 
            case EXP: 
            case PROBIT: 
            case CLOGLOG: 
            case LOGLOG: 
            case CAUCHIT: {
                RegressionModelUtil.normalizeRegressionResult(normalizationMethod, result);
                break;
            }
            case SIMPLEMAX: {
                throw new InvalidAttributeException((PMMLObject)regressionModel, (Enum)normalizationMethod);
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)regressionModel, (Enum)normalizationMethod);
            }
        }
        return TargetUtil.evaluateRegression(targetField, result);
    }

    @Override
    protected <V extends Number> Map<String, ? extends Classification<?, V>> evaluateClassification(ValueFactory<V> valueFactory, EvaluationContext context) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        TargetField targetField = this.getTargetField();
        String targetFieldName = regressionModel.getTargetField();
        if (targetFieldName != null && !Objects.equals(targetField.getName(), targetFieldName)) {
            throw new InvalidAttributeException((PMMLObject)regressionModel, PMMLAttributes.REGRESSIONMODEL_TARGETFIELD, (Object)targetFieldName);
        }
        OpType opType = targetField.getOpType();
        switch (opType) {
            case CATEGORICAL: 
            case ORDINAL: {
                break;
            }
            default: {
                throw new InvalidElementException((PMMLObject)regressionModel);
            }
        }
        List regressionTables = regressionModel.requireRegressionTables();
        if (regressionTables.size() < 2) {
            throw new InvalidElementListException(regressionTables);
        }
        List<Object> targetCategories = targetField.getCategories();
        if (targetCategories != null && targetCategories.size() != regressionTables.size()) {
            throw new InvalidElementListException(regressionTables);
        }
        ValueMap<Object, Value<V>> values = new ValueMap<Object, Value<V>>(2 * regressionTables.size());
        int max = regressionTables.size();
        for (int i = 0; i < max; ++i) {
            RegressionTable regressionTable = (RegressionTable)regressionTables.get(i);
            Object targetCategory = regressionTable.requireTargetCategory();
            if (targetCategories != null && targetCategories.indexOf(targetCategory) < 0) {
                throw new InvalidAttributeException((PMMLObject)regressionTable, PMMLAttributes.REGRESSIONTABLE_TARGETCATEGORY, targetCategory);
            }
            Value<V> value = this.evaluateRegressionTable(valueFactory, regressionTable, context);
            if (value == null) {
                return TargetUtil.evaluateClassificationDefault(valueFactory, targetField);
            }
            values.put(targetCategory, value);
        }
        RegressionModel.NormalizationMethod normalizationMethod = regressionModel.getNormalizationMethod();
        block3 : switch (opType) {
            case CATEGORICAL: {
                if (values.size() == 2) {
                    switch (normalizationMethod) {
                        case NONE: 
                        case LOGIT: 
                        case PROBIT: 
                        case CLOGLOG: 
                        case LOGLOG: 
                        case CAUCHIT: {
                            RegressionModelUtil.computeBinomialProbabilities(normalizationMethod, values);
                            break block3;
                        }
                        case SOFTMAX: 
                        case SIMPLEMAX: {
                            if (RegressionModelEvaluator.isDefault((RegressionTable)regressionTables.get(1)) && normalizationMethod == RegressionModel.NormalizationMethod.SOFTMAX) {
                                RegressionModelUtil.computeBinomialProbabilities(RegressionModel.NormalizationMethod.LOGIT, values);
                                break block3;
                            }
                            RegressionModelUtil.computeMultinomialProbabilities(normalizationMethod, values);
                            break block3;
                        }
                        case EXP: {
                            throw new InvalidAttributeException((PMMLObject)regressionModel, (Enum)normalizationMethod);
                        }
                    }
                    throw new UnsupportedAttributeException((PMMLObject)regressionModel, (Enum)normalizationMethod);
                }
                switch (normalizationMethod) {
                    case NONE: 
                    case SOFTMAX: 
                    case SIMPLEMAX: {
                        RegressionModelUtil.computeMultinomialProbabilities(normalizationMethod, values);
                        break block3;
                    }
                    case LOGIT: 
                    case EXP: 
                    case PROBIT: 
                    case CLOGLOG: 
                    case LOGLOG: 
                    case CAUCHIT: {
                        if (RegressionModel.NormalizationMethod.LOGIT.equals((Object)normalizationMethod)) {
                            RegressionModelUtil.computeMultinomialProbabilities(normalizationMethod, values);
                            break block3;
                        }
                        throw new InvalidAttributeException((PMMLObject)regressionModel, (Enum)normalizationMethod);
                    }
                }
                throw new UnsupportedAttributeException((PMMLObject)regressionModel, (Enum)normalizationMethod);
            }
            case ORDINAL: {
                switch (normalizationMethod) {
                    case NONE: 
                    case LOGIT: 
                    case PROBIT: 
                    case CLOGLOG: 
                    case LOGLOG: 
                    case CAUCHIT: {
                        RegressionModelUtil.computeOrdinalProbabilities(normalizationMethod, values);
                        break block3;
                    }
                    case SOFTMAX: 
                    case EXP: 
                    case SIMPLEMAX: {
                        throw new InvalidAttributeException((PMMLObject)regressionModel, (Enum)normalizationMethod);
                    }
                }
                throw new UnsupportedAttributeException((PMMLObject)regressionModel, (Enum)normalizationMethod);
            }
            default: {
                throw new InvalidElementException((PMMLObject)regressionModel);
            }
        }
        Classification result = this.createClassification(values);
        return TargetUtil.evaluateClassification(targetField, result);
    }

    private <V extends Number> Value<V> evaluateRegressionTable(ValueFactory<V> valueFactory, RegressionTable regressionTable, EvaluationContext context) {
        Number intercept;
        int i;
        Value<V> result = valueFactory.newValue();
        if (regressionTable.hasNumericPredictors()) {
            List numericPredictors = regressionTable.getNumericPredictors();
            int max = numericPredictors.size();
            for (int i2 = 0; i2 < max; ++i2) {
                NumericPredictor numericPredictor = (NumericPredictor)numericPredictors.get(i2);
                FieldValue value = context.evaluate(numericPredictor.requireField());
                if (FieldValueUtil.isMissing(value)) {
                    return null;
                }
                int exponent = numericPredictor.getExponent();
                if (exponent != 1) {
                    result.add(numericPredictor.requireCoefficient(), value.asNumber(), exponent);
                    continue;
                }
                result.add(numericPredictor.requireCoefficient(), value.asNumber());
            }
        }
        if (regressionTable.hasCategoricalPredictors()) {
            List categoricalPredictors = regressionTable.getCategoricalPredictors();
            String matchedFieldName = null;
            int max = categoricalPredictors.size();
            for (i = 0; i < max; ++i) {
                FieldValue value;
                CategoricalPredictor categoricalPredictor = (CategoricalPredictor)categoricalPredictors.get(i);
                String fieldName = categoricalPredictor.requireField();
                if (matchedFieldName != null) {
                    if (matchedFieldName.equals(fieldName)) continue;
                    matchedFieldName = null;
                }
                if (FieldValueUtil.isMissing(value = context.evaluate(fieldName))) {
                    matchedFieldName = fieldName;
                    continue;
                }
                boolean equals = value.equals((HasValue<?>)categoricalPredictor);
                if (!equals) continue;
                matchedFieldName = fieldName;
                result.add(categoricalPredictor.requireCoefficient());
            }
        }
        if (regressionTable.hasPredictorTerms()) {
            List predictorTerms = regressionTable.getPredictorTerms();
            ArrayList<Number> factors = new ArrayList<Number>();
            int max = predictorTerms.size();
            for (i = 0; i < max; ++i) {
                PredictorTerm predictorTerm = (PredictorTerm)predictorTerms.get(i);
                factors.clear();
                Number coefficient = predictorTerm.requireCoefficient();
                List fieldRefs = predictorTerm.requireFieldRefs();
                for (FieldRef fieldRef : fieldRefs) {
                    FieldValue value = ExpressionUtil.evaluate((Expression)fieldRef, context);
                    if (FieldValueUtil.isMissing(value)) {
                        return null;
                    }
                    factors.add(value.asNumber());
                }
                if (factors.size() == 1) {
                    result.add(coefficient, (Number)factors.get(0));
                    continue;
                }
                if (factors.size() == 2) {
                    result.add(coefficient, (Number)factors.get(0), (Number)factors.get(1));
                    continue;
                }
                result.add(coefficient, factors.toArray(new Number[factors.size()]));
            }
        }
        if ((intercept = regressionTable.requireIntercept()).doubleValue() != 0.0) {
            result.add(intercept);
        }
        return result;
    }

    private static boolean isDefault(RegressionTable regressionTable) {
        if (regressionTable.hasNumericPredictors() || regressionTable.hasCategoricalPredictors() || regressionTable.hasPredictorTerms()) {
            return false;
        }
        Number intercept = regressionTable.requireIntercept();
        return intercept.doubleValue() == 0.0;
    }
}

