/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.window.matcher;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.operator.window.matcher.ArrayView;
import io.trino.operator.window.matcher.Instruction;
import io.trino.operator.window.matcher.Jump;
import io.trino.operator.window.matcher.MatchLabel;
import io.trino.operator.window.matcher.Program;
import io.trino.operator.window.matcher.Split;
import io.trino.operator.window.pattern.LogicalIndexNavigation;
import io.trino.operator.window.pattern.MatchAggregation;
import io.trino.operator.window.pattern.MatchAggregationPointer;
import io.trino.operator.window.pattern.PhysicalValueAccessor;
import io.trino.operator.window.pattern.PhysicalValuePointer;
import io.trino.sql.planner.LocalExecutionPlanner;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class ThreadEquivalence {
    private final List<Set<Integer>> reachableLabels;
    private final List<Set<LogicalIndexNavigation>> positionsToCompare;
    private final List<Set<LogicalIndexNavigation>> labelsToCompare;
    private final List<List<Integer>> matchAggregationsToComparePositions;
    private final List<List<Integer>> matchAggregationsToComparePositionsAndLabels;

    public ThreadEquivalence(Program program, List<List<PhysicalValueAccessor>> accessors, List<LocalExecutionPlanner.MatchAggregationLabelDependency> labelDependencies) {
        this.reachableLabels = ThreadEquivalence.computeReachableLabels(program);
        this.positionsToCompare = (List)ThreadEquivalence.getInputValuePointers(accessors).stream().map(pointersList -> (ImmutableSet)pointersList.stream().map(PhysicalValuePointer::getLogicalIndexNavigation).filter(navigation -> !navigation.getLabels().isEmpty()).map(LogicalIndexNavigation::withoutPhysicalOffset).map(ThreadEquivalence::allPositionsToCompare).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet())).collect(ImmutableList.toImmutableList());
        this.labelsToCompare = (List)ThreadEquivalence.getClassifierValuePointers(accessors).stream().map(pointersList -> (ImmutableSet)pointersList.stream().map(PhysicalValuePointer::getLogicalIndexNavigation).map(ThreadEquivalence::allPositionsToCompare).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet())).collect(ImmutableList.toImmutableList());
        AggregationIndexes aggregationIndexes = ThreadEquivalence.classifyAggregations(accessors, labelDependencies);
        this.matchAggregationsToComparePositions = aggregationIndexes.foundNoClassifierAggregations ? aggregationIndexes.noClassifierAggregations : null;
        this.matchAggregationsToComparePositionsAndLabels = aggregationIndexes.foundClassifierAggregations ? aggregationIndexes.classifierAggregations : null;
    }

    public boolean equivalent(int firstThread, ArrayView firstLabels, MatchAggregation[] firstAggregations, int secondThread, ArrayView secondLabels, MatchAggregation[] secondAggregations, int pointer) {
        int i;
        ArrayView secondPositions;
        Preconditions.checkArgument((firstLabels.length() == secondLabels.length() ? 1 : 0) != 0, (Object)"matched labels for compared threads differ in length");
        Preconditions.checkArgument((pointer >= 0 && pointer < this.reachableLabels.size() ? 1 : 0) != 0, (Object)"instruction pointer out of program bounds");
        if (firstThread == secondThread || firstLabels.length() == 0) {
            return true;
        }
        HashSet distinctPositionsToCompare = new HashSet();
        for (int label : this.reachableLabels.get(pointer)) {
            distinctPositionsToCompare.addAll(this.positionsToCompare.get(label));
        }
        for (LogicalIndexNavigation navigation : distinctPositionsToCompare) {
            if (ThreadEquivalence.resolvePosition(navigation, firstLabels) == ThreadEquivalence.resolvePosition(navigation, secondLabels)) continue;
            return false;
        }
        HashSet distinctLabelPositionsToCompare = new HashSet();
        for (int label : this.reachableLabels.get(pointer)) {
            distinctLabelPositionsToCompare.addAll(this.labelsToCompare.get(label));
        }
        for (LogicalIndexNavigation navigation : distinctLabelPositionsToCompare) {
            int secondPosition;
            int firstPosition = ThreadEquivalence.resolvePosition(navigation, firstLabels);
            if (firstPosition == -1 != ((secondPosition = ThreadEquivalence.resolvePosition(navigation, secondLabels)) == -1)) {
                return false;
            }
            if (firstPosition == -1 || firstLabels.get(firstPosition) == secondLabels.get(secondPosition)) continue;
            return false;
        }
        if (this.matchAggregationsToComparePositions != null) {
            HashSet aggregationsToComparePositions = new HashSet();
            for (int label : this.reachableLabels.get(pointer)) {
                aggregationsToComparePositions.addAll(this.matchAggregationsToComparePositions.get(label));
            }
            for (int aggregationIndex : aggregationsToComparePositions) {
                ArrayView firstPositions = firstAggregations[aggregationIndex].getAllPositions(firstLabels);
                secondPositions = secondAggregations[aggregationIndex].getAllPositions(secondLabels);
                if (firstPositions.length() != secondPositions.length()) {
                    return false;
                }
                for (i = 0; i < firstPositions.length(); ++i) {
                    if (firstPositions.get(i) == secondPositions.get(i)) continue;
                    return false;
                }
            }
        }
        if (this.matchAggregationsToComparePositionsAndLabels != null) {
            HashSet aggregationsToComparePositionsAndLabels = new HashSet();
            for (int label : this.reachableLabels.get(pointer)) {
                aggregationsToComparePositionsAndLabels.addAll(this.matchAggregationsToComparePositionsAndLabels.get(label));
            }
            for (int aggregationIndex : aggregationsToComparePositionsAndLabels) {
                ArrayView firstPositions = firstAggregations[aggregationIndex].getAllPositions(firstLabels);
                secondPositions = secondAggregations[aggregationIndex].getAllPositions(secondLabels);
                if (firstPositions.length() != secondPositions.length()) {
                    return false;
                }
                for (i = 0; i < firstPositions.length(); ++i) {
                    int position = firstPositions.get(i);
                    if (position == secondPositions.get(i) && firstLabels.get(position) == secondLabels.get(position)) continue;
                    return false;
                }
            }
        }
        return true;
    }

    private static int resolvePosition(LogicalIndexNavigation navigation, ArrayView labels) {
        return navigation.resolvePosition(labels.length() - 1, labels, 0, labels.length(), 0);
    }

    private static List<Set<Integer>> computeReachableLabels(Program program) {
        ArrayList<Set<Integer>> reachableLabels = new ArrayList<Set<Integer>>(program.size());
        for (int instructionIndex = 0; instructionIndex < program.size(); ++instructionIndex) {
            reachableLabels.add(ThreadEquivalence.reachableLabels(program, instructionIndex, new boolean[program.size()]));
        }
        return reachableLabels;
    }

    private static Set<Integer> reachableLabels(Program program, int instructionIndex, boolean[] visited) {
        if (visited[instructionIndex]) {
            return new HashSet<Integer>();
        }
        visited[instructionIndex] = true;
        HashSet<Integer> reachableLabels = new HashSet<Integer>();
        Instruction instruction = program.at(instructionIndex);
        switch (instruction.type()) {
            case MATCH_LABEL: {
                reachableLabels.addAll(ThreadEquivalence.reachableLabels(program, instructionIndex + 1, visited));
                reachableLabels.add(((MatchLabel)instruction).getLabel());
                break;
            }
            case JUMP: {
                reachableLabels.addAll(ThreadEquivalence.reachableLabels(program, ((Jump)instruction).getTarget(), visited));
                break;
            }
            case SPLIT: {
                reachableLabels.addAll(ThreadEquivalence.reachableLabels(program, ((Split)instruction).getFirst(), visited));
                reachableLabels.addAll(ThreadEquivalence.reachableLabels(program, ((Split)instruction).getSecond(), visited));
                break;
            }
            case MATCH_START: 
            case MATCH_END: 
            case SAVE: {
                reachableLabels.addAll(ThreadEquivalence.reachableLabels(program, instructionIndex + 1, visited));
                break;
            }
        }
        return reachableLabels;
    }

    private static List<List<PhysicalValuePointer>> getInputValuePointers(List<List<PhysicalValueAccessor>> valuePointers) {
        return (List)valuePointers.stream().map(pointerList -> (ImmutableList)pointerList.stream().filter(pointer -> pointer instanceof PhysicalValuePointer).map(PhysicalValuePointer.class::cast).filter(pointer -> pointer.getSourceChannel() != -1 && pointer.getSourceChannel() != -2).collect(ImmutableList.toImmutableList())).collect(ImmutableList.toImmutableList());
    }

    private static List<List<PhysicalValuePointer>> getClassifierValuePointers(List<List<PhysicalValueAccessor>> valuePointers) {
        return (List)valuePointers.stream().map(pointerList -> (ImmutableList)pointerList.stream().filter(pointer -> pointer instanceof PhysicalValuePointer).map(PhysicalValuePointer.class::cast).filter(pointer -> pointer.getSourceChannel() == -1).collect(ImmutableList.toImmutableList())).collect(ImmutableList.toImmutableList());
    }

    private static AggregationIndexes classifyAggregations(List<List<PhysicalValueAccessor>> valuePointers, List<LocalExecutionPlanner.MatchAggregationLabelDependency> labelDependencies) {
        ImmutableList.Builder noClassifierAggregations = ImmutableList.builder();
        boolean foundNoClassifierAggregations = false;
        ImmutableList.Builder classifierAggregations = ImmutableList.builder();
        boolean foundClassifierAggregations = false;
        for (List<PhysicalValueAccessor> pointerList : valuePointers) {
            ImmutableList.Builder noClassifierAggregationIndexes = ImmutableList.builder();
            ImmutableList.Builder classifierAggregationIndexes = ImmutableList.builder();
            for (PhysicalValueAccessor pointer : pointerList) {
                if (!(pointer instanceof MatchAggregationPointer)) continue;
                MatchAggregationPointer matchAggregationPointer = (MatchAggregationPointer)pointer;
                int aggregationIndex = matchAggregationPointer.getIndex();
                LocalExecutionPlanner.MatchAggregationLabelDependency labelDependency = labelDependencies.get(aggregationIndex);
                if (!labelDependency.isClassifierInvolved() || labelDependency.getLabels().size() == 1) {
                    foundNoClassifierAggregations = true;
                    noClassifierAggregationIndexes.add((Object)aggregationIndex);
                    continue;
                }
                foundClassifierAggregations = true;
                classifierAggregationIndexes.add((Object)aggregationIndex);
            }
            noClassifierAggregations.add((Object)noClassifierAggregationIndexes.build());
            classifierAggregations.add((Object)classifierAggregationIndexes.build());
        }
        return new AggregationIndexes(foundNoClassifierAggregations, (List<List<Integer>>)noClassifierAggregations.build(), foundClassifierAggregations, (List<List<Integer>>)classifierAggregations.build());
    }

    private static List<LogicalIndexNavigation> allPositionsToCompare(LogicalIndexNavigation navigation) {
        if (navigation.isLast()) {
            ArrayList<LogicalIndexNavigation> result = new ArrayList<LogicalIndexNavigation>();
            for (int offset = 0; offset <= navigation.getLogicalOffset(); ++offset) {
                result.add(navigation.withLogicalOffset(offset));
            }
            for (int tail = navigation.getPhysicalOffset() + 1; tail < 0; ++tail) {
                result.add(navigation.withoutLogicalOffset().withPhysicalOffset(tail));
            }
            return result;
        }
        return ImmutableList.of((Object)navigation);
    }

    private static class AggregationIndexes {
        final boolean foundNoClassifierAggregations;
        final List<List<Integer>> noClassifierAggregations;
        final boolean foundClassifierAggregations;
        final List<List<Integer>> classifierAggregations;

        public AggregationIndexes(boolean foundNoClassifierAggregations, List<List<Integer>> noClassifierAggregations, boolean foundClassifierAggregations, List<List<Integer>> classifierAggregations) {
            this.foundNoClassifierAggregations = foundNoClassifierAggregations;
            this.noClassifierAggregations = noClassifierAggregations;
            this.foundClassifierAggregations = foundClassifierAggregations;
            this.classifierAggregations = classifierAggregations;
        }
    }
}

