/*
 * Decompiled with CFR 0.152.
 */
package ai.timefold.jpyinterpreter.dag;

import ai.timefold.jpyinterpreter.ExceptionBlock;
import ai.timefold.jpyinterpreter.FunctionMetadata;
import ai.timefold.jpyinterpreter.PythonBytecodeToJavaBytecodeTranslator;
import ai.timefold.jpyinterpreter.StackMetadata;
import ai.timefold.jpyinterpreter.dag.BasicBlock;
import ai.timefold.jpyinterpreter.dag.JumpSource;
import ai.timefold.jpyinterpreter.opcodes.Opcode;
import ai.timefold.jpyinterpreter.types.BuiltinTypes;
import ai.timefold.jpyinterpreter.types.errors.PythonBaseException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

public class FlowGraph {
    BasicBlock initialBlock;
    List<BasicBlock> basicBlockList;
    Map<BasicBlock, List<BasicBlock>> basicBlockToSourcesMap;
    Map<BasicBlock, List<JumpSource>> basicBlockToJumpSourcesMap;
    Map<IndexBranchPair, JumpSource> opcodeIndexToJumpSourceMap;
    List<StackMetadata> stackMetadataForOperations;

    private FlowGraph(BasicBlock initialBlock, List<BasicBlock> basicBlockList, Map<BasicBlock, List<BasicBlock>> basicBlockToSourcesMap, Map<BasicBlock, List<JumpSource>> basicBlockToJumpSourcesMap, Map<IndexBranchPair, JumpSource> opcodeIndexToJumpSourceMap) {
        this.initialBlock = initialBlock;
        this.basicBlockList = basicBlockList;
        this.basicBlockToSourcesMap = basicBlockToSourcesMap;
        this.basicBlockToJumpSourcesMap = basicBlockToJumpSourcesMap;
        this.opcodeIndexToJumpSourceMap = opcodeIndexToJumpSourceMap;
    }

    public List<StackMetadata> getStackMetadataForOperations() {
        return this.stackMetadataForOperations;
    }

    public <T extends Opcode> void visitOperations(Class<T> opcodeClass, BiConsumer<? super T, StackMetadata> visitor) {
        for (BasicBlock basicBlock : this.basicBlockList) {
            for (Opcode opcode : basicBlock.getBlockOpcodeList()) {
                StackMetadata stackMetadata;
                if (!opcodeClass.isAssignableFrom(opcode.getClass()) || (stackMetadata = this.stackMetadataForOperations.get(opcode.getBytecodeIndex())).isDeadCode()) continue;
                visitor.accept(opcode, stackMetadata);
            }
        }
    }

