/*
 * Decompiled with CFR 0.152.
 */
package studio.mevera.imperat.command.flags;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import studio.mevera.imperat.command.CommandUsage;
import studio.mevera.imperat.command.flags.FlagExtractor;
import studio.mevera.imperat.command.parameters.CommandParameter;
import studio.mevera.imperat.context.Context;
import studio.mevera.imperat.context.FlagData;
import studio.mevera.imperat.context.Source;
import studio.mevera.imperat.exception.UnknownFlagException;

final class FlagExtractorImpl<S extends Source>
implements FlagExtractor<S> {
    private final CommandUsage<S> usage;
    private final FlagTrie<S> flagTrie;

    FlagExtractorImpl(CommandUsage<S> usage) {
        this.usage = Objects.requireNonNull(usage, "CommandUsage cannot be null");
        this.flagTrie = this.buildFlagTrie();
    }

    @Override
    public void insertFlag(FlagData<S> flagData) {
        this.flagTrie.insert(flagData.name(), flagData);
        for (String alias : flagData.aliases()) {
            this.flagTrie.insert(alias, flagData);
        }
    }

    @Override
    public Set<FlagData<S>> extract(String rawInput, Context<S> ctx) throws UnknownFlagException {
        if (rawInput == null || rawInput.isEmpty()) {
            return Collections.emptySet();
        }
        return this.parseFlags(rawInput, ctx);
    }

    private FlagTrie<S> buildFlagTrie() {
        FlagTrie trie = new FlagTrie();
        Set allFlags = this.usage.getParameters().stream().filter(CommandParameter::isFlag).map(parameter -> parameter.asFlagParameter().flagData()).collect(Collectors.toSet());
        for (FlagData flagData : allFlags) {
            trie.insert(flagData.name(), flagData);
            for (String alias : flagData.aliases()) {
                trie.insert(alias, flagData);
            }
        }
        return trie;
    }

    private Set<FlagData<S>> parseFlags(String input, Context<S> context) throws UnknownFlagException {
        LinkedHashSet<FlagData<S>> extractedFlags = new LinkedHashSet<FlagData<S>>(3);
        ArrayList<String> unmatchedParts = new ArrayList<String>();
        int position = 0;
        while (position < input.length()) {
            MatchResult<S> match = this.flagTrie.findLongestMatch(input, position);
            if (match.isFound()) {
                extractedFlags.add(match.flagData());
                position += match.matchLength();
                continue;
            }
            unmatchedParts.add(String.valueOf(input.charAt(position)));
            ++position;
        }
        if (!unmatchedParts.isEmpty()) {
            throw new UnknownFlagException(String.join((CharSequence)", ", unmatchedParts), context);
        }
        return extractedFlags;
    }

    private static class FlagTrie<S extends Source> {
        private final TrieNode<S> root = new TrieNode();

        FlagTrie() {
        }

        void insert(String alias, FlagData<S> flagData) {
            TrieNode current = this.root;
            for (char c : alias.toCharArray()) {
                current = current.children.computeIfAbsent(Character.valueOf(c), k -> new TrieNode());
            }
            current.flagData = flagData;
            current.isEndOfFlag = true;
        }

        MatchResult<S> findLongestMatch(String input, int startPos) {
            char c;
            TrieNode<S> current = this.root;
            FlagData lastMatchedFlag = null;
            int lastMatchLength = 0;
            for (int i = startPos; i < input.length() && (current = current.children.get(Character.valueOf(c = input.charAt(i)))) != null; ++i) {
                if (!current.isEndOfFlag) continue;
                lastMatchedFlag = current.flagData;
                lastMatchLength = i - startPos + 1;
            }
            return new MatchResult(lastMatchedFlag, lastMatchLength);
        }
    }

    private record MatchResult<S extends Source>(FlagData<S> flagData, int matchLength) {
        boolean isFound() {
            return this.flagData != null && this.matchLength > 0;
        }
    }

    private static class TrieNode<S extends Source> {
        final Map<Character, TrieNode<S>> children = new HashMap<Character, TrieNode<S>>();
        FlagData<S> flagData;
        boolean isEndOfFlag = false;

        TrieNode() {
        }
    }
}

