/*
 * Decompiled with CFR 0.152.
 */
package opennlp.dl.doccat;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.File;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.IntStream;
import opennlp.dl.AbstractDL;
import opennlp.dl.InferenceOptions;
import opennlp.dl.Tokens;
import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
import opennlp.tools.doccat.DocumentCategorizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DocumentCategorizerDL
extends AbstractDL
implements DocumentCategorizer {
    private static final Logger logger = LoggerFactory.getLogger(DocumentCategorizerDL.class);
    private final Map<Integer, String> categories;
    private final ClassificationScoringStrategy classificationScoringStrategy;
    private final InferenceOptions inferenceOptions;

    public DocumentCategorizerDL(File model, File vocabulary, Map<Integer, String> categories, ClassificationScoringStrategy classificationScoringStrategy, InferenceOptions inferenceOptions) throws IOException, OrtException {
        this.env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        if (inferenceOptions.isGpu()) {
            sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
        }
        this.session = this.env.createSession(model.getPath(), sessionOptions);
        this.vocab = this.loadVocab(vocabulary);
        this.tokenizer = new WordpieceTokenizer(this.vocab.keySet());
        this.categories = categories;
        this.classificationScoringStrategy = classificationScoringStrategy;
        this.inferenceOptions = inferenceOptions;
    }

    public double[] categorize(String[] strings) {
        try {
            List<Tokens> tokens = this.tokenize(strings[0]);
            LinkedList<double[]> scores = new LinkedList<double[]>();
            for (Tokens t : tokens) {
                HashMap<String, OnnxTensor> inputs = new HashMap<String, OnnxTensor>();
                inputs.put("input_ids", OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(t.ids()), (long[])new long[]{1L, t.ids().length}));
                if (this.inferenceOptions.isIncludeAttentionMask()) {
                    inputs.put("attention_mask", OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(t.mask()), (long[])new long[]{1L, t.mask().length}));
                }
                if (this.inferenceOptions.isIncludeTokenTypeIds()) {
                    inputs.put("token_type_ids", OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(t.types()), (long[])new long[]{1L, t.types().length}));
                }
                float[][] v = (float[][])this.session.run(inputs).get(0).getValue();
                double[] categoryScoresForTokens = this.softmax(v[0]);
                scores.add(categoryScoresForTokens);
            }
            return this.classificationScoringStrategy.score(scores);
        }
        catch (Exception ex) {
            logger.error("Unload to perform document classification inference", (Throwable)ex);
            return new double[0];
        }
    }

    public double[] categorize(String[] strings, Map<String, Object> map) {
        return this.categorize(strings);
    }

    public String getBestCategory(double[] doubles) {
        return this.categories.get(this.maxIndex(doubles));
    }

    public int getIndex(String s) {
        return this.getKey(s);
    }

    public String getCategory(int i) {
        return this.categories.get(i);
    }

    public int getNumberOfCategories() {
        return this.categories.size();
    }

    public String getAllResults(double[] doubles) {
        return null;
    }

    public Map<String, Double> scoreMap(String[] strings) {
        double[] scores = this.categorize(strings);
        HashMap<String, Double> scoreMap = new HashMap<String, Double>();
        for (int x : this.categories.keySet()) {
            scoreMap.put(this.categories.get(x), scores[x]);
        }
        return scoreMap;
    }

    public SortedMap<Double, Set<String>> sortedScoreMap(String[] strings) {
        double[] scores = this.categorize(strings);
        TreeMap<Double, Set<String>> scoreMap = new TreeMap<Double, Set<String>>();
        for (int x : this.categories.keySet()) {
            if (scoreMap.get(scores[x]) == null) {
                scoreMap.put(scores[x], new HashSet());
            }
            ((Set)scoreMap.get(scores[x])).add(this.categories.get(x));
        }
        return scoreMap;
    }

    private int getKey(String value) {
        for (Map.Entry<Integer, String> entry : this.categories.entrySet()) {
            if (!entry.getValue().equals(value)) continue;
            return entry.getKey();
        }
        return -1;
    }

    private Tokens oldTokenize(String text) {
        String[] tokens = this.tokenizer.tokenize(text);
        int[] ids = new int[tokens.length];
        for (int x = 0; x < tokens.length; ++x) {
            ids[x] = (Integer)this.vocab.get(tokens[x]);
        }
        long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
        long[] mask = new long[ids.length];
        Arrays.fill(mask, 1L);
        long[] types = new long[ids.length];
        Arrays.fill(types, 0L);
        return new Tokens(tokens, lids, mask, types);
    }

    private List<Tokens> tokenize(String text) {
        LinkedList<Tokens> t = new LinkedList<Tokens>();
        String[] whitespaceTokenized = text.split("\\s+");
        for (int start = 0; start < whitespaceTokenized.length; start += this.inferenceOptions.getDocumentSplitSize()) {
            int end = start + this.inferenceOptions.getDocumentSplitSize();
            if (end > whitespaceTokenized.length) {
                end = whitespaceTokenized.length;
            }
            String group = String.join((CharSequence)" ", Arrays.copyOfRange(whitespaceTokenized, start, end));
            start -= this.inferenceOptions.getSplitOverlapSize();
            String[] tokens = this.tokenizer.tokenize(group);
            int[] ids = new int[tokens.length];
            for (int x = 0; x < tokens.length; ++x) {
                ids[x] = (Integer)this.vocab.get(tokens[x]);
            }
            long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
            long[] mask = new long[ids.length];
            Arrays.fill(mask, 1L);
            long[] types = new long[ids.length];
            Arrays.fill(types, 0L);
            t.add(new Tokens(tokens, lids, mask, types));
        }
        return t;
    }

    private double[] softmax(float[] input) {
        double[] t = new double[input.length];
        double sum = 0.0;
        for (int x = 0; x < input.length; ++x) {
            double val = Math.exp(input[x]);
            sum += val;
            t[x] = val;
        }
        double[] output = new double[input.length];
        for (int x = 0; x < output.length; ++x) {
            output[x] = (float)(t[x] / sum);
        }
        return output;
    }

    private int maxIndex(double[] arr) {
        return IntStream.range(0, arr.length).reduce((i, j) -> arr[i] > arr[j] ? i : j).orElse(-1);
    }
}

