/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.translator.regression;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimaps;
import com.sun.codemodel.JBlock;
import com.sun.codemodel.JClass;
import com.sun.codemodel.JDefinedClass;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JFieldVar;
import com.sun.codemodel.JForEach;
import com.sun.codemodel.JForLoop;
import com.sun.codemodel.JInvocation;
import com.sun.codemodel.JMethod;
import com.sun.codemodel.JStatement;
import com.sun.codemodel.JType;
import com.sun.codemodel.JVar;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Output;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.TextIndex;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.NumericPredictor;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.InvalidElementException;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.UnsupportedElementException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.VoteDistribution;
import org.jpmml.evaluator.regression.RegressionModelUtil;
import org.jpmml.translator.FieldInfo;
import org.jpmml.translator.FunctionInvocation;
import org.jpmml.translator.IdentifierUtil;
import org.jpmml.translator.JBinaryFileInitializer;
import org.jpmml.translator.MethodScope;
import org.jpmml.translator.ModelTranslator;
import org.jpmml.translator.OperableRef;
import org.jpmml.translator.PMMLObjectUtil;
import org.jpmml.translator.Scope;
import org.jpmml.translator.TextIndexUtil;
import org.jpmml.translator.TranslationContext;
import org.jpmml.translator.ValueBuilder;
import org.jpmml.translator.ValueMapBuilder;

