/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.algorithm;

import ai.libs.jaicore.basic.FileUtil;
import ai.libs.jaicore.ml.core.exception.ConfigurationException;
import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.core.predictivemodel.ICertaintyProvider;
import ai.libs.jaicore.ml.core.predictivemodel.IOnlineLearner;
import ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModelConfiguration;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.algorithm.IPLDyadRanker;
import ai.libs.jaicore.ml.dyadranking.algorithm.IPLNetDyadRankerConfiguration;
import ai.libs.jaicore.ml.dyadranking.algorithm.PLNetLoss;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.aeonbits.owner.ConfigFactory;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PLNetDyadRanker
implements IPLDyadRanker,
IOnlineLearner<IDyadRankingInstance, IDyadRankingInstance, DyadRankingDataset>,
ICertaintyProvider<IDyadRankingInstance, IDyadRankingInstance, DyadRankingDataset> {
    private static final Logger log = LoggerFactory.getLogger(PLNetDyadRanker.class);
    private MultiLayerNetwork plNet;
    private IPLNetDyadRankerConfiguration configuration;
    private int epoch;
    private int iteration;

    public PLNetDyadRanker() {
        this.configuration = (IPLNetDyadRankerConfiguration)ConfigFactory.create(IPLNetDyadRankerConfiguration.class, (Map[])new Map[0]);
    }

    public PLNetDyadRanker(IPLNetDyadRankerConfiguration config) {
        this.configuration = config;
    }

    @Override
    public void train(DyadRankingDataset dataset) throws TrainingException {
        this.train(dataset.toND4j());
    }

    @Override
    public void train(List<INDArray> dataset) {
        this.train(dataset, this.configuration.plNetMaxEpochs(), this.configuration.plNetEarlyStoppingTrainRatio());
        if (this.configuration.plNetEarlyStoppingRetrain()) {
            int maxEpochs = this.epoch;
            this.plNet = null;
            this.train(dataset, maxEpochs, 1.0);
        }
    }

    public void train(DyadRankingDataset dataset, int maxEpochs, double earlyStoppingTrainRatio) {
        this.train(dataset.toND4j(), maxEpochs, earlyStoppingTrainRatio);
    }

    public void train(List<INDArray> dataset, int maxEpochs, double earlyStoppingTrainRatio) {
        List<INDArray> drTrain = dataset.subList(0, (int)(earlyStoppingTrainRatio * (double)dataset.size()));
        List<INDArray> drTest = dataset.subList((int)(earlyStoppingTrainRatio * (double)dataset.size()), dataset.size());
        if (this.plNet == null) {
            int dyadSize = dataset.get(0).columns();
            this.plNet = this.createNetwork(dyadSize);
            this.plNet.init();
        }
        double currentBestScore = Double.POSITIVE_INFINITY;
        MultiLayerNetwork currentBestModel = this.plNet;
        this.epoch = 0;
        this.iteration = 0;
        int patience = 0;
        int earlyStoppingCounter = 0;
        while (!(patience >= this.configuration.plNetEarlyStoppingPatience() && this.configuration.plNetEarlyStoppingPatience() > 0 || this.epoch >= maxEpochs && maxEpochs != 0)) {
            this.tryUpdatingWithMinibatch(drTrain);
            log.debug("plNet params: {}", (Object)this.plNet.params());
            if (++earlyStoppingCounter == this.configuration.plNetEarlyStoppingInterval() && earlyStoppingTrainRatio < 1.0) {
                double avgScore = this.computeAvgError(drTest);
                if (avgScore < currentBestScore) {
                    currentBestScore = avgScore;
                    currentBestModel = this.plNet.clone();
                    log.debug("current best score: {}", (Object)currentBestScore);
                    patience = 0;
                } else {
                    ++patience;
                }
                earlyStoppingCounter = 0;
            }
            ++this.epoch;
        }
        this.plNet = currentBestModel;
    }

    private void tryUpdatingWithMinibatch(List<INDArray> drTrain) {
        int miniBatchSize = this.configuration.plNetMiniBatchSize();
        ArrayList<INDArray> miniBatch = new ArrayList<INDArray>(miniBatchSize);
        for (INDArray dyadRankingInstance : drTrain) {
            miniBatch.add(dyadRankingInstance);
            if (miniBatch.size() != miniBatchSize) continue;
            this.updateWithMinibatch(miniBatch);
            miniBatch.clear();
        }
        if (!miniBatch.isEmpty()) {
            this.updateWithMinibatch(miniBatch);
            miniBatch.clear();
        }
    }

    private INDArray computeScaledGradient(INDArray dyadMatrix) {
        int dyadRankingLength = dyadMatrix.rows();
        List activations = this.plNet.feedForward(dyadMatrix);
        INDArray output = (INDArray)activations.get(activations.size() - 1);
        output = output.transpose();
        INDArray deltaW = Nd4j.zeros((long[])new long[]{this.plNet.params().length()});
        Gradient deltaWk = null;
        MultiLayerNetwork plNetClone = this.plNet.clone();
        for (int k = 0; k < dyadRankingLength; ++k) {
            plNetClone.setInput(dyadMatrix.getRow((long)k));
            plNetClone.feedForward(true, false);
            INDArray lossGradient = PLNetLoss.computeLossGradient(output, k);
            Pair p = plNetClone.backpropGradient(lossGradient, null);
            deltaWk = (Gradient)p.getFirst();
            this.plNet.getUpdater().update((Trainable)this.plNet, deltaWk, this.iteration, this.epoch, 1, LayerWorkspaceMgr.noWorkspaces());
            deltaW.addi(deltaWk.gradient());
        }
        return deltaW;
    }

    private INDArray computeScaledGradient(IDyadRankingInstance instance) {
        ArrayList<INDArray> dyadList = new ArrayList<INDArray>(instance.length());
        for (Dyad dyad : instance) {
            INDArray dyadVector = this.dyadToVector(dyad);
            dyadList.add(dyadVector);
        }
        INDArray dyadMatrix = this.dyadRankingToMatrix(instance);
        List activations = this.plNet.feedForward(dyadMatrix);
        INDArray output = (INDArray)activations.get(activations.size() - 1);
        output = output.transpose();
        INDArray deltaW = Nd4j.zeros((long[])new long[]{this.plNet.params().length()});
        Gradient deltaWk = null;
        MultiLayerNetwork plNetClone = this.plNet.clone();
        for (int k = 0; k < instance.length(); ++k) {
            plNetClone.setInput((INDArray)dyadList.get(k));
            plNetClone.feedForward(true, false);
            INDArray lossGradient = PLNetLoss.computeLossGradient(output, k);
            Pair p = plNetClone.backpropGradient(lossGradient, null);
            deltaWk = (Gradient)p.getFirst();
            this.plNet.getUpdater().update((Trainable)this.plNet, deltaWk, this.iteration, this.epoch, 1, LayerWorkspaceMgr.noWorkspaces());
            deltaW.addi(deltaWk.gradient());
        }
        return deltaW;
    }

    private void updateWithMinibatch(List<INDArray> minibatch) {
        double actualMiniBatchSize = minibatch.size();
        INDArray cumulativeDeltaW = Nd4j.zeros((long[])new long[]{this.plNet.params().length()});
        for (INDArray instance : minibatch) {
            cumulativeDeltaW.addi(this.computeScaledGradient(instance));
        }
        cumulativeDeltaW.muli((Number)(1.0 / actualMiniBatchSize));
        this.plNet.params().subi(cumulativeDeltaW);
        ++this.iteration;
    }

    @Override
    public void update(IDyadRankingInstance instance) throws TrainingException {
        if (this.plNet == null) {
            int dyadSize = instance.getDyadAtPosition(0).getInstance().length() + instance.getDyadAtPosition(0).getAlternative().length();
            this.plNet = this.createNetwork(dyadSize);
            this.plNet.init();
        }
        INDArray deltaW = this.computeScaledGradient(instance);
        this.plNet.params().subi(deltaW);
        ++this.iteration;
    }

    @Override
    public void update(Set<IDyadRankingInstance> instances) throws TrainingException {
        ArrayList<INDArray> minibatch = new ArrayList<INDArray>(instances.size());
        for (IDyadRankingInstance instance : instances) {
            if (this.plNet == null) {
                int dyadSize = instance.getDyadAtPosition(0).getInstance().length() + instance.getDyadAtPosition(0).getAlternative().length();
                this.plNet = this.createNetwork(dyadSize);
                this.plNet.init();
            }
            minibatch.add(instance.toMatrix());
        }
        this.updateWithMinibatch(minibatch);
    }

    @Override
    public IDyadRankingInstance predict(IDyadRankingInstance instance) throws PredictionException {
        if (this.plNet == null) {
            int dyadSize = instance.getDyadAtPosition(0).getInstance().length() + instance.getDyadAtPosition(0).getAlternative().length();
            this.plNet = this.createNetwork(dyadSize);
            this.plNet.init();
        }
        ArrayList<Pair> dyadUtilityPairs = new ArrayList<Pair>(instance.length());
        for (Dyad dyad : instance) {
            INDArray plNetInput = this.dyadToVector(dyad);
            double plNetOutput = this.plNet.output(plNetInput).getDouble(0L);
            dyadUtilityPairs.add(new Pair((Object)dyad, (Object)plNetOutput));
        }
        Collections.sort(dyadUtilityPairs, Comparator.comparing(p -> -((Double)p.getRight()).doubleValue()));
        ArrayList<Dyad> ranking = new ArrayList<Dyad>();
        for (Pair pair : dyadUtilityPairs) {
            ranking.add((Dyad)pair.getLeft());
        }
        return new DyadRankingInstance(ranking);
    }

    @Override
    public List<IDyadRankingInstance> predict(DyadRankingDataset dataset) throws PredictionException {
        ArrayList<IDyadRankingInstance> results = new ArrayList<IDyadRankingInstance>(dataset.size());
        for (IDyadRankingInstance instance : dataset) {
            results.add(this.predict(instance));
        }
        return results;
    }

    private double computeAvgError(List<INDArray> drTest) {
        DescriptiveStatistics stats = new DescriptiveStatistics();
        for (INDArray dyadRankingInstance : drTest) {
            INDArray outputs = this.plNet.output(dyadRankingInstance);
            outputs = outputs.transpose();
            double score = PLNetLoss.computeLoss(outputs).getDouble(0L);
            stats.addValue(score);
        }
        return stats.getMean();
    }

    @Override
    public void setConfiguration(IPredictiveModelConfiguration configuration) throws ConfigurationException {
        if (!(configuration instanceof IPLNetDyadRankerConfiguration)) {
            throw new IllegalArgumentException("The configuration is no PLNetDyadRankerConfiguration!");
        }
        this.configuration = (IPLNetDyadRankerConfiguration)configuration;
    }

    @Override
    public IPredictiveModelConfiguration getConfiguration() {
        return this.configuration;
    }

    private MultiLayerNetwork createNetwork(int numInputs) {
        if (this.configuration.plNetHiddenNodes().isEmpty()) {
            throw new IllegalArgumentException("There must be at least one hidden layer in specified in the config file!");
        }
        NeuralNetConfiguration.ListBuilder configBuilder = new NeuralNetConfiguration.Builder().seed((long)this.configuration.plNetSeed()).updater((IUpdater)new Adam(this.configuration.plNetLearningRate())).list();
        String activation = this.configuration.plNetActivationFunction();
        int inputsFirstHiddenLayer = this.configuration.plNetHiddenNodes().get(0);
        configBuilder.layer(0, (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(numInputs)).nOut(inputsFirstHiddenLayer)).weightInit(WeightInit.SIGMOID_UNIFORM)).activation(Activation.fromString((String)activation))).hasBias(true).build());
        List<Integer> hiddenNodes = this.configuration.plNetHiddenNodes();
        for (int i = 0; i < hiddenNodes.size() - 1; ++i) {
            int numIn = hiddenNodes.get(i);
            int numOut = hiddenNodes.get(i + 1);
            configBuilder.layer(i + 1, (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(numIn)).nOut(numOut)).weightInit(WeightInit.SIGMOID_UNIFORM)).activation(Activation.fromString((String)activation))).hasBias(true).build());
        }
        configBuilder.layer(hiddenNodes.size(), (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(hiddenNodes.get(hiddenNodes.size() - 1).intValue())).nOut(1)).weightInit(WeightInit.UNIFORM)).activation(Activation.IDENTITY)).hasBias(true).build());
        MultiLayerConfiguration multiLayerConfig = configBuilder.build();
        return new MultiLayerNetwork(multiLayerConfig);
    }

    private INDArray dyadToVector(Dyad dyad) {
        INDArray instanceOfDyad = Nd4j.create((double[])dyad.getInstance().asArray());
        INDArray alternativeOfDyad = Nd4j.create((double[])dyad.getAlternative().asArray());
        return Nd4j.hstack((INDArray[])new INDArray[]{instanceOfDyad, alternativeOfDyad});
    }

    private INDArray dyadRankingToMatrix(IDyadRankingInstance drInstance) {
        ArrayList<INDArray> dyadList = new ArrayList<INDArray>(drInstance.length());
        for (Dyad dyad : drInstance) {
            INDArray dyadVector = this.dyadToVector(dyad);
            dyadList.add(dyadVector);
        }
        INDArray dyadMatrix = Nd4j.vstack(dyadList);
        return dyadMatrix;
    }

    public void createNetworkFromDl4jConfigFile(File configFile) {
        MultiLayerNetwork network;
        String json = "";
        try {
            json = FileUtil.readFileAsString((File)configFile);
        }
        catch (IOException e) {
            log.error(e.getMessage());
        }
        MultiLayerConfiguration config = MultiLayerConfiguration.fromJson((String)json);
        this.plNet = network = new MultiLayerNetwork(config);
    }

    public void saveModelToFile(String filePath) throws IOException {
        if (this.plNet == null) {
            throw new IllegalStateException("Cannot save untrained model.");
        }
        File locationToSave = new File(filePath + ".zip");
        ModelSerializer.writeModel((Model)this.plNet, (File)locationToSave, (boolean)true);
    }

    public void loadModelFromFile(String filePath) throws IOException {
        MultiLayerNetwork restored;
        this.plNet = restored = ModelSerializer.restoreMultiLayerNetwork((String)filePath);
    }

    public MultiLayerNetwork getPlNet() {
        return this.plNet;
    }

    public int getEpoch() {
        return this.epoch;
    }

    @Override
    public double getCertainty(IDyadRankingInstance queryInstance) {
        if (queryInstance.length() != 2) {
            throw new IllegalArgumentException("Can only provide certainty for pairs of dyads!");
        }
        ArrayList<Pair> dyadUtilityPairs = new ArrayList<Pair>(queryInstance.length());
        for (Dyad dyad : queryInstance) {
            INDArray plNetInput = this.dyadToVector(dyad);
            double plNetOutput = this.plNet.output(plNetInput).getDouble(0L);
            dyadUtilityPairs.add(new Pair((Object)dyad, (Object)plNetOutput));
        }
        return Math.abs((Double)((Pair)dyadUtilityPairs.get(0)).getRight() - (Double)((Pair)dyadUtilityPairs.get(1)).getRight());
    }

    public IDyadRankingInstance getPairWithLeastCertainty(IDyadRankingInstance drInstance) {
        if (this.plNet == null) {
            int dyadSize = drInstance.getDyadAtPosition(0).getInstance().length() + drInstance.getDyadAtPosition(0).getAlternative().length();
            this.plNet = this.createNetwork(dyadSize);
            this.plNet.init();
        }
        if (drInstance.length() < 2) {
            throw new IllegalArgumentException("The query instance must contain at least 2 dyads!");
        }
        ArrayList<Pair> dyadUtilityPairs = new ArrayList<Pair>(drInstance.length());
        for (Dyad dyad : drInstance) {
            INDArray plNetInput = this.dyadToVector(dyad);
            double plNetOutput = this.plNet.output(plNetInput).getDouble(0L);
            dyadUtilityPairs.add(new Pair((Object)dyad, (Object)plNetOutput));
        }
        Collections.sort(dyadUtilityPairs, Comparator.comparing(p -> -((Double)p.getRight()).doubleValue()));
        int indexOfPairWithLeastCertainty = 0;
        double currentlyLowestCertainty = Double.MAX_VALUE;
        for (int i = 0; i < dyadUtilityPairs.size() - 1; ++i) {
            double currentCertainty = Math.abs((Double)((Pair)dyadUtilityPairs.get(i)).getRight() - (Double)((Pair)dyadUtilityPairs.get(i + 1)).getRight());
            if (!(currentCertainty < currentlyLowestCertainty)) continue;
            currentlyLowestCertainty = currentCertainty;
            indexOfPairWithLeastCertainty = i;
        }
        LinkedList<Dyad> leastCertainDyads = new LinkedList<Dyad>();
        leastCertainDyads.add((Dyad)((Pair)dyadUtilityPairs.get(indexOfPairWithLeastCertainty)).getLeft());
        leastCertainDyads.add((Dyad)((Pair)dyadUtilityPairs.get(indexOfPairWithLeastCertainty + 1)).getLeft());
        return new DyadRankingInstance(leastCertainDyads);
    }

    public double getProbabilityOfTopRanking(IDyadRankingInstance drInstance) {
        return this.getProbabilityOfTopKRanking(drInstance, drInstance.length());
    }

    private List<Pair<Dyad, Double>> getDyadUtilityPairsForInstance(IDyadRankingInstance drInstance) {
        if (this.plNet == null) {
            int dyadSize = drInstance.getDyadAtPosition(0).getInstance().length() + drInstance.getDyadAtPosition(0).getAlternative().length();
            this.plNet = this.createNetwork(dyadSize);
            this.plNet.init();
        }
        ArrayList<Pair<Dyad, Double>> dyadUtilityPairs = new ArrayList<Pair<Dyad, Double>>(drInstance.length());
        for (Dyad dyad : drInstance) {
            INDArray plNetInput = this.dyadToVector(dyad);
            double plNetOutput = this.plNet.output(plNetInput).getDouble(0L);
            dyadUtilityPairs.add((Pair<Dyad, Double>)new Pair((Object)dyad, (Object)plNetOutput));
        }
        return dyadUtilityPairs;
    }

    private List<Pair<Dyad, Double>> getSortedDyadUtilityPairsForInstance(IDyadRankingInstance drInstance) {
        List<Pair<Dyad, Double>> dyadUtilityPairs = this.getDyadUtilityPairsForInstance(drInstance);
        Collections.sort(dyadUtilityPairs, Comparator.comparing(p -> -((Double)p.getRight()).doubleValue()));
        return dyadUtilityPairs;
    }

    public double getProbabilityOfTopKRanking(IDyadRankingInstance drInstance, int k) {
        List<Pair<Dyad, Double>> dyadUtilityPairs = this.getSortedDyadUtilityPairsForInstance(drInstance);
        double currentProbability = 1.0;
        for (int i = 0; i < Integer.min(k, dyadUtilityPairs.size()); ++i) {
            double sumOfRemainingSkills = 0.0;
            for (int j = i; j < Integer.min(k, dyadUtilityPairs.size()); ++j) {
                sumOfRemainingSkills += Math.exp((Double)dyadUtilityPairs.get(j).getRight());
            }
            if (sumOfRemainingSkills != 0.0) {
                currentProbability *= Math.exp((Double)dyadUtilityPairs.get(i).getRight()) / sumOfRemainingSkills;
                continue;
            }
            currentProbability = Double.NaN;
        }
        return currentProbability;
    }

    public double getLogProbabilityOfTopRanking(IDyadRankingInstance drInstance) {
        return this.getLogProbabilityOfTopKRanking(drInstance, Integer.MAX_VALUE);
    }

    public double getLogProbabilityOfTopKRanking(IDyadRankingInstance drInstance, int k) {
        List<Pair<Dyad, Double>> dyadUtilityPairs = this.getSortedDyadUtilityPairsForInstance(drInstance);
        double currentProbability = 0.0;
        for (int i = 0; i < Integer.min(k, dyadUtilityPairs.size()); ++i) {
            double sumOfRemainingSkills = 0.0;
            for (int j = i; j < Integer.min(k, dyadUtilityPairs.size()); ++j) {
                sumOfRemainingSkills += Math.exp((Double)dyadUtilityPairs.get(j).getRight());
            }
            currentProbability += (Double)dyadUtilityPairs.get(i).getRight() - Math.log(sumOfRemainingSkills);
        }
        return currentProbability;
    }

    public double getProbabilityRanking(IDyadRankingInstance drInstance) {
        List<Pair<Dyad, Double>> dyadUtilityPairs = this.getDyadUtilityPairsForInstance(drInstance);
        double currentProbability = 1.0;
        for (int i = 0; i < dyadUtilityPairs.size(); ++i) {
            double sumOfRemainingSkills = 0.0;
            for (int j = i; j < dyadUtilityPairs.size(); ++j) {
                sumOfRemainingSkills += Math.exp((Double)dyadUtilityPairs.get(j).getRight());
            }
            if (sumOfRemainingSkills != 0.0) {
                currentProbability *= Math.exp((Double)dyadUtilityPairs.get(i).getRight()) / sumOfRemainingSkills;
                continue;
            }
            currentProbability = Double.NaN;
        }
        return currentProbability;
    }

    public double getLogProbabilityRanking(IDyadRankingInstance drInstance) {
        List<Pair<Dyad, Double>> dyadUtilityPairs = this.getDyadUtilityPairsForInstance(drInstance);
        double currentProbability = 0.0;
        for (int i = 0; i < dyadUtilityPairs.size(); ++i) {
            double sumOfRemainingSkills = 0.0;
            for (int j = i; j < dyadUtilityPairs.size(); ++j) {
                sumOfRemainingSkills += ((Double)dyadUtilityPairs.get(j).getRight()).doubleValue();
            }
            currentProbability += (Double)dyadUtilityPairs.get(i).getRight() - sumOfRemainingSkills;
        }
        return currentProbability;
    }

    public double getSkillForDyad(Dyad dyad) {
        if (this.plNet == null) {
            return Double.NaN;
        }
        INDArray plNetInput = this.dyadToVector(dyad);
        return this.plNet.output(plNetInput).getDouble(0L);
    }
}

