/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.neural_network;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasFieldReference;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLAttributes;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.neural_network.Connection;
import org.dmg.pmml.neural_network.NeuralEntity;
import org.dmg.pmml.neural_network.NeuralInput;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.dmg.pmml.neural_network.PMMLElements;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.HasEntityRegistry;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.InvalidElementException;
import org.jpmml.evaluator.InvalidElementListException;
import org.jpmml.evaluator.MisplacedElementException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.NormalizationUtil;
import org.jpmml.evaluator.Numbers;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.neural_network.NeuralNetworkUtil;
import org.jpmml.evaluator.neural_network.NeuronProbabilityDistribution;
import org.jpmml.model.XPathUtil;

public class NeuralNetworkEvaluator
extends ModelEvaluator<NeuralNetwork>
implements HasEntityRegistry<NeuralEntity> {
    private transient Map<FieldName, List<NeuralOutput>> neuralOutputMap = null;
    private transient BiMap<String, NeuralEntity> entityRegistry = null;
    private static final LoadingCache<NeuralNetwork, BiMap<String, NeuralEntity>> entityCache = CacheUtil.buildLoadingCache(new CacheLoader<NeuralNetwork, BiMap<String, NeuralEntity>>(){

        public BiMap<String, NeuralEntity> load(NeuralNetwork neuralNetwork) {
            ImmutableBiMap.Builder<String, NeuralInput> builder = new ImmutableBiMap.Builder<String, NeuralInput>();
            AtomicInteger index = new AtomicInteger(1);
            NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
            for (NeuralInput neuralInput : neuralInputs) {
                builder = EntityUtil.put(neuralInput, index, builder);
            }
            List neuralLayers = neuralNetwork.getNeuralLayers();
            for (NeuralLayer neuralLayer : neuralLayers) {
                List neurons = neuralLayer.getNeurons();
                for (int i = 0; i < neurons.size(); ++i) {
                    Neuron neuron = (Neuron)neurons.get(i);
                    builder = EntityUtil.put(neuron, index, builder);
                }
            }
            return builder.build();
        }
    });

    public NeuralNetworkEvaluator(PMML pmml) {
        this(pmml, PMMLUtil.findModel(pmml, NeuralNetwork.class));
    }

    public NeuralNetworkEvaluator(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml, neuralNetwork);
        NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
        if (neuralInputs == null) {
            throw new MissingElementException((PMMLObject)neuralNetwork, PMMLElements.NEURALNETWORK_NEURALINPUTS);
        }
        if (!neuralInputs.hasNeuralInputs()) {
            throw new MissingElementException((PMMLObject)neuralInputs, PMMLElements.NEURALINPUTS_NEURALINPUTS);
        }
        if (!neuralNetwork.hasNeuralLayers()) {
            throw new MissingElementException((PMMLObject)neuralNetwork, PMMLElements.NEURALNETWORK_NEURALLAYERS);
        }
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        if (neuralOutputs == null) {
            throw new MissingElementException((PMMLObject)neuralNetwork, PMMLElements.NEURALNETWORK_NEURALOUTPUTS);
        }
        if (!neuralOutputs.hasNeuralOutputs()) {
            throw new MissingElementException((PMMLObject)neuralOutputs, PMMLElements.NEURALOUTPUTS_NEURALOUTPUTS);
        }
    }

    @Override
    public String getSummary() {
        return "Neural network";
    }

    @Override
    public BiMap<String, NeuralEntity> getEntityRegistry() {
        if (this.entityRegistry == null) {
            this.entityRegistry = this.getValue(entityCache);
        }
        return this.entityRegistry;
    }

    @Override
    protected <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext context) {
        NeuralNetwork neuralNetwork = (NeuralNetwork)this.getModel();
        List<TargetField> targetFields = this.getTargetFields();
        ValueMap<String, V> values = this.evaluateRaw(valueFactory, context);
        if (values == null) {
            if (targetFields.size() == 1) {
                TargetField targetField = targetFields.get(0);
                return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
            }
            LinkedHashMap results = new LinkedHashMap();
            for (TargetField targetField : targetFields) {
                results.putAll(TargetUtil.evaluateRegressionDefault(valueFactory, targetField));
            }
            return results;
        }
        Map<FieldName, List<NeuralOutput>> neuralOutputMap = this.getNeuralOutputMap();
        LinkedHashMap results = null;
        for (TargetField targetField : targetFields) {
            FieldName name = targetField.getFieldName();
            List<NeuralOutput> neuralOutputs = neuralOutputMap.get(name);
            if (neuralOutputs == null) {
                throw new InvalidElementException((PMMLObject)neuralNetwork);
            }
            if (neuralOutputs.size() != 1) {
                throw new InvalidElementListException(neuralOutputs);
            }
            NeuralOutput neuralOutput = neuralOutputs.get(0);
            String id = neuralOutput.getOutputNeuron();
            if (id == null) {
                throw new MissingAttributeException((PMMLObject)neuralOutput, org.dmg.pmml.neural_network.PMMLAttributes.NEURALOUTPUT_OUTPUTNEURON);
            }
            Value value = (Value)values.get(id);
            if (value == null) {
                throw new InvalidAttributeException((PMMLObject)neuralOutput, org.dmg.pmml.neural_network.PMMLAttributes.NEURALOUTPUT_OUTPUTNEURON, id);
            }
            value = value.copy();
            Expression expression = this.getOutputExpression(neuralOutput);
            if (!(expression instanceof FieldRef)) {
                if (expression instanceof NormContinuous) {
                    NormContinuous normContinuous = (NormContinuous)expression;
                    NormalizationUtil.denormalize(normContinuous, value);
                } else {
                    throw new MisplacedElementException((PMMLObject)expression);
                }
            }
            if (targetFields.size() == 1) {
                return TargetUtil.evaluateRegression(targetField, value);
            }
            if (results == null) {
                results = new LinkedHashMap();
            }
            results.putAll(TargetUtil.evaluateRegression(targetField, value));
        }
        return results;
    }

    @Override
    protected <V extends Number> Map<FieldName, ? extends Classification<?, V>> evaluateClassification(ValueFactory<V> valueFactory, EvaluationContext context) {
        NeuralNetwork neuralNetwork = (NeuralNetwork)this.getModel();
        List<TargetField> targetFields = this.getTargetFields();
        ValueMap<String, V> values = this.evaluateRaw(valueFactory, context);
        if (values == null) {
            if (targetFields.size() == 1) {
                TargetField targetField = targetFields.get(0);
                return TargetUtil.evaluateClassificationDefault(valueFactory, targetField);
            }
            LinkedHashMap results = new LinkedHashMap();
            for (TargetField targetField : targetFields) {
                results.putAll(TargetUtil.evaluateClassificationDefault(valueFactory, targetField));
            }
            return results;
        }
        Map<FieldName, List<NeuralOutput>> neuralOutputMap = this.getNeuralOutputMap();
        final BiMap<String, NeuralEntity> entityRegistry = this.getEntityRegistry();
        LinkedHashMap results = null;
        for (TargetField targetField : targetFields) {
            FieldName name = targetField.getFieldName();
            List<NeuralOutput> neuralOutputs = neuralOutputMap.get(name);
            if (neuralOutputs == null) {
                throw new InvalidElementException((PMMLObject)neuralNetwork);
            }
            NeuronProbabilityDistribution result = new NeuronProbabilityDistribution<V>(new ValueMap(2 * neuralOutputs.size())){

                @Override
                public BiMap<String, NeuralEntity> getEntityRegistry() {
                    return entityRegistry;
                }
            };
            for (NeuralOutput neuralOutput : neuralOutputs) {
                String id = neuralOutput.getOutputNeuron();
                if (id == null) {
                    throw new MissingAttributeException((PMMLObject)neuralOutput, org.dmg.pmml.neural_network.PMMLAttributes.NEURALOUTPUT_OUTPUTNEURON);
                }
                NeuralEntity entity = (NeuralEntity)entityRegistry.get((Object)id);
                if (entity == null) {
                    throw new InvalidAttributeException((PMMLObject)neuralOutput, org.dmg.pmml.neural_network.PMMLAttributes.NEURALOUTPUT_OUTPUTNEURON, id);
                }
                Value value = (Value)values.get(id);
                if (value == null) {
                    throw new InvalidAttributeException((PMMLObject)neuralOutput, org.dmg.pmml.neural_network.PMMLAttributes.NEURALOUTPUT_OUTPUTNEURON, id);
                }
                Expression expression = this.getOutputExpression(neuralOutput);
                if (expression instanceof NormDiscrete) {
                    NormDiscrete normDiscrete = (NormDiscrete)expression;
                    Object targetCategory = normDiscrete.getValue();
                    if (targetCategory == null) {
                        throw new MissingAttributeException((PMMLObject)normDiscrete, PMMLAttributes.NORMDISCRETE_VALUE);
                    }
                    result.put(entity, targetCategory, value);
                    continue;
                }
                throw new MisplacedElementException((PMMLObject)expression);
            }
            if (targetFields.size() == 1) {
                return TargetUtil.evaluateClassification(targetField, result);
            }
            if (results == null) {
                results = new LinkedHashMap();
            }
            results.putAll(TargetUtil.evaluateClassification(targetField, result));
        }
        return results;
    }

    private Expression getOutputExpression(NeuralOutput neuralOutput) {
        DerivedField derivedField = neuralOutput.getDerivedField();
        if (derivedField == null) {
            throw new MissingElementException((PMMLObject)neuralOutput, PMMLElements.NEURALOUTPUT_DERIVEDFIELD);
        }
        Expression expression = ExpressionUtil.ensureExpression(derivedField);
        if (expression instanceof FieldRef) {
            FieldRef fieldRef = (FieldRef)expression;
            FieldName name = fieldRef.getField();
            if (name == null) {
                throw new MissingAttributeException((PMMLObject)fieldRef, PMMLAttributes.FIELDREF_FIELD);
            }
            Field<?> field = this.resolveField(name);
            if (field == null) {
                throw new MissingFieldException(name, (PMMLObject)fieldRef);
            }
            if (field instanceof DataField) {
                return expression;
            }
            if (field instanceof DerivedField) {
                DerivedField targetDerivedField = (DerivedField)field;
                Expression targetExpression = ExpressionUtil.ensureExpression(targetDerivedField);
                return targetExpression;
            }
            throw new InvalidAttributeException((PMMLObject)fieldRef, PMMLAttributes.FIELDREF_FIELD, name);
        }
        return expression;
    }

    private <V extends Number> ValueMap<String, V> evaluateRaw(ValueFactory<V> valueFactory, EvaluationContext context) {
        NeuralNetwork neuralNetwork = (NeuralNetwork)this.getModel();
        BiMap<String, NeuralEntity> entityRegistry = this.getEntityRegistry();
        ValueMap<String, Value<V>> result = new ValueMap<String, Value<V>>(2 * entityRegistry.size());
        NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
        for (NeuralInput neuralInput : neuralInputs) {
            DerivedField derivedField = neuralInput.getDerivedField();
            if (derivedField == null) {
                throw new MissingElementException((PMMLObject)neuralInput, PMMLElements.NEURALINPUT_DERIVEDFIELD);
            }
            FieldValue value = ExpressionUtil.evaluateTypedExpressionContainer(derivedField, context);
            if (FieldValueUtil.isMissing(value)) {
                return null;
            }
            Value<V> output = valueFactory.newValue(value.asNumber());
            result.put(neuralInput.getId(), output);
        }
        ArrayList outputs = new ArrayList();
        List neuralLayers = neuralNetwork.getNeuralLayers();
        block17: for (NeuralLayer neuralLayer : neuralLayers) {
            Number width;
            Number altitude;
            outputs.clear();
            NeuralLayer locatable = neuralLayer;
            NeuralNetwork.ActivationFunction activationFunction = neuralLayer.getActivationFunction();
            if (activationFunction == null) {
                locatable = neuralNetwork;
                activationFunction = neuralNetwork.getActivationFunction();
            }
            if (activationFunction == null) {
                throw new MissingAttributeException((PMMLObject)neuralNetwork, org.dmg.pmml.neural_network.PMMLAttributes.NEURALNETWORK_ACTIVATIONFUNCTION);
            }
            Number threshold = neuralLayer.getThreshold();
            if (threshold == null) {
                threshold = neuralNetwork.getThreshold();
            }
            if ((altitude = neuralLayer.getAltitude()) == null) {
                altitude = neuralNetwork.getAltitude();
            }
            if ((width = neuralLayer.getWidth()) == null) {
                width = neuralNetwork.getWidth();
            }
            switch (activationFunction) {
                case THRESHOLD: {
                    if (threshold != null) break;
                    throw new MissingAttributeException((PMMLObject)neuralNetwork, org.dmg.pmml.neural_network.PMMLAttributes.NEURALNETWORK_THRESHOLD);
                }
                case LOGISTIC: 
                case TANH: 
                case IDENTITY: 
                case EXPONENTIAL: 
                case RECIPROCAL: 
                case SQUARE: 
                case GAUSS: 
                case SINE: 
                case COSINE: 
                case ELLIOTT: 
                case ARCTAN: 
                case RECTIFIER: {
                    break;
                }
                case RADIAL_BASIS: {
                    break;
                }
                default: {
                    throw new UnsupportedAttributeException((PMMLObject)locatable, (Enum<?>)activationFunction);
                }
            }
            List neurons = neuralLayer.getNeurons();
            for (int i = 0; i < neurons.size(); ++i) {
                Neuron neuron = (Neuron)neurons.get(i);
                Value<V> output = valueFactory.newValue();
                List connections = neuron.getConnections();
                block19: for (int j = 0; j < connections.size(); ++j) {
                    Connection connection = (Connection)connections.get(j);
                    String id = connection.getFrom();
                    if (id == null) {
                        throw new MissingAttributeException((PMMLObject)connection, org.dmg.pmml.neural_network.PMMLAttributes.CONNECTION_FROM);
                    }
                    Number weight = connection.getWeight();
                    if (weight == null) {
                        throw new MissingAttributeException((PMMLObject)connection, org.dmg.pmml.neural_network.PMMLAttributes.CONNECTION_WEIGHT);
                    }
                    Value input = (Value)result.get(id);
                    if (input == null) {
                        throw new InvalidAttributeException((PMMLObject)connection, org.dmg.pmml.neural_network.PMMLAttributes.CONNECTION_FROM, id);
                    }
                    switch (activationFunction) {
                        case THRESHOLD: 
                        case LOGISTIC: 
                        case TANH: 
                        case IDENTITY: 
                        case EXPONENTIAL: 
                        case RECIPROCAL: 
                        case SQUARE: 
                        case GAUSS: 
                        case SINE: 
                        case COSINE: 
                        case ELLIOTT: 
                        case ARCTAN: 
                        case RECTIFIER: {
                            output.add(weight, (Number)input.getValue());
                            continue block19;
                        }
                        case RADIAL_BASIS: {
                            input = input.copy();
                            output.add(input.subtract(weight).square());
                            continue block19;
                        }
                        default: {
                            throw new UnsupportedAttributeException((PMMLObject)locatable, (Enum<?>)activationFunction);
                        }
                    }
                }
                switch (activationFunction) {
                    case THRESHOLD: 
                    case LOGISTIC: 
                    case TANH: 
                    case IDENTITY: 
                    case EXPONENTIAL: 
                    case RECIPROCAL: 
                    case SQUARE: 
                    case GAUSS: 
                    case SINE: 
                    case COSINE: 
                    case ELLIOTT: 
                    case ARCTAN: 
                    case RECTIFIER: {
                        Number neuronBias = neuron.getBias();
                        if (neuronBias != null) {
                            output.add(neuronBias);
                        }
                        NeuralNetworkUtil.activateNeuronOutput(activationFunction, threshold, output);
                        break;
                    }
                    case RADIAL_BASIS: {
                        Number neuronWidth = neuron.getWidth();
                        if (neuronWidth == null) {
                            if (width == null) {
                                throw new MissingAttributeException((PMMLObject)neuralNetwork, org.dmg.pmml.neural_network.PMMLAttributes.NEURALNETWORK_WIDTH);
                            }
                            neuronWidth = width;
                        }
                        Value<V> denominator = valueFactory.newValue(neuronWidth).square().multiply(Numbers.DOUBLE_MINUS_TWO);
                        output.divide(denominator);
                        if (altitude.doubleValue() != 1.0) {
                            Value<V> value = valueFactory.newValue(altitude).ln().multiply(connections.size());
                            output.add(value);
                        }
                        output.exp();
                        break;
                    }
                    default: {
                        throw new UnsupportedAttributeException((PMMLObject)locatable, (Enum<?>)activationFunction);
                    }
                }
                result.put(neuron.getId(), output);
                outputs.add(output);
            }
            locatable = neuralLayer;
            NeuralNetwork.NormalizationMethod normalizationMethod = neuralLayer.getNormalizationMethod();
            if (normalizationMethod == null) {
                locatable = neuralNetwork;
                normalizationMethod = neuralNetwork.getNormalizationMethod();
            }
            switch (normalizationMethod) {
                case NONE: 
                case SIMPLEMAX: 
                case SOFTMAX: {
                    NeuralNetworkUtil.normalizeNeuralLayerOutputs(normalizationMethod, outputs);
                    continue block17;
                }
            }
            throw new UnsupportedAttributeException((PMMLObject)locatable, (Enum<?>)normalizationMethod);
        }
        return result;
    }

    private Map<FieldName, List<NeuralOutput>> getNeuralOutputMap() {
        if (this.neuralOutputMap == null) {
            this.neuralOutputMap = this.parseNeuralOutputs();
        }
        return this.neuralOutputMap;
    }

    private Map<FieldName, List<NeuralOutput>> parseNeuralOutputs() {
        NeuralNetwork neuralNetwork = (NeuralNetwork)this.getModel();
        ArrayListMultimap result = ArrayListMultimap.create();
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        for (NeuralOutput neuralOutput : neuralOutputs) {
            FieldName name;
            Expression expression = this.getOutputExpression(neuralOutput);
            if (expression instanceof HasFieldReference) {
                HasFieldReference hasFieldReference = (HasFieldReference)expression;
                name = hasFieldReference.getField();
                if (name == null) {
                    throw new MissingAttributeException(MissingAttributeException.formatMessage(XPathUtil.formatElement(hasFieldReference.getClass()) + "@field"), (PMMLObject)expression);
                }
            } else {
                throw new MisplacedElementException((PMMLObject)expression);
            }
            result.put((Object)name, (Object)neuralOutput);
        }
        return result.asMap();
    }
}