public class RegressionModelTranslator
extends ModelTranslator<RegressionModel> {
    public RegressionModelTranslator(PMML pmml, RegressionModel regressionModel) {
        super(pmml, regressionModel);
        MiningFunction miningFunction = regressionModel.getMiningFunction();
        switch (miningFunction) {
            case REGRESSION: 
            case CLASSIFICATION: {
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)regressionModel, (Enum)miningFunction);
            }
        }
        List regressionTables = regressionModel.getRegressionTables();
        for (RegressionTable regressionTable : regressionTables) {
            if (!regressionTable.hasPredictorTerms()) continue;
            List predictorTerms = regressionTable.getPredictorTerms();
            throw new UnsupportedElementException((PMMLObject)Iterables.getFirst((Iterable)predictorTerms, null));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public JMethod translateRegressor(TranslationContext context) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        List regressionTables = regressionModel.getRegressionTables();
        Map<FieldName, FieldInfo> fieldInfos = this.getFieldInfos(new HashSet(regressionTables));
        RegressionTable regressionTable = (RegressionTable)Iterables.getOnlyElement((Iterable)regressionTables);
        JMethod evaluateMethod = RegressionModelTranslator.createEvaluatorMethod(Value.class, (PMMLObject)regressionTable, true, context);
        try {
            context.pushScope(new MethodScope(evaluateMethod));
            ValueBuilder valueBuilder = RegressionModelTranslator.translateRegressionTable(regressionTable, fieldInfos, context);
            RegressionModelTranslator.computeValue(valueBuilder, regressionModel, context);
        }
        finally {
            context.popScope();
        }
        return evaluateMethod;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public JMethod translateClassifier(TranslationContext context) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        List regressionTables = regressionModel.getRegressionTables();
        Map<FieldName, FieldInfo> fieldInfos = this.getFieldInfos(new HashSet(regressionTables));
        JMethod evaluateListMethod = RegressionModelTranslator.createEvaluatorMethod(Classification.class, regressionTables, true, context);
        try {
            context.pushScope(new MethodScope(evaluateListMethod));
            ValueMapBuilder valueMapBuilder = new ValueMapBuilder(context).construct("values");
            for (RegressionTable regressionTable : regressionTables) {
                JMethod evaluateMethod = RegressionModelTranslator.createEvaluatorMethod(Value.class, (PMMLObject)regressionTable, true, context);
                try {
                    context.pushScope(new MethodScope(evaluateMethod));
                    ValueBuilder valueBuilder = RegressionModelTranslator.translateRegressionTable(regressionTable, fieldInfos, context);
                    context._return((JExpression)valueBuilder.getVariable());
                }
                finally {
                    context.popScope();
                }
                valueMapBuilder.update("put", regressionTable.getTargetCategory(), RegressionModelTranslator.createEvaluatorMethodInvocation(evaluateMethod, context));
            }
            RegressionModelTranslator.computeClassification(valueMapBuilder, regressionModel, context);
        }
        finally {
            context.popScope();
        }
        return evaluateListMethod;
    }

    public static void computeValue(ValueBuilder valueBuilder, RegressionModel regressionModel, TranslationContext context) {
        RegressionModel.NormalizationMethod normalizationMethod = regressionModel.getNormalizationMethod();
        switch (normalizationMethod) {
            case NONE: {
                break;
            }
            default: {
                valueBuilder.staticUpdate(RegressionModelUtil.class, "normalizeRegressionResult", normalizationMethod);
            }
        }
        context._return((JExpression)valueBuilder.getVariable());
    }

    public static void computeClassification(ValueMapBuilder valueMapBuilder, RegressionModel regressionModel, TranslationContext context) {
        RegressionModel.NormalizationMethod normalizationMethod = regressionModel.getNormalizationMethod();
        List regressionTables = regressionModel.getRegressionTables();
        Output output = regressionModel.getOutput();
        if (regressionTables.size() == 2) {
            valueMapBuilder.staticUpdate(RegressionModelUtil.class, "computeBinomialProbabilities", normalizationMethod);
        } else if (regressionTables.size() >= 2) {
            valueMapBuilder.staticUpdate(RegressionModelUtil.class, "computeMultinomialProbabilities", normalizationMethod);
        } else {
            throw new InvalidElementException((PMMLObject)regressionModel);
        }
        boolean probabilistic = false;
        if (output != null && output.hasOutputFields()) {
            List outputFields = output.getOutputFields();
            List probabilityOutputFields = outputFields.stream().filter(outputField -> {
                ResultFeature resultFeature = outputField.getResultFeature();
                switch (resultFeature) {
                    case PROBABILITY: {
                        return true;
                    }
                }
                return false;
            }).collect(Collectors.toList());
            probabilistic = regressionTables.size() == probabilityOutputFields.size();
        }
        JInvocation classificationExpr = probabilistic ? context._new(ProbabilityDistribution.class, valueMapBuilder) : context._new(VoteDistribution.class, valueMapBuilder);
        context._return((JExpression)classificationExpr);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static ValueBuilder translateRegressionTable(RegressionTable regressionTable, Map<FieldName, FieldInfo> fieldInfos, TranslationContext context) {
        ValueBuilder valueBuilder = new ValueBuilder(context).declare(IdentifierUtil.create("result", (PMMLObject)regressionTable), context.getValueFactoryVariable().newValue());
        if (regressionTable.hasNumericPredictors()) {
            List numericPredictors = regressionTable.getNumericPredictors();
            ArrayListMultimap tfTerms = ArrayListMultimap.create();
            for (NumericPredictor numericPredictor : numericPredictors) {
                FieldInfo fieldInfo = RegressionModelTranslator.getFieldInfo(numericPredictor, fieldInfos);
                FunctionInvocation functionInvocation = fieldInfo.getFunctionInvocation();
                if (functionInvocation instanceof FunctionInvocation.Tf || functionInvocation instanceof FunctionInvocation.TfIdf) {
                    FunctionInvocationPredictor tfTerm = new FunctionInvocationPredictor(numericPredictor, functionInvocation);
                    FunctionInvocation.Tf tf = tfTerm.getTf();
                    tfTerms.put((Object)tf.getTextField(), (Object)tfTerm);
                    continue;
                }
                Number coefficient = numericPredictor.getCoefficient();
                Integer exponent = numericPredictor.getExponent();
                OperableRef operableRef = context.ensureOperableVariable(fieldInfo);
                if (exponent != null && exponent != 1) {
                    valueBuilder.update("add", coefficient, operableRef.getVariable(), exponent);
                    continue;
                }
                if (coefficient.doubleValue() != 1.0) {
                    valueBuilder.update("add", coefficient, operableRef.getVariable());
                    continue;
                }
                valueBuilder.update("add", operableRef.getVariable());
            }
            RegressionModelTranslator.addTermFrequencies(regressionTable, valueBuilder, Multimaps.asMap((ListMultimap)tfTerms), fieldInfos, context);
        }
        if (regressionTable.hasCategoricalPredictors()) {
            Map fieldCategoricalPredictors = regressionTable.getCategoricalPredictors().stream().collect(Collectors.groupingBy(categoricalPredictor -> categoricalPredictor.getField(), Collectors.toList()));
            JBlock block = context.block();
            Set entries = fieldCategoricalPredictors.entrySet();
            for (Map.Entry entry : entries) {
                FieldInfo fieldInfo = RegressionModelTranslator.getFieldInfo((FieldName)entry.getKey(), fieldInfos);
                JMethod evaluateCategoryMethod = RegressionModelTranslator.createEvaluatorMethod(Number.class, (List)entry.getValue(), false, context);
                try {
                    context.pushScope(new MethodScope(evaluateCategoryMethod));
                    OperableRef operableRef = context.ensureOperableVariable(fieldInfo);
                    Map<Object, Number> categoryValues = ((List)entry.getValue()).stream().collect(Collectors.toMap(CategoricalPredictor::getValue, CategoricalPredictor::getCoefficient));
                    context._return((JExpression)operableRef.getVariable(), categoryValues, null);
                }
                finally {
                    context.popScope();
                }
                JVar categoryValueVar = context.declare(Number.class, IdentifierUtil.create("lookup", (FieldName)entry.getKey()), (JExpression)RegressionModelTranslator.createEvaluatorMethodInvocation(evaluateCategoryMethod, context));
                JBlock thenBlock = block._if(categoryValueVar.ne(JExpr._null()))._then();
                try {
                    context.pushScope(new Scope(thenBlock));
                    valueBuilder.update("add", categoryValueVar);
                }
                finally {
                    context.popScope();
                }
            }
        }
        if (regressionTable.hasPredictorTerms()) {
            List predictorTerms = regressionTable.getPredictorTerms();
            throw new UnsupportedElementException((PMMLObject)Iterables.getFirst((Iterable)predictorTerms, null));
        }
        Number intercept = regressionTable.getIntercept();
        if (intercept != null && intercept.doubleValue() != 0.0) {
            valueBuilder.update("add", intercept);
        }
        return valueBuilder;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void addTermFrequencies(RegressionTable regressionTable, ValueBuilder valueBuilder, Map<FieldName, List<FunctionInvocationPredictor>> tfTerms, Map<FieldName, FieldInfo> fieldInfos, TranslationContext context) {
        JDefinedClass owner = context.getOwner();
        if (tfTerms.isEmpty()) {
            return;
        }
        JBinaryFileInitializer resourceInitializer = new JBinaryFileInitializer(IdentifierUtil.create(RegressionTable.class.getSimpleName(), (PMMLObject)regressionTable) + ".data", context);
        Function<FunctionInvocationPredictor, TextIndex> textIndexFunction = new Function<FunctionInvocationPredictor, TextIndex>(){

            @Override
            public TextIndex apply(FunctionInvocationPredictor tfTerm) {
                FunctionInvocation.Tf tf = tfTerm.getTf();
                return tf.getTextIndex();
            }
        };
        Function<FunctionInvocationPredictor, List<String>> termFunction = new Function<FunctionInvocationPredictor, List<String>>(){

            @Override
            public List<String> apply(FunctionInvocationPredictor tfTerm) {
                FunctionInvocation.Tf tf = tfTerm.getTf();
                return tf.getTermTokens();
            }
        };
        Function<FunctionInvocationPredictor, Number> coefficientFunction = new Function<FunctionInvocationPredictor, Number>(){

            @Override
            public Number apply(FunctionInvocationPredictor tfTerm) {
                NumericPredictor numericPredictor = tfTerm.numericPredictor;
                return numericPredictor.getCoefficient();
            }
        };
        Function<FunctionInvocationPredictor, Number> weightFunction = new Function<FunctionInvocationPredictor, Number>(){

            @Override
            public Number apply(FunctionInvocationPredictor tfTerm) {
                FunctionInvocation.TfIdf tfIdf = tfTerm.getTfIdf();
                return tfIdf.getWeight();
            }
        };
        Set<Map.Entry<FieldName, List<FunctionInvocationPredictor>>> entries = tfTerms.entrySet();
        for (Map.Entry entry : entries) {
            FieldName name = (FieldName)entry.getKey();
            Collection predictors = (Collection)entry.getValue();
            Set textIndexes = predictors.stream().map(textIndexFunction).collect(Collectors.toSet());
            TextIndex textIndex = (TextIndex)Iterables.getOnlyElement(textIndexes);
            TextIndex localTextIndex = TextIndexUtil.toLocalTextIndex(textIndex, name);
            JFieldVar textIndexVar = owner.field(28, (JType)context.ref(TextIndex.class), IdentifierUtil.create("textIndex", (PMMLObject)regressionTable, name), (JExpression)PMMLObjectUtil.createObject((PMMLObject)localTextIndex, context));
            List[] terms = (List[])predictors.stream().map(termFunction).toArray(List[]::new);
            JFieldVar termsVar = resourceInitializer.initStringLists(IdentifierUtil.create("terms", (PMMLObject)regressionTable, name), terms);
            JFieldVar termIndicesVar = owner.field(28, (JType)context.ref(Map.class).narrow(Arrays.asList(context.ref(List.class).narrow(String.class), context.ref(Integer.class))), IdentifierUtil.create("termIndices", (PMMLObject)regressionTable, name), (JExpression)JExpr._new((JClass)context.ref(LinkedHashMap.class).narrow(Collections.emptyList())));
            JForLoop termIndicesForLoop = new JForLoop();
            JVar termIndicesLoopVar = termIndicesForLoop.init(context._ref(Integer.TYPE), "i", JExpr.lit((int)0));
            termIndicesForLoop.test(termIndicesLoopVar.lt(JExpr.lit((int)terms.length)));
            termIndicesForLoop.update(termIndicesLoopVar.incr());
            JBlock termIndicesForBlock = termIndicesForLoop.body();
            termIndicesForBlock.add((JStatement)termIndicesVar.invoke("put").arg((JExpression)termsVar.invoke("get").arg((JExpression)termIndicesLoopVar)).arg((JExpression)termIndicesLoopVar));
            resourceInitializer.add((JStatement)termIndicesForLoop);
            Number[] coefficients = (Number[])predictors.stream().map(coefficientFunction).toArray(Number[]::new);
            JFieldVar coefficientsVar = resourceInitializer.initNumbers(IdentifierUtil.create("coefficients", (PMMLObject)regressionTable, name), MathContext.DOUBLE, coefficients);
            Number[] weights = (Number[])predictors.stream().map(weightFunction).toArray(Number[]::new);
            JFieldVar weightsVar = null;
            if (Arrays.stream(weights).anyMatch(weight -> weights != null && weight.doubleValue() != 1.0)) {
                weightsVar = resourceInitializer.initNumbers(IdentifierUtil.create("weights", (PMMLObject)regressionTable, name), MathContext.DOUBLE, weights);
            }
            int maxLength = Arrays.stream(terms).mapToInt(List::size).max().orElseThrow(NoSuchElementException::new);
            JVar termFrequencyTableVar = (JVar)TextIndexUtil.computeTermFrequencyTable(null, localTextIndex, (JExpression)textIndexVar, (JExpression)termIndicesVar.invoke("keySet"), maxLength, context);
            JVar entriesVar = context.declare((JType)context.ref(Collection.class).narrow(context.ref(Map.Entry.class).narrow(((JClass)termFrequencyTableVar.type()).getTypeParameters())), "entries", (JExpression)termFrequencyTableVar.invoke("entrySet"));
            JBlock block = context.block();
            JForEach entriesForEach = block.forEach((JType)((JClass)entriesVar.type()).getTypeParameters().get(0), "entry", (JExpression)entriesVar);
            try {
                context.pushScope(new Scope(entriesForEach.body()));
                JVar termVar = context.declare((JType)context.ref(List.class).narrow(String.class), "term", (JExpression)entriesForEach.var().invoke("getKey"));
                JVar frequencyVar = context.declare((JType)context.ref(Integer.class), "frequency", (JExpression)entriesForEach.var().invoke("getValue"));
                JVar indexVar = context.declare((JType)context.ref(Integer.class), "termIndex", (JExpression)termIndicesVar.invoke("get").arg((JExpression)termVar));
                JVar coefficientVar = context.declare((JType)context.ref(Number.class), "coefficient", (JExpression)coefficientsVar.invoke("get").arg((JExpression)indexVar));
                if (weightsVar != null) {
                    JVar weightVar = context.declare((JType)context.ref(Number.class), "weight", (JExpression)weightsVar.invoke("get").arg((JExpression)indexVar));
                    valueBuilder.update("add", coefficientVar, weightVar, frequencyVar);
                    continue;
                }
                valueBuilder.update("add", coefficientVar, frequencyVar);
            }
            finally {
                context.popScope();
            }
        }
    }

    private static class FunctionInvocationPredictor {
        private NumericPredictor numericPredictor = null;
        private FunctionInvocation functionInvocation = null;

        private FunctionInvocationPredictor(NumericPredictor numericPredictor, FunctionInvocation functionInvocation) {
            this.numericPredictor = numericPredictor;
            this.functionInvocation = functionInvocation;
        }

        public FunctionInvocation.Tf getTf() {
            return TextIndexUtil.asTf(this.functionInvocation);
        }

        public FunctionInvocation.TfIdf getTfIdf() {
            return TextIndexUtil.asTfIdf(this.functionInvocation);
        }
    }
}

