/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.neural_network;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.neural_network.Connection;
import org.dmg.pmml.neural_network.NeuralInput;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EntityProbabilityDistribution;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.HasEntityRegistry;
import org.jpmml.evaluator.InvalidFeatureException;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.NormalizationUtil;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedFeatureException;

public class NeuralNetworkEvaluator
extends ModelEvaluator<NeuralNetwork>
implements HasEntityRegistry<Entity> {
    private transient BiMap<String, Entity> entityRegistry = null;
    private static final LoadingCache<NeuralNetwork, BiMap<String, Entity>> entityCache = CacheUtil.buildLoadingCache(new CacheLoader<NeuralNetwork, BiMap<String, Entity>>(){

        public BiMap<String, Entity> load(NeuralNetwork neuralNetwork) {
            ImmutableBiMap.Builder<String, NeuralInput> builder = new ImmutableBiMap.Builder<String, NeuralInput>();
            AtomicInteger index = new AtomicInteger(1);
            NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
            for (NeuralInput neuralInput : neuralInputs) {
                builder = EntityUtil.put(neuralInput, index, builder);
            }
            List neuralLayers = neuralNetwork.getNeuralLayers();
            for (NeuralLayer neuralLayer : neuralLayers) {
                List neurons = neuralLayer.getNeurons();
                for (int i = 0; i < neurons.size(); ++i) {
                    Neuron neuron = (Neuron)neurons.get(i);
                    builder = EntityUtil.put(neuron, index, builder);
                }
            }
            return builder.build();
        }
    });

    public NeuralNetworkEvaluator(PMML pmml) {
        this(pmml, NeuralNetworkEvaluator.selectModel(pmml, NeuralNetwork.class));
    }

    public NeuralNetworkEvaluator(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml, neuralNetwork);
        NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
        if (neuralInputs == null) {
            throw new InvalidFeatureException((PMMLObject)neuralNetwork);
        }
        if (!neuralInputs.hasNeuralInputs()) {
            throw new InvalidFeatureException((PMMLObject)neuralInputs);
        }
        if (!neuralNetwork.hasNeuralLayers()) {
            throw new InvalidFeatureException((PMMLObject)neuralNetwork);
        }
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        if (neuralOutputs == null) {
            throw new InvalidFeatureException((PMMLObject)neuralNetwork);
        }
        if (!neuralOutputs.hasNeuralOutputs()) {
            throw new InvalidFeatureException((PMMLObject)neuralOutputs);
        }
    }

    @Override
    public String getSummary() {
        return "Neural network";
    }

    @Override
    public BiMap<String, Entity> getEntityRegistry() {
        if (this.entityRegistry == null) {
            this.entityRegistry = this.getValue(entityCache);
        }
        return this.entityRegistry;
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        Map<FieldName, Object> predictions;
        NeuralNetwork neuralNetwork = (NeuralNetwork)this.getModel();
        if (!neuralNetwork.isScorable()) {
            throw new InvalidResultException((PMMLObject)neuralNetwork);
        }
        MiningFunction miningFunction = neuralNetwork.getMiningFunction();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)neuralNetwork, (Enum<?>)miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext context) {
        NeuralNetwork neuralNetwork = (NeuralNetwork)this.getModel();
        Map<String, Double> entityOutputs = this.evaluateRaw(context);
        if (entityOutputs == null) {
            return TargetUtil.evaluateRegressionDefault(context);
        }
        LinkedHashMap<FieldName, Object> result = new LinkedHashMap<FieldName, Object>();
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        for (NeuralOutput neuralOutput : neuralOutputs) {
            Double value;
            FieldName name;
            String id = neuralOutput.getOutputNeuron();
            Expression expression = this.getOutputExpression(neuralOutput);
            if (expression instanceof FieldRef) {
                FieldRef fieldRef = (FieldRef)expression;
                name = fieldRef.getField();
                value = entityOutputs.get(id);
                result.put(name, value);
                continue;
            }
            if (expression instanceof NormContinuous) {
                NormContinuous normContinuous = (NormContinuous)expression;
                name = normContinuous.getField();
                value = NormalizationUtil.denormalize(normContinuous, entityOutputs.get(id));
                result.put(name, value);
                continue;
            }
            throw new UnsupportedFeatureException((PMMLObject)expression);
        }
        List<TargetField> targetFields = this.getTargetFields();
        for (TargetField targetField : targetFields) {
            FieldName name = targetField.getName();
            result.put(name, TargetUtil.evaluateRegressionInternal(targetField, result.get(name), context));
        }
        return result;
    }

    private Map<FieldName, ? extends Classification> evaluateClassification(ModelEvaluationContext context) {
        NeuralNetwork neuralNetwork = (NeuralNetwork)this.getModel();
        BiMap<String, Entity> entityRegistry = this.getEntityRegistry();
        Map<String, Double> entityOutputs = this.evaluateRaw(context);
        if (entityOutputs == null) {
            return TargetUtil.evaluateClassificationDefault(context);
        }
        LinkedHashMap<FieldName, Classification> result = new LinkedHashMap<FieldName, Classification>();
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        for (NeuralOutput neuralOutput : neuralOutputs) {
            String id = neuralOutput.getOutputNeuron();
            Entity entity = (Entity)entityRegistry.get((Object)id);
            Expression expression = this.getOutputExpression(neuralOutput);
            if (expression instanceof NormDiscrete) {
                NormDiscrete normDiscrete = (NormDiscrete)expression;
                FieldName name = normDiscrete.getField();
                EntityProbabilityDistribution<Entity> values = (EntityProbabilityDistribution<Entity>)result.get(name);
                if (values == null) {
                    values = new EntityProbabilityDistribution<Entity>(entityRegistry);
                    result.put(name, values);
                }
                Double value = entityOutputs.get(id);
                values.put(entity, normDiscrete.getValue(), value);
                continue;
            }
            throw new UnsupportedFeatureException((PMMLObject)expression);
        }
        List<TargetField> targetFields = this.getTargetFields();
        for (TargetField targetField : targetFields) {
            FieldName name = targetField.getName();
            result.put(name, TargetUtil.evaluateClassificationInternal(targetField, (Classification)result.get(name), context));
        }
        return result;
    }

    private Expression getOutputExpression(NeuralOutput neuralOutput) {
        DerivedField derivedField = neuralOutput.getDerivedField();
        if (derivedField == null) {
            throw new InvalidFeatureException((PMMLObject)neuralOutput);
        }
        Expression expression = derivedField.getExpression();
        if (expression == null) {
            throw new InvalidFeatureException((PMMLObject)derivedField);
        }
        if (expression instanceof FieldRef) {
            FieldRef fieldRef = (FieldRef)expression;
            FieldName name = fieldRef.getField();
            TypeDefinitionField field = this.resolveField(name);
            if (field == null) {
                throw new MissingFieldException(name, (PMMLObject)fieldRef);
            }
            if (field instanceof DataField) {
                return expression;
            }
            if (field instanceof DerivedField) {
                DerivedField targetDerivedField = (DerivedField)field;
                Expression targetExpression = targetDerivedField.getExpression();
                if (targetExpression == null) {
                    throw new InvalidFeatureException((PMMLObject)targetDerivedField);
                }
                return targetExpression;
            }
            throw new InvalidFeatureException((PMMLObject)fieldRef);
        }
        return expression;
    }

    private Map<String, Double> evaluateRaw(EvaluationContext context) {
        NeuralNetwork neuralNetwork = (NeuralNetwork)this.getModel();
        BiMap<String, Entity> entityRegistry = this.getEntityRegistry();
        HashMap<String, Double> result = new HashMap<String, Double>(entityRegistry.size());
        NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
        for (NeuralInput neuralInput : neuralInputs) {
            DerivedField derivedField = neuralInput.getDerivedField();
            FieldValue value = ExpressionUtil.evaluate(derivedField, context);
            if (value == null) {
                return null;
            }
            result.put(neuralInput.getId(), value.asDouble());
        }
        HashMap<String, Double> outputs = new HashMap<String, Double>();
        List neuralLayers = neuralNetwork.getNeuralLayers();
        for (NeuralLayer neuralLayer : neuralLayers) {
            outputs.clear();
            List neurons = neuralLayer.getNeurons();
            for (int i = 0; i < neurons.size(); ++i) {
                Neuron neuron = (Neuron)neurons.get(i);
                double z = 0.0;
                List connections = neuron.getConnections();
                for (int j = 0; j < connections.size(); ++j) {
                    Connection connection = (Connection)connections.get(j);
                    Double input = (Double)result.get(connection.getFrom());
                    if (input == null) {
                        throw new InvalidFeatureException((PMMLObject)connection);
                    }
                    z += input * connection.getWeight();
                }
                Double bias = neuron.getBias();
                if (bias != null) {
                    z += bias.doubleValue();
                }
                double output = this.activation(z, neuralLayer);
                outputs.put(neuron.getId(), output);
            }
            this.normalizeNeuronOutputs(neuralLayer, outputs);
            result.putAll(outputs);
        }
        return result;
    }

    private void normalizeNeuronOutputs(NeuralLayer neuralLayer, Map<String, Double> values) {
        NeuralNetwork neuralNetwork = (NeuralNetwork)this.getModel();
        NeuralLayer locatable = neuralLayer;
        NeuralNetwork.NormalizationMethod normalizationMethod = neuralLayer.getNormalizationMethod();
        if (normalizationMethod == null) {
            locatable = neuralNetwork;
            normalizationMethod = neuralNetwork.getNormalizationMethod();
        }
        switch (normalizationMethod) {
            case NONE: {
                break;
            }
            case SIMPLEMAX: {
                Classification.normalize(values);
                break;
            }
            case SOFTMAX: {
                Classification.normalizeSoftMax(values);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)locatable, (Enum<?>)normalizationMethod);
            }
        }
    }

    private double activation(double z, NeuralLayer neuralLayer) {
        NeuralNetwork neuralNetwork = (NeuralNetwork)this.getModel();
        NeuralLayer locatable = neuralLayer;
        NeuralNetwork.ActivationFunction activationFunction = neuralLayer.getActivationFunction();
        if (activationFunction == null) {
            locatable = neuralNetwork;
            activationFunction = neuralNetwork.getActivationFunction();
        }
        if (activationFunction == null) {
            throw new InvalidFeatureException((PMMLObject)neuralLayer);
        }
        switch (activationFunction) {
            case THRESHOLD: {
                Double threshold = neuralLayer.getThreshold();
                if (threshold == null) {
                    threshold = neuralNetwork.getThreshold();
                }
                if (threshold == null) {
                    throw new InvalidFeatureException((PMMLObject)neuralLayer);
                }
                return z > threshold ? 1.0 : 0.0;
            }
            case LOGISTIC: {
                return 1.0 / (1.0 + Math.exp(-z));
            }
            case TANH: {
                return Math.tanh(z);
            }
            case IDENTITY: {
                return z;
            }
            case EXPONENTIAL: {
                return Math.exp(z);
            }
            case RECIPROCAL: {
                return 1.0 / z;
            }
            case SQUARE: {
                return z * z;
            }
            case GAUSS: {
                return Math.exp(-(z * z));
            }
            case SINE: {
                return Math.sin(z);
            }
            case COSINE: {
                return Math.cos(z);
            }
            case ELLIOTT: {
                return z / (1.0 + Math.abs(z));
            }
            case ARCTAN: {
                return Math.atan(z);
            }
            case RECTIFIER: {
                return Math.max(0.0, z);
            }
        }
        throw new UnsupportedFeatureException((PMMLObject)locatable, (Enum<?>)activationFunction);
    }
}

