/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.internal.kernel.api.helpers.traversal.ppbfs;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.LongPredicate;
import java.util.function.Predicate;
import org.eclipse.collections.api.block.function.Function0;
import org.eclipse.collections.api.tuple.primitive.LongObjectPair;
import org.neo4j.collection.trackable.HeapTracking;
import org.neo4j.collection.trackable.HeapTrackingArrayList;
import org.neo4j.collection.trackable.HeapTrackingCollections;
import org.neo4j.collection.trackable.HeapTrackingLongArrayList;
import org.neo4j.collection.trackable.HeapTrackingUnifiedMap;
import org.neo4j.function.Predicates;
import org.neo4j.graphdb.Direction;
import org.neo4j.internal.helpers.collection.PrefetchingIterator;
import org.neo4j.internal.kernel.api.KernelReadTracer;
import org.neo4j.internal.kernel.api.RelationshipTraversalEntities;
import org.neo4j.internal.kernel.api.helpers.traversal.ReversedRelTraversalEntities;
import org.neo4j.internal.kernel.api.helpers.traversal.ppbfs.FoundNodes;
import org.neo4j.internal.kernel.api.helpers.traversal.ppbfs.GlobalState;
import org.neo4j.internal.kernel.api.helpers.traversal.ppbfs.MREValidator;
import org.neo4j.internal.kernel.api.helpers.traversal.ppbfs.NodeState;
import org.neo4j.internal.kernel.api.helpers.traversal.ppbfs.PGPathPropagatingBFS;
import org.neo4j.internal.kernel.api.helpers.traversal.ppbfs.SearchMode;
import org.neo4j.internal.kernel.api.helpers.traversal.ppbfs.TraversalDirection;
import org.neo4j.internal.kernel.api.helpers.traversal.ppbfs.TraversalPathModeFactory;
import org.neo4j.internal.kernel.api.helpers.traversal.ppbfs.TwoWaySignpost;
import org.neo4j.internal.kernel.api.helpers.traversal.ppbfs.hooks.PPBFSHooks;
import org.neo4j.internal.kernel.api.helpers.traversal.productgraph.MultiRelationshipExpansion;
import org.neo4j.internal.kernel.api.helpers.traversal.productgraph.NodeJuxtaposition;
import org.neo4j.internal.kernel.api.helpers.traversal.productgraph.ProductGraphTraversalCursor;
import org.neo4j.internal.kernel.api.helpers.traversal.productgraph.RelationshipExpansion;
import org.neo4j.internal.kernel.api.helpers.traversal.productgraph.RelationshipPredicate;
import org.neo4j.internal.kernel.api.helpers.traversal.productgraph.State;
import org.neo4j.kernel.impl.util.ValueUtils;
import org.neo4j.memory.HeapEstimator;
import org.neo4j.memory.MemoryTracker;
import org.neo4j.storageengine.api.RelationshipSelection;
import org.neo4j.values.virtual.VirtualRelationshipValue;

