/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
import com.google.inject.Inject;
import io.airlift.slice.XxHash64;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.connector.CatalogServiceProvider;
import io.trino.execution.scheduler.BucketNodeMap;
import io.trino.execution.scheduler.NodeScheduler;
import io.trino.execution.scheduler.NodeSelector;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.operator.BucketPartitionFunction;
import io.trino.operator.PartitionFunction;
import io.trino.operator.RetryPolicy;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.connector.BucketFunction;
import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.connector.ConnectorBucketNodeMap;
import io.trino.spi.connector.ConnectorNodePartitioningProvider;
import io.trino.spi.connector.ConnectorPartitioningHandle;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.split.EmptySplit;
import io.trino.sql.planner.MergePartitioningHandle;
import io.trino.sql.planner.NodePartitionMap;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.util.Failures;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.ToIntFunction;
import java.util.stream.IntStream;

public class NodePartitioningManager {
    private final NodeScheduler nodeScheduler;
    private final TypeOperators typeOperators;
    private final CatalogServiceProvider<ConnectorNodePartitioningProvider> partitioningProvider;

    @Inject
    public NodePartitioningManager(NodeScheduler nodeScheduler, TypeOperators typeOperators, CatalogServiceProvider<ConnectorNodePartitioningProvider> partitioningProvider) {
        this.nodeScheduler = Objects.requireNonNull(nodeScheduler, "nodeScheduler is null");
        this.typeOperators = Objects.requireNonNull(typeOperators, "typeOperators is null");
        this.partitioningProvider = Objects.requireNonNull(partitioningProvider, "partitioningProvider is null");
    }

