/*
 * Decompiled with CFR 0.152.
 */
package opennlp.dl.namefinder;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import opennlp.dl.InferenceOptions;
import opennlp.dl.SpanEnd;
import opennlp.dl.Tokens;
import opennlp.tools.namefind.TokenNameFinder;
import opennlp.tools.sentdetect.SentenceDetector;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
import opennlp.tools.util.Span;

public class NameFinderDL
implements TokenNameFinder {
    public static final String INPUT_IDS = "input_ids";
    public static final String ATTENTION_MASK = "attention_mask";
    public static final String TOKEN_TYPE_IDS = "token_type_ids";
    public static final String I_PER = "I-PER";
    public static final String B_PER = "B-PER";
    public static final String SEPARATOR = "[SEP]";
    private static final String CHARS_TO_REPLACE = "##";
    protected final OrtSession session;
    private final SentenceDetector sentenceDetector;
    private final Map<Integer, String> ids2Labels;
    private final Tokenizer tokenizer;
    private final Map<String, Integer> vocab;
    private final InferenceOptions inferenceOptions;
    protected final OrtEnvironment env = OrtEnvironment.getEnvironment();

    public NameFinderDL(File model, File vocabulary, Map<Integer, String> ids2Labels, SentenceDetector sentenceDetector) throws Exception {
        this(model, vocabulary, ids2Labels, new InferenceOptions(), sentenceDetector);
    }

    public NameFinderDL(File model, File vocabulary, Map<Integer, String> ids2Labels, InferenceOptions inferenceOptions, SentenceDetector sentenceDetector) throws Exception {
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        if (inferenceOptions.isGpu()) {
            sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
        }
        this.session = this.env.createSession(model.getPath(), sessionOptions);
        this.ids2Labels = ids2Labels;
        this.vocab = this.loadVocab(vocabulary);
        this.tokenizer = new WordpieceTokenizer(this.vocab.keySet());
        this.inferenceOptions = inferenceOptions;
        this.sentenceDetector = sentenceDetector;
    }

    public Span[] find(String[] input) {
        String[] sentences;
        LinkedList<Span> spans = new LinkedList<Span>();
        String text = String.join((CharSequence)" ", input);
        for (String sentence : sentences = this.sentenceDetector.sentDetect((CharSequence)text)) {
            List<Tokens> wordpieceTokens = this.tokenize(sentence);
            for (Tokens tokens : wordpieceTokens) {
                try {
                    HashMap<String, OnnxTensor> inputs = new HashMap<String, OnnxTensor>();
                    inputs.put(INPUT_IDS, OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(tokens.getIds()), (long[])new long[]{1L, tokens.getIds().length}));
                    if (this.inferenceOptions.isIncludeAttentionMask()) {
                        inputs.put(ATTENTION_MASK, OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(tokens.getMask()), (long[])new long[]{1L, tokens.getMask().length}));
                    }
                    if (this.inferenceOptions.isIncludeTokenTypeIds()) {
                        inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(tokens.getTypes()), (long[])new long[]{1L, tokens.getTypes().length}));
                    }
                    float[][][] v = (float[][][])this.session.run(inputs).get(0).getValue();
                    int characterStart = 0;
                    String[] toks = tokens.getTokens();
                    for (int x = 0; x < v[0].length; ++x) {
                        String spanText;
                        float[] arr = v[0][x];
                        int maxIndex = this.maxIndex(arr);
                        String label = this.ids2Labels.get(maxIndex);
                        double confidence = arr[maxIndex];
                        if (!B_PER.equals(label)) continue;
                        SpanEnd spanEnd = this.findSpanEnd(v, x, this.ids2Labels, toks);
                        if (spanEnd.getIndex() != -1) {
                            StringBuilder sb = new StringBuilder();
                            int end = spanEnd.getIndex();
                            for (int i = x; i <= end; ++i) {
                                if (toks[i + 1].startsWith(CHARS_TO_REPLACE)) {
                                    sb.append(toks[i]).append(toks[i + 1].replace(CHARS_TO_REPLACE, ""));
                                    if (!toks[i + 2].startsWith(CHARS_TO_REPLACE)) {
                                        sb.append(" ");
                                    }
                                    ++i;
                                    continue;
                                }
                                sb.append(toks[i].replace(CHARS_TO_REPLACE, ""));
                                if (".".equals(toks[i + 1])) continue;
                                sb.append(" ");
                            }
                            spanText = NameFinderDL.findByRegex(text, sb.toString().trim()).trim();
                        } else {
                            spanText = toks[x];
                        }
                        if (SEPARATOR.equals(spanText) || (characterStart = text.indexOf(spanText = spanText.replace(CHARS_TO_REPLACE, ""), characterStart)) == -1) continue;
                        int characterEnd = characterStart + spanText.length();
                        spans.add(new Span(characterStart, characterEnd, spanText, confidence));
                        ++characterStart;
                    }
                }
                catch (OrtException ex) {
                    throw new RuntimeException("Error performing namefinder inference: " + ex.getMessage(), ex);
                }
            }
        }
        return spans.toArray(new Span[0]);
    }

    public void clearAdaptiveData() {
    }

    private SpanEnd findSpanEnd(float[][][] v, int startIndex, Map<Integer, String> id2Labels, String[] tokens) {
        int x;
        int index = -1;
        int characterEnd = 0;
        for (x = startIndex + 1; x < v[0].length; ++x) {
            float[] arr = v[0][x];
            String nextTokenClassification = id2Labels.get(this.maxIndex(arr));
            if (I_PER.equals(nextTokenClassification)) continue;
            index = x - 1;
            break;
        }
        for (x = 1; x <= index && x < tokens.length; ++x) {
            characterEnd += tokens[x].length();
        }
        return new SpanEnd(index, characterEnd += index - 1);
    }

    private int maxIndex(float[] arr) {
        double max = Double.NEGATIVE_INFINITY;
        int index = -1;
        for (int x = 0; x < arr.length; ++x) {
            if (!((double)arr[x] > max)) continue;
            index = x;
            max = arr[x];
        }
        return index;
    }

    private static String findByRegex(String text, String span) {
        String regex = span.replaceAll(" ", "\\\\s+").replaceAll("\\)", "\\\\)").replaceAll("\\(", "\\\\(");
        Pattern pattern = Pattern.compile(regex, 2);
        Matcher matcher = pattern.matcher(text);
        if (matcher.find()) {
            return matcher.group(0);
        }
        return span;
    }

    private List<Tokens> tokenize(String text) {
        LinkedList<Tokens> t = new LinkedList<Tokens>();
        String[] whitespaceTokenized = text.split("\\s+");
        for (int start = 0; start < whitespaceTokenized.length; start += this.inferenceOptions.getDocumentSplitSize()) {
            int end = start + this.inferenceOptions.getDocumentSplitSize();
            if (end > whitespaceTokenized.length) {
                end = whitespaceTokenized.length;
            }
            String group = String.join((CharSequence)" ", Arrays.copyOfRange(whitespaceTokenized, start, end));
            start -= this.inferenceOptions.getSplitOverlapSize();
            String[] tokens = this.tokenizer.tokenize(group);
            int[] ids = new int[tokens.length];
            for (int x = 0; x < tokens.length; ++x) {
                ids[x] = this.vocab.get(tokens[x]);
            }
            long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
            long[] mask = new long[ids.length];
            Arrays.fill(mask, 1L);
            long[] types = new long[ids.length];
            Arrays.fill(types, 0L);
            t.add(new Tokens(tokens, lids, mask, types));
        }
        return t;
    }

    private Map<String, Integer> loadVocab(File vocab) throws IOException {
        HashMap<String, Integer> v = new HashMap<String, Integer>();
        try (BufferedReader br = new BufferedReader(new FileReader(vocab.getPath()));){
            String line = br.readLine();
            int x = 0;
            while (line != null) {
                line = br.readLine();
                v.put(line, ++x);
            }
        }
        return v;
    }
}