    public static FlowGraph createFlowGraph(FunctionMetadata functionMetadata, StackMetadata initialStackMetadata, List<Opcode> opcodeList) {
        ArrayList<Integer> leaderIndexList = new ArrayList<Integer>();
        boolean wasPreviousInstructionGoto = true;
        for (int i = 0; i < opcodeList.size(); ++i) {
            if (wasPreviousInstructionGoto || opcodeList.get(i).isJumpTarget() || functionMetadata.pythonCompiledFunction.co_exceptiontable.containsJumpTarget(i)) {
                leaderIndexList.add(i);
            }
            wasPreviousInstructionGoto = opcodeList.get(i).isForcedJump();
        }
        ArrayList<BasicBlock> basicBlockList = new ArrayList<BasicBlock>(leaderIndexList.size());
        HashMap<Integer, BasicBlock> jumpTargetToBasicBlock = new HashMap<Integer, BasicBlock>();
        for (int i = 0; i < leaderIndexList.size() - 1; ++i) {
            basicBlockList.add(new BasicBlock((Integer)leaderIndexList.get(i), opcodeList.subList((Integer)leaderIndexList.get(i), (Integer)leaderIndexList.get(i + 1))));
            jumpTargetToBasicBlock.put((Integer)leaderIndexList.get(i), (BasicBlock)basicBlockList.get(i));
        }
        basicBlockList.add(new BasicBlock((Integer)leaderIndexList.get(leaderIndexList.size() - 1), opcodeList.subList((Integer)leaderIndexList.get(leaderIndexList.size() - 1), opcodeList.size())));
        jumpTargetToBasicBlock.put((Integer)leaderIndexList.get(leaderIndexList.size() - 1), (BasicBlock)basicBlockList.get(leaderIndexList.size() - 1));
        BasicBlock initialBlock = (BasicBlock)basicBlockList.get(0);
        HashMap<BasicBlock, List<BasicBlock>> basicBlockToSourcesMap = new HashMap<BasicBlock, List<BasicBlock>>();
        HashMap<BasicBlock, List<JumpSource>> basicBlockToJumpSourcesMap = new HashMap<BasicBlock, List<JumpSource>>();
        HashMap<IndexBranchPair, JumpSource> opcodeIndexToJumpSourceMap = new HashMap<IndexBranchPair, JumpSource>();
        for (BasicBlock basicBlock : basicBlockList) {
            basicBlockToSourcesMap.put(basicBlock, new ArrayList());
        }
        for (BasicBlock basicBlock : basicBlockList) {
            for (Opcode opcode : basicBlock.getBlockOpcodeList()) {
                for (int branch = 0; branch < opcode.getPossibleNextBytecodeIndexList().size(); ++branch) {
                    int jumpTarget = opcode.getPossibleNextBytecodeIndexList().get(branch);
                    if (basicBlock.containsIndex(jumpTarget) && jumpTarget > opcode.getBytecodeIndex()) continue;
                    BasicBlock jumpTargetBlock = (BasicBlock)jumpTargetToBasicBlock.get(jumpTarget);
                    JumpSource jumpSource = new JumpSource(basicBlock);
                    basicBlockToSourcesMap.computeIfAbsent(jumpTargetBlock, key -> new ArrayList()).add(basicBlock);
                    basicBlockToJumpSourcesMap.computeIfAbsent(jumpTargetBlock, key -> new ArrayList()).add(jumpSource);
                    opcodeIndexToJumpSourceMap.put(new IndexBranchPair(opcode.getBytecodeIndex(), branch), jumpSource);
                }
            }
        }
        FlowGraph out = new FlowGraph(initialBlock, basicBlockList, basicBlockToSourcesMap, basicBlockToJumpSourcesMap, opcodeIndexToJumpSourceMap);
        out.computeStackMetadataForOperations(functionMetadata, initialStackMetadata);
        return out;
    }

    private static StackMetadata getExceptionStackMetadata(ExceptionBlock exceptionBlock, FunctionMetadata functionMetadata, StackMetadata initialStackMetadata, StackMetadata previousStackMetadata) {
        if (previousStackMetadata == StackMetadata.DEAD_CODE) {
            previousStackMetadata = initialStackMetadata;
        }
        while (previousStackMetadata.getStackSize() < exceptionBlock.getStackDepth()) {
            previousStackMetadata = previousStackMetadata.pushTemp(BuiltinTypes.NONE_TYPE);
        }
        while (previousStackMetadata.getStackSize() > exceptionBlock.getStackDepth()) {
            previousStackMetadata = previousStackMetadata.pop();
        }
        if (exceptionBlock.isPushLastIndex()) {
            return previousStackMetadata.pushTemp(BuiltinTypes.INT_TYPE).pushTemp(PythonBaseException.BASE_EXCEPTION_TYPE);
        }
        return previousStackMetadata.pushTemp(PythonBaseException.BASE_EXCEPTION_TYPE);
    }