    public PartitionFunction getPartitionFunction(Session session, PartitioningScheme partitioningScheme, List<Type> partitionChannelTypes) {
        int[] bucketToPartition = partitioningScheme.getBucketToPartition().orElseThrow(() -> new IllegalArgumentException("Bucket to partition must be set before a partition function can be created"));
        PartitioningHandle partitioningHandle = partitioningScheme.getPartitioning().getHandle();
        if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle) {
            return ((SystemPartitioningHandle)partitioningHandle.getConnectorHandle()).getPartitionFunction(partitionChannelTypes, partitioningScheme.getHashColumn().isPresent(), bucketToPartition, this.typeOperators);
        }
        ConnectorPartitioningHandle connectorPartitioningHandle = partitioningHandle.getConnectorHandle();
        if (connectorPartitioningHandle instanceof MergePartitioningHandle) {
            MergePartitioningHandle handle = (MergePartitioningHandle)connectorPartitioningHandle;
            return handle.getPartitionFunction((PartitioningScheme scheme, List<Type> types) -> this.getPartitionFunction(session, scheme, types, bucketToPartition), partitionChannelTypes, bucketToPartition);
        }
        return this.getPartitionFunction(session, partitioningScheme, partitionChannelTypes, bucketToPartition);
    }

    public PartitionFunction getPartitionFunction(Session session, PartitioningScheme partitioningScheme, List<Type> partitionChannelTypes, int[] bucketToPartition) {
        PartitioningHandle partitioningHandle = partitioningScheme.getPartitioning().getHandle();
        ConnectorPartitioningHandle connectorPartitioningHandle = partitioningHandle.getConnectorHandle();
        if (connectorPartitioningHandle instanceof SystemPartitioningHandle) {
            SystemPartitioningHandle handle = (SystemPartitioningHandle)connectorPartitioningHandle;
            return handle.getPartitionFunction(partitionChannelTypes, partitioningScheme.getHashColumn().isPresent(), bucketToPartition, this.typeOperators);
        }
        BucketFunction bucketFunction = this.getBucketFunction(session, partitioningHandle, partitionChannelTypes, bucketToPartition.length);
        return new BucketPartitionFunction(bucketFunction, bucketToPartition);
    }

    public BucketFunction getBucketFunction(Session session, PartitioningHandle partitioningHandle, List<Type> partitionChannelTypes, int bucketCount) {
        CatalogHandle catalogHandle = NodePartitioningManager.requiredCatalogHandle(partitioningHandle);
        ConnectorNodePartitioningProvider partitioningProvider = this.getPartitioningProvider(catalogHandle);
        BucketFunction bucketFunction = partitioningProvider.getBucketFunction(partitioningHandle.getTransactionHandle().orElseThrow(), session.toConnectorSession(), partitioningHandle.getConnectorHandle(), partitionChannelTypes, bucketCount);
        Preconditions.checkArgument((bucketFunction != null ? 1 : 0) != 0, (String)"No bucket function for partitioning: %s", (Object)partitioningHandle);
        return bucketFunction;
    }

    public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle) {
        return this.getNodePartitioningMap(session, partitioningHandle, new HashMap<Integer, List<InternalNode>>(), new AtomicReference<List<InternalNode>>(), Optional.empty());
    }

    public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle, Optional<Integer> partitionCount) {
        return this.getNodePartitioningMap(session, partitioningHandle, new HashMap<Integer, List<InternalNode>>(), new AtomicReference<List<InternalNode>>(), partitionCount);
    }

    private NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle, Map<Integer, List<InternalNode>> bucketToNodeCache, AtomicReference<List<InternalNode>> systemPartitioningCache, Optional<Integer> partitionCount) {
        List<InternalNode> bucketToNode;
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(partitioningHandle, "partitioningHandle is null");
        if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle) {
            return new NodePartitionMap(this.systemBucketToNode(session, partitioningHandle, systemPartitioningCache, partitionCount), split -> {
                throw new UnsupportedOperationException("System distribution does not support source splits " + String.valueOf(partitioningHandle));
            });
        }
        ConnectorPartitioningHandle connectorPartitioningHandle = partitioningHandle.getConnectorHandle();
        if (connectorPartitioningHandle instanceof MergePartitioningHandle) {
            MergePartitioningHandle mergeHandle = (MergePartitioningHandle)connectorPartitioningHandle;
            return mergeHandle.getNodePartitioningMap(handle -> this.getNodePartitioningMap(session, (PartitioningHandle)handle, bucketToNodeCache, systemPartitioningCache, partitionCount));
        }
        Optional<ConnectorBucketNodeMap> optionalMap = this.getConnectorBucketNodeMap(session, partitioningHandle);
        if (optionalMap.isEmpty()) {
            bucketToNode = this.systemBucketToNode(session, SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION, systemPartitioningCache, partitionCount);
        } else {
            ConnectorBucketNodeMap connectorBucketNodeMap = optionalMap.get();
            Preconditions.checkArgument((connectorBucketNodeMap.getBucketCount() < 1000000 ? 1 : 0) != 0, (String)"Too many buckets in partitioning: %s", (int)connectorBucketNodeMap.getBucketCount());
            if (connectorBucketNodeMap.hasFixedMapping()) {
                bucketToNode = NodePartitioningManager.getFixedMapping(connectorBucketNodeMap);
                Verify.verify((bucketToNode.size() == connectorBucketNodeMap.getBucketCount() ? 1 : 0) != 0, (String)"Fixed mapping size does not match bucket count", (Object[])new Object[0]);
            } else {
                CatalogHandle catalogHandle = NodePartitioningManager.requiredCatalogHandle(partitioningHandle);
                bucketToNode = bucketToNodeCache.computeIfAbsent(connectorBucketNodeMap.getBucketCount(), bucketCount -> NodePartitioningManager.createArbitraryBucketToNode(connectorBucketNodeMap.getCacheKeyHint(), this.getAllNodes(session, catalogHandle), bucketCount));
            }
        }
        int[] bucketToPartition = new int[bucketToNode.size()];
        HashBiMap nodeToPartition = HashBiMap.create();
        int nextPartitionId = 0;
        for (int bucket = 0; bucket < bucketToNode.size(); ++bucket) {
            InternalNode node = bucketToNode.get(bucket);
            Integer partitionId = (Integer)nodeToPartition.get((Object)node);
            if (partitionId == null) {
                partitionId = nextPartitionId;
                ++nextPartitionId;
                nodeToPartition.put((Object)node, (Object)partitionId);
            }
            bucketToPartition[bucket] = partitionId;
        }
        List partitionToNode = (List)IntStream.range(0, nodeToPartition.size()).mapToObj(arg_0 -> NodePartitioningManager.lambda$getNodePartitioningMap$5((BiMap)nodeToPartition, arg_0)).collect(ImmutableList.toImmutableList());
        return new NodePartitionMap(partitionToNode, bucketToPartition, this.getSplitToBucket(session, partitioningHandle, bucketToNode.size()));
    }

    private List<InternalNode> systemBucketToNode(Session session, PartitioningHandle partitioningHandle, AtomicReference<List<InternalNode>> nodesCache, Optional<Integer> partitionCount) {
        SystemPartitioningHandle.SystemPartitioning partitioning = ((SystemPartitioningHandle)partitioningHandle.getConnectorHandle()).getPartitioning();
        NodeSelector nodeSelector = this.nodeScheduler.createNodeSelector(session, Optional.empty());
        Object nodes = switch (partitioning) {
            case SystemPartitioningHandle.SystemPartitioning.COORDINATOR_ONLY -> ImmutableList.of((Object)nodeSelector.selectCurrentNode());
            case SystemPartitioningHandle.SystemPartitioning.SINGLE -> nodeSelector.selectRandomNodes(1);
            case SystemPartitioningHandle.SystemPartitioning.FIXED -> {
                List<InternalNode> value = nodesCache.get();
                if (value == null) {
                    value = nodeSelector.selectRandomNodes(partitionCount.orElse(SystemSessionProperties.getMaxHashPartitionCount(session)));
                    nodesCache.set(value);
                }
                yield value;
            }
            default -> throw new IllegalArgumentException("Unsupported plan distribution " + String.valueOf((Object)partitioning));
        };
        Failures.checkCondition(!nodes.isEmpty(), (ErrorCodeSupplier)StandardErrorCode.NO_NODES_AVAILABLE, "No worker nodes available", new Object[0]);
        return nodes;
    }

    public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partitioningHandle) {
        Optional<ConnectorBucketNodeMap> bucketNodeMap = this.getConnectorBucketNodeMap(session, partitioningHandle);
        int bucketCount = bucketNodeMap.map(ConnectorBucketNodeMap::getBucketCount).orElseGet(() -> this.getDefaultBucketCount(session, partitioningHandle));
        ToIntFunction<Split> splitToBucket = this.getSplitToBucket(session, partitioningHandle, bucketCount);
        if (bucketNodeMap.map(ConnectorBucketNodeMap::hasFixedMapping).orElse(false).booleanValue()) {
            return new BucketNodeMap(splitToBucket, NodePartitioningManager.getFixedMapping(bucketNodeMap.get()));
        }
        long seed = bucketNodeMap.map(ConnectorBucketNodeMap::getCacheKeyHint).orElse(ThreadLocalRandom.current().nextLong());
        List<InternalNode> nodes = this.getAllNodes(session, NodePartitioningManager.requiredCatalogHandle(partitioningHandle));
        return new BucketNodeMap(splitToBucket, NodePartitioningManager.createArbitraryBucketToNode(seed, nodes, bucketCount));
    }

    private int getDefaultBucketCount(Session session, PartitioningHandle partitioningHandle) {
        int remoteBucketCount = SystemSessionProperties.getRetryPolicy(session) != RetryPolicy.TASK ? this.getNodeCount(session, partitioningHandle) * 3 : SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount(session) * 3;
        int localBucketCount = 8192;
        return Math.max(remoteBucketCount, localBucketCount);
    }

    public int getNodeCount(Session session, PartitioningHandle partitioningHandle) {
        return this.getAllNodes(session, NodePartitioningManager.requiredCatalogHandle(partitioningHandle)).size();
    }

    private List<InternalNode> getAllNodes(Session session, CatalogHandle catalogHandle) {
        return this.nodeScheduler.createNodeSelector(session, Optional.of(catalogHandle)).allNodes();
    }

    private static List<InternalNode> getFixedMapping(ConnectorBucketNodeMap connectorBucketNodeMap) {
        return (List)connectorBucketNodeMap.getFixedMapping().stream().map(InternalNode.class::cast).collect(ImmutableList.toImmutableList());
    }

    public Optional<ConnectorBucketNodeMap> getConnectorBucketNodeMap(Session session, PartitioningHandle partitioningHandle) {
        CatalogHandle catalogHandle = NodePartitioningManager.requiredCatalogHandle(partitioningHandle);
        ConnectorNodePartitioningProvider partitioningProvider = this.getPartitioningProvider(catalogHandle);
        return partitioningProvider.getBucketNodeMapping(partitioningHandle.getTransactionHandle().orElseThrow(), session.toConnectorSession(catalogHandle), partitioningHandle.getConnectorHandle());
    }

    public ToIntFunction<Split> getSplitToBucket(Session session, PartitioningHandle partitioningHandle, int bucketCount) {
        CatalogHandle catalogHandle = NodePartitioningManager.requiredCatalogHandle(partitioningHandle);
        ConnectorNodePartitioningProvider partitioningProvider = this.getPartitioningProvider(catalogHandle);
        ToIntFunction splitBucketFunction = partitioningProvider.getSplitBucketFunction(partitioningHandle.getTransactionHandle().orElseThrow(), session.toConnectorSession(catalogHandle), partitioningHandle.getConnectorHandle(), bucketCount);
        Preconditions.checkArgument((splitBucketFunction != null ? 1 : 0) != 0, (String)"No partitioning %s", (Object)partitioningHandle);
        return split -> {
            int bucket = split.getConnectorSplit() instanceof EmptySplit ? 0 : splitBucketFunction.applyAsInt(split.getConnectorSplit());
            return bucket;
        };
    }

    private ConnectorNodePartitioningProvider getPartitioningProvider(CatalogHandle catalogHandle) {
        return this.partitioningProvider.getService(Objects.requireNonNull(catalogHandle, "catalogHandle is null"));
    }

    private static CatalogHandle requiredCatalogHandle(PartitioningHandle partitioningHandle) {
        return partitioningHandle.getCatalogHandle().orElseThrow(() -> new IllegalStateException("No catalog handle for partitioning handle: " + String.valueOf(partitioningHandle)));
    }

    private static List<InternalNode> createArbitraryBucketToNode(long seed, List<InternalNode> nodes, int bucketCount) {
        Objects.requireNonNull(nodes, "nodes is null");
        Preconditions.checkArgument((!nodes.isEmpty() ? 1 : 0) != 0, (Object)"nodes is empty");
        Preconditions.checkArgument((bucketCount > 0 ? 1 : 0) != 0, (Object)"bucketCount must be greater than zero");
        ImmutableList.Builder bucketAssignments = ImmutableList.builderWithExpectedSize((int)bucketCount);
        for (int bucket = 0; bucket < bucketCount; ++bucket) {
            long bucketHash = XxHash64.hash((long)seed, (long)bucket);
            InternalNode bestNode = null;
            long highestWeight = Long.MIN_VALUE;
            for (InternalNode node : nodes) {
                long weight = XxHash64.hash((long)node.longHashCode(), (long)bucketHash);
                if (weight < highestWeight) continue;
                highestWeight = weight;
                bestNode = node;
            }
            bucketAssignments.add((Object)Objects.requireNonNull(bestNode));
        }
        return bucketAssignments.build();
    }

    private static /* synthetic */ InternalNode lambda$getNodePartitioningMap$5(BiMap nodeToPartition, int partitionId) {
        return (InternalNode)nodeToPartition.inverse().get((Object)partitionId);
    }
}