final class BFSExpander
implements AutoCloseable {
    private final MemoryTracker mt;
    private final PPBFSHooks hooks;
    private final GlobalState globalState;
    private final ProductGraphTraversalCursor pgCursor;
    private final CachingRelCursor relCursor;
    private final long intoTarget;
    private final TraversalPathModeFactory tracker;
    private final HeapTrackingArrayList<State> statesList;
    private final FoundNodes foundNodes;
    private final HeapTrackingUnifiedMap<CachedRelPredicate, Boolean> relPredicateCache;
    private final HeapTrackingUnifiedMap<CachedNodePredicate, Boolean> nodePredicateCache;

    public BFSExpander(FoundNodes foundNodes, GlobalState globalState, ProductGraphTraversalCursor pgCursor, ProductGraphTraversalCursor.DataGraphRelationshipCursor relCursor, long intoTarget, int nfaStateCount, TraversalPathModeFactory tracker) {
        this.mt = globalState.mt;
        this.hooks = globalState.hooks;
        this.globalState = globalState;
        this.pgCursor = pgCursor;
        this.relCursor = new CachingRelCursor(relCursor, this.mt);
        this.intoTarget = intoTarget;
        this.statesList = HeapTrackingArrayList.newArrayList((int)nfaStateCount, (MemoryTracker)this.mt);
        this.foundNodes = foundNodes;
        this.tracker = tracker;
        this.relPredicateCache = HeapTrackingCollections.newMap((MemoryTracker)this.mt);
        this.nodePredicateCache = HeapTrackingCollections.newMap((MemoryTracker)this.mt);
    }

    public void discover(NodeState node, TraversalDirection direction) {
        this.hooks.discover(node, direction);
        this.foundNodes.addToBuffer(node);
        node.discover(direction);
        State state = node.state();
        block4: for (NodeJuxtaposition nj : state.getNodeJuxtapositions(direction)) {
            if (!nj.endState(direction).test(node.id())) continue;
            switch (direction) {
                case FORWARD: {
                    NodeState nextNode = this.encounter(node.id(), nj.targetState(), direction);
                    TwoWaySignpost.NodeSignpost signpost = TwoWaySignpost.fromNodeJuxtaposition(this.mt, node, nextNode, this.foundNodes.forwardDepth(), this.tracker.lengths());
                    if (this.globalState.searchMode != SearchMode.Unidirectional && nextNode.hasSourceSignpost(signpost)) continue block4;
                    nextNode.addSourceSignpost(signpost, this.foundNodes.forwardDepth());
                    continue block4;
                }
                case BACKWARD: {
                    NodeState nextNode = this.encounter(node.id(), nj.sourceState(), direction);
                    TwoWaySignpost.NodeSignpost signpost = TwoWaySignpost.fromNodeJuxtaposition(this.mt, nextNode, node, this.tracker.lengths());
                    if (nextNode.hasTargetSignpost(signpost)) continue block4;
                    TwoWaySignpost.NodeSignpost addedSignpost = node.upsertSourceSignpost(signpost);
                    addedSignpost.setMinTargetDistance(this.foundNodes.backwardDepth(), PGPathPropagatingBFS.Phase.Expansion);
                }
            }
        }
    }

    public NodeState encounter(long nodeId, State state, TraversalDirection direction) {
        NodeState nodeState = this.foundNodes.get(nodeId, state.id());
        if (nodeState == null) {
            nodeState = new NodeState(this.globalState, nodeId, state, this.intoTarget, this.tracker.lengths());
            this.discover(nodeState, direction);
        } else if (this.globalState.searchMode == SearchMode.Bidirectional && !nodeState.hasBeenSeen(direction)) {
            this.discover(nodeState, direction);
        }
        return nodeState;
    }

    private boolean cachedRelPredicate(Predicate<RelationshipTraversalEntities> predicate, RelationshipTraversalEntities rel, TraversalDirection dir) {
        if (predicate == RelationshipPredicate.ALWAYS_TRUE) {
            return true;
        }
        return (Boolean)this.relPredicateCache.getIfAbsentPut((Object)new CachedRelPredicate(predicate, rel.relationshipReference(), dir), (Function0 & Serializable)() -> {
            this.mt.allocateHeap(CachedRelPredicate.SHALLOW_SIZE);
            return predicate.test((RelationshipTraversalEntities)(dir.isForward() ? rel : new ReversedRelTraversalEntities(rel)));
        });
    }

    private boolean cachedNodePredicate(LongPredicate predicate, long node) {
        if (predicate == Predicates.ALWAYS_TRUE_LONG) {
            return true;
        }
        return (Boolean)this.nodePredicateCache.getIfAbsentPut((Object)new CachedNodePredicate(predicate, node), (Function0 & Serializable)() -> {
            this.mt.allocateHeap(CachedNodePredicate.SHALLOW_SIZE);
            return predicate.test(node);
        });
    }

    private void multiHopDFS(NodeState startNode, MultiRelationshipExpansion expansion, TraversalDirection direction) {
        VirtualRelationshipValue[] rels = new VirtualRelationshipValue[expansion.length()];
        long[] nodes = new long[expansion.length() - 1];
        HeapTrackingLongArrayList[] nodeTree = new HeapTrackingLongArrayList[expansion.length() + 1];
        nodeTree[0] = HeapTrackingLongArrayList.newLongArrayList((int)1, (MemoryTracker)this.mt);
        nodeTree[0].add(startNode.id());
        HeapTrackingArrayList[] relTree = new HeapTrackingArrayList[expansion.length()];
        int depth = 0;
        MREValidator mreValidator = this.tracker.mreValidator();
        block4: while (depth != -1) {
            VirtualRelationshipValue rel;
            assert (depth <= expansion.length()) : "Multi-hop depth first search should never exceed total expansion length";
            if (nodeTree[depth] == null || nodeTree[depth].isEmpty()) {
                if (depth > 0) {
                    rels[direction.isBackward() ? rels.length - depth : depth - 1] = null;
                    if (depth <= nodes.length) {
                        nodes[direction.isBackward() ? nodes.length - depth : depth - 1] = 0L;
                    }
                }
                --depth;
                continue;
            }
            if (depth == expansion.length()) {
                long end;
                long endNode = nodeTree[depth].removeLast();
                rel = (VirtualRelationshipValue)relTree[depth - 1].removeLast();
                rels[direction.isBackward() ? rels.length - depth : depth - 1] = rel;
                long start = direction.isBackward() ? endNode : startNode.id();
                long l = end = direction.isBackward() ? startNode.id() : endNode;
                if (!expansion.compoundPredicate().test(start, rels, nodes, end)) continue;
                NodeState nextNode = this.encounter(endNode, expansion.endState(direction), direction);
                long[] relIds = new long[rels.length];
                for (int i = 0; i < rels.length; ++i) {
                    relIds[i] = rels[i].id();
                }
                switch (direction) {
                    case FORWARD: {
                        TwoWaySignpost.MultiRelSignpost signpost = TwoWaySignpost.fromMultiRel(this.mt, startNode, relIds, (long[])nodes.clone(), expansion, nextNode, this.foundNodes.forwardDepth(), this.tracker.lengths());
                        if (this.globalState.searchMode != SearchMode.Unidirectional && nextNode.hasSourceSignpost(signpost)) continue block4;
                        nextNode.addSourceSignpost(signpost, this.foundNodes.forwardDepth());
                        break;
                    }
                    case BACKWARD: {
                        TwoWaySignpost.MultiRelSignpost signpost = TwoWaySignpost.fromMultiRel(this.mt, nextNode, relIds, (long[])nodes.clone(), expansion, startNode, this.tracker.lengths());
                        if (nextNode.hasTargetSignpost(signpost)) break;
                        TwoWaySignpost.MultiRelSignpost addedSignpost = startNode.upsertSourceSignpost(signpost);
                        addedSignpost.setMinTargetDistance(this.foundNodes.backwardDepth(), PGPathPropagatingBFS.Phase.Expansion);
                    }
                }
                continue;
            }
            long node = nodeTree[depth].removeLast();
            if (depth > 0) {
                rel = (VirtualRelationshipValue)relTree[depth - 1].removeLast();
                rels[direction.isBackward() ? rels.length - depth : depth - 1] = rel;
                if (depth <= nodes.length) {
                    nodes[direction.isBackward() ? nodes.length - depth : depth - 1] = node;
                }
            }
            MultiRelationshipExpansion.Rel relHop = expansion.rel(depth, direction);
            LongPredicate nodePredicate = expansion.nodePredicate(depth, direction);
            boolean canExpand = false;
            Iterator<RelationshipTraversalEntities> it = this.relCursor.iterator(node, relHop.types(), relHop.getDirection(direction));
            while (it.hasNext()) {
                RelationshipTraversalEntities rel2 = it.next();
                if (!mreValidator.validateRelationships(direction, depth, rels, rel2) || !this.cachedRelPredicate(relHop.predicate(), rel2, direction) || !this.cachedNodePredicate(nodePredicate, rel2.otherNodeReference())) continue;
                if (nodeTree[depth + 1] == null) {
                    nodeTree[depth + 1] = HeapTrackingLongArrayList.newLongArrayList((MemoryTracker)this.mt);
                }
                nodeTree[depth + 1].add(rel2.otherNodeReference());
                if (relTree[depth] == null) {
                    relTree[depth] = HeapTrackingArrayList.newArrayList((MemoryTracker)this.mt);
                }
                relTree[depth].add((Object)ValueUtils.fromRelationshipCursor((RelationshipTraversalEntities)rel2));
                canExpand = true;
            }
            if (!canExpand) continue;
            ++depth;
        }
    }

    public void expand() {
        this.foundNodes.openBuffer();
        TraversalDirection direction = this.foundNodes.getNextExpansionDirection();
        this.hooks.expand(direction, this.foundNodes);
        for (LongObjectPair pair : this.foundNodes.frontier(direction).keyValuesView()) {
            long dbNodeId = pair.getOne();
            HeapTrackingArrayList statesById = (HeapTrackingArrayList)pair.getTwo();
            this.statesList.clear();
            for (NodeState nodeState : statesById) {
                if (nodeState == null) continue;
                this.statesList.add((Object)nodeState.state());
                for (MultiRelationshipExpansion mre : nodeState.state().getMultiRelationshipExpansions(direction)) {
                    int depth = this.foundNodes.depth(direction) - 1 + mre.length();
                    this.foundNodes.enqueueScheduled(depth, nodeState, mre, direction);
                }
            }
            this.hooks.expandNode(dbNodeId, this.statesList, direction);
            this.pgCursor.setNodeAndStates(dbNodeId, (List<State>)this.statesList, direction);
            block7: while (this.pgCursor.next()) {
                long foundNode = this.pgCursor.otherNodeReference();
                RelationshipExpansion re = this.pgCursor.relationshipExpansion();
                switch (direction) {
                    case FORWARD: {
                        NodeState nextNode = this.encounter(foundNode, re.targetState(), direction);
                        NodeState node = (NodeState)statesById.get(re.sourceState().id());
                        TwoWaySignpost.RelSignpost signpost = TwoWaySignpost.fromRelExpansion(this.mt, node, this.pgCursor.relationshipReference(), nextNode, re, this.foundNodes.forwardDepth(), this.tracker.lengths());
                        if (this.globalState.searchMode != SearchMode.Unidirectional && nextNode.hasSourceSignpost(signpost)) continue block7;
                        nextNode.addSourceSignpost(signpost, this.foundNodes.forwardDepth());
                        break;
                    }
                    case BACKWARD: {
                        NodeState nextNode = this.encounter(foundNode, re.sourceState(), direction);
                        NodeState node = (NodeState)statesById.get(re.targetState().id());
                        TwoWaySignpost.RelSignpost signpost = TwoWaySignpost.fromRelExpansion(this.mt, nextNode, this.pgCursor.relationshipReference(), node, re, this.tracker.lengths());
                        if (nextNode.hasTargetSignpost(signpost)) break;
                        TwoWaySignpost.RelSignpost addedSignpost = node.upsertSourceSignpost(signpost);
                        addedSignpost.setMinTargetDistance(this.foundNodes.backwardDepth(), PGPathPropagatingBFS.Phase.Expansion);
                    }
                }
            }
        }
        FoundNodes.ScheduledExpansion mre = this.foundNodes.dequeueScheduled(direction);
        while (mre != null) {
            this.multiHopDFS(mre.start(), mre.expansion(), direction);
            mre = this.foundNodes.dequeueScheduled(direction);
        }
        this.foundNodes.commitBuffer(direction);
    }

    public void setTracer(KernelReadTracer tracer) {
        this.pgCursor.setTracer(tracer);
    }

    @Override
    public void close() throws Exception {
        this.pgCursor.close();
        this.statesList.close();
    }

    private static class CachingRelCursor {
        private final ProductGraphTraversalCursor.DataGraphRelationshipCursor relCursor;
        private final HeapTracking.Map<CachedNode, List<RelationshipTraversalEntities>> cache;
        private final MemoryTracker mt;

        public CachingRelCursor(ProductGraphTraversalCursor.DataGraphRelationshipCursor relCursor, MemoryTracker mt) {
            this.relCursor = relCursor;
            this.cache = HeapTrackingCollections.newMap((MemoryTracker)mt);
            this.mt = mt;
        }

        public Iterator<RelationshipTraversalEntities> iterator(long node, int[] types, Direction direction) {
            CachedNode cacheKey = new CachedNode(node, types, direction);
            List cached = (List)this.cache.get((Object)cacheKey);
            if (cached != null) {
                return cached.iterator();
            }
            final HeapTrackingArrayList currentCache = HeapTrackingCollections.newArrayList((MemoryTracker)this.mt);
            this.mt.allocateHeap(CachedNode.SHALLOW_SIZE);
            this.cache.put((Object)cacheKey, (Object)currentCache);
            this.relCursor.setNode(node, RelationshipSelection.selection((int[])types, (Direction)direction));
            return new PrefetchingIterator<RelationshipTraversalEntities>(){

                protected RelationshipTraversalEntities fetchNextOrNull() {
                    if (relCursor.nextRelationship()) {
                        CachedRel cached = new CachedRel(relCursor.relationshipReference(), relCursor.type(), relCursor.sourceNodeReference(), relCursor.targetNodeReference(), relCursor.otherNodeReference(), relCursor.originNodeReference());
                        mt.allocateHeap(CachedRel.SHALLOW_SIZE);
                        currentCache.add((Object)cached);
                        return cached;
                    }
                    return null;
                }
            };
        }

        private record CachedNode(long nodeId, int[] types, Direction direction) {
            public static final long SHALLOW_SIZE = HeapEstimator.shallowSizeOfInstance(CachedNode.class);

            @Override
            public boolean equals(Object o) {
                if (this == o) {
                    return true;
                }
                if (o == null || this.getClass() != o.getClass()) {
                    return false;
                }
                CachedNode that = (CachedNode)o;
                return this.nodeId == that.nodeId && Arrays.equals(this.types, that.types) && this.direction == that.direction;
            }

            @Override
            public int hashCode() {
                return Objects.hash(this.nodeId, Arrays.hashCode(this.types), this.direction);
            }
        }

        private record CachedRel(long relationshipReference, int type, long sourceNodeReference, long targetNodeReference, long otherNodeReference, long originNodeReference) implements RelationshipTraversalEntities
        {
            public static final long SHALLOW_SIZE = HeapEstimator.shallowSizeOfInstance(CachedRel.class);
        }
    }

    record CachedRelPredicate(Predicate<RelationshipTraversalEntities> predicate, long rel, TraversalDirection dir) {
        public static final long SHALLOW_SIZE = HeapEstimator.shallowSizeOfInstance(CachedRelPredicate.class);
    }

    record CachedNodePredicate(LongPredicate predicate, long node) {
        public static final long SHALLOW_SIZE = HeapEstimator.shallowSizeOfInstance(CachedNodePredicate.class);
    }
}