    private void computeStackMetadataForOperations(FunctionMetadata functionMetadata, StackMetadata initialStackMetadata) {
        boolean hasChanged;
        HashMap<Integer, StackMetadata> opcodeIndexToStackMetadata = new HashMap<Integer, StackMetadata>();
        opcodeIndexToStackMetadata.put(0, initialStackMetadata);
        for (BasicBlock basicBlock : this.basicBlockList) {
            for (Opcode opcode : basicBlock.getBlockOpcodeList()) {
                List<StackMetadata> nextStackMetadataList;
                StackMetadata currentStackMetadata = opcodeIndexToStackMetadata.computeIfAbsent(opcode.getBytecodeIndex(), k -> StackMetadata.DEAD_CODE);
                List<Integer> branchList = opcode.getPossibleNextBytecodeIndexList();
                try {
                    nextStackMetadataList = currentStackMetadata.isDeadCode() ? branchList.stream().map(i -> StackMetadata.DEAD_CODE).collect(Collectors.toList()) : opcode.getStackMetadataAfterInstructionForBranches(functionMetadata, currentStackMetadata);
                }
                catch (Throwable t) {
                    throw new IllegalStateException("Failed to calculate successor stack metadata for opcode (" + opcode + ") with prior stack metadata (" + currentStackMetadata + ").", t);
                }
                for (int i2 = 0; i2 < branchList.size(); ++i2) {
                    IndexBranchPair indexBranchPair = new IndexBranchPair(opcode.getBytecodeIndex(), i2);
                    int nextBytecodeIndex = branchList.get(i2);
                    StackMetadata nextStackMetadata = nextStackMetadataList.get(i2);
                    if (this.opcodeIndexToJumpSourceMap.containsKey(indexBranchPair)) {
                        this.opcodeIndexToJumpSourceMap.get(indexBranchPair).setStackMetadata(nextStackMetadata);
                    }
                    try {
                        opcodeIndexToStackMetadata.merge(nextBytecodeIndex, nextStackMetadata, StackMetadata::unifyWith);
                        continue;
                    }
                    catch (IllegalArgumentException e) {
                        throw new IllegalStateException("Cannot unify block starting at " + nextBytecodeIndex + ": different stack sizes;\n" + PythonBytecodeToJavaBytecodeTranslator.getPythonBytecodeListing(functionMetadata.pythonCompiledFunction), e);
                    }
                }
            }
        }
        for (ExceptionBlock exceptionBlock : functionMetadata.pythonCompiledFunction.co_exceptiontable.getEntries()) {
            try {
                opcodeIndexToStackMetadata.merge(exceptionBlock.getTargetInstruction(), FlowGraph.getExceptionStackMetadata(exceptionBlock, functionMetadata, initialStackMetadata, opcodeIndexToStackMetadata.getOrDefault(exceptionBlock.getBlockStartInstructionInclusive(), StackMetadata.DEAD_CODE)), StackMetadata::unifyWith);
            }
            catch (IllegalArgumentException e) {
                throw new IllegalStateException("Cannot unify block starting at " + exceptionBlock.getTargetInstruction() + ": different stack sizes;\n" + PythonBytecodeToJavaBytecodeTranslator.getPythonBytecodeListing(functionMetadata.pythonCompiledFunction), e);
            }
        }
        do {
            hasChanged = false;
            for (BasicBlock basicBlock : this.basicBlockList) {
                StackMetadata originalMetadata;
                StackMetadata newMetadata = originalMetadata = (StackMetadata)opcodeIndexToStackMetadata.get(basicBlock.startAtIndex);
                for (JumpSource jumpSource : this.basicBlockToJumpSourcesMap.getOrDefault(basicBlock, Collections.emptyList())) {
                    newMetadata = newMetadata.unifyWith(jumpSource.getStackMetadata());
                }
                hasChanged |= !newMetadata.equals(originalMetadata);
                opcodeIndexToStackMetadata.put(basicBlock.startAtIndex, newMetadata);
                for (Opcode opcode : basicBlock.getBlockOpcodeList()) {
                    List<StackMetadata> nextStackMetadataList;
                    StackMetadata currentStackMetadata = (StackMetadata)opcodeIndexToStackMetadata.get(opcode.getBytecodeIndex());
                    List<Integer> branchList = opcode.getPossibleNextBytecodeIndexList();
                    try {
                        nextStackMetadataList = currentStackMetadata.isDeadCode() ? branchList.stream().map(i -> StackMetadata.DEAD_CODE).collect(Collectors.toList()) : opcode.getStackMetadataAfterInstructionForBranches(functionMetadata, currentStackMetadata);
                    }
                    catch (Throwable t) {
                        throw new IllegalStateException("Failed to calculate successor stack metadata for opcode (" + opcode + ") with prior stack metadata (" + currentStackMetadata + ").", t);
                    }
                    for (int i3 = 0; i3 < branchList.size(); ++i3) {
                        IndexBranchPair indexBranchPair = new IndexBranchPair(opcode.getBytecodeIndex(), i3);
                        int nextBytecodeIndex = branchList.get(i3);
                        StackMetadata nextStackMetadata = nextStackMetadataList.get(i3);
                        if (this.opcodeIndexToJumpSourceMap.containsKey(indexBranchPair)) {
                            this.opcodeIndexToJumpSourceMap.get(indexBranchPair).setStackMetadata(nextStackMetadata);
                        }
                        try {
                            StackMetadata originalOpcodeMetadata = (StackMetadata)opcodeIndexToStackMetadata.get(nextBytecodeIndex);
                            StackMetadata newOpcodeMetadata = opcodeIndexToStackMetadata.merge(nextBytecodeIndex, nextStackMetadata, StackMetadata::unifyWith);
                            hasChanged |= !newOpcodeMetadata.equals(originalOpcodeMetadata);
                            continue;
                        }
                        catch (IllegalArgumentException e) {
                            throw new IllegalStateException("Cannot unify branch (" + indexBranchPair.branch + ": to index " + indexBranchPair.index + ") stack metadata (" + nextStackMetadata + ") for source opcode (" + opcode + ") withprior stack metadata (" + opcodeIndexToStackMetadata.get(nextBytecodeIndex) + "): different stack sizes;\n" + PythonBytecodeToJavaBytecodeTranslator.getPythonBytecodeListing(functionMetadata.pythonCompiledFunction), e);
                        }
                    }
                }
            }
            for (ExceptionBlock exceptionBlock : functionMetadata.pythonCompiledFunction.co_exceptiontable.getEntries()) {
                try {
                    StackMetadata originalOpcodeMetadata = (StackMetadata)opcodeIndexToStackMetadata.get(exceptionBlock.getTargetInstruction());
                    StackMetadata newOpcodeMetadata = opcodeIndexToStackMetadata.merge(exceptionBlock.getTargetInstruction(), FlowGraph.getExceptionStackMetadata(exceptionBlock, functionMetadata, initialStackMetadata, opcodeIndexToStackMetadata.getOrDefault(exceptionBlock.getBlockStartInstructionInclusive(), StackMetadata.DEAD_CODE)), StackMetadata::unifyWith);
                    hasChanged |= !newOpcodeMetadata.equals(originalOpcodeMetadata);
                }
                catch (IllegalArgumentException e) {
                    throw new IllegalStateException("Cannot unify block starting at " + exceptionBlock.getTargetInstruction() + ": different stack sizes;\n" + PythonBytecodeToJavaBytecodeTranslator.getPythonBytecodeListing(functionMetadata.pythonCompiledFunction), e);
                }
            }
        } while (hasChanged);
        this.stackMetadataForOperations = opcodeIndexToStackMetadata.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Map.Entry::getValue).collect(Collectors.toList());
    }

    private static class IndexBranchPair {
        final Integer index;
        final Integer branch;

        public IndexBranchPair(Integer index, Integer branch) {
            this.index = index;
            this.branch = branch;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            IndexBranchPair that = (IndexBranchPair)o;
            return this.index.equals(that.index) && this.branch.equals(that.branch);
        }

        public int hashCode() {
            return Objects.hash(this.index, this.branch);
        }

        public String toString() {
            return "IndexBranchPair{index=" + this.index + ", branch=" + this.branch + "}";
        }
    }
}

