/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler.faulttolerant;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import io.trino.execution.scheduler.OutputDataSizeEstimate;
import io.trino.execution.scheduler.faulttolerant.FaultTolerantPartitioningScheme;
import io.trino.execution.scheduler.faulttolerant.NodeRequirements;
import io.trino.execution.scheduler.faulttolerant.SplitAssigner;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.spi.HostAddress;
import io.trino.spi.connector.CatalogHandle;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.TableWriterNode;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.IntStream;

class HashDistributionSplitAssigner
implements SplitAssigner {
    private final PlanFragmentId fragmentId;
    private final Optional<CatalogHandle> catalogRequirement;
    private final Set<PlanNodeId> replicatedSources;
    private final Set<PlanNodeId> allSources;
    private final FaultTolerantPartitioningScheme sourcePartitioningScheme;
    private final Map<Integer, TaskPartition> sourcePartitionToTaskPartition;
    private final Set<Integer> createdTaskPartitions = new HashSet<Integer>();
    private final Set<PlanNodeId> completedSources = new HashSet<PlanNodeId>();
    private final ListMultimap<PlanNodeId, Split> replicatedSplits = ArrayListMultimap.create();
    private boolean allTaskPartitionsCreated;

    public static HashDistributionSplitAssigner create(Optional<CatalogHandle> catalogRequirement, Set<PlanNodeId> partitionedSources, Set<PlanNodeId> replicatedSources, FaultTolerantPartitioningScheme sourcePartitioningScheme, Map<PlanNodeId, OutputDataSizeEstimate> sourceDataSizeEstimates, PlanFragment fragment, long targetPartitionSizeInBytes, int targetMinTaskCount, int targetMaxTaskCount) {
        if (fragment.getPartitioning().isScaleWriters()) {
            Verify.verify((fragment.getPartitionedSources().isEmpty() && fragment.getRemoteSourceNodes().size() == 1 ? 1 : 0) != 0, (String)"fragments using scale-writers partitioning are expected to have exactly one remote source and no table scans", (Object[])new Object[0]);
        }
        return new HashDistributionSplitAssigner(fragment.getId(), catalogRequirement, partitionedSources, replicatedSources, sourcePartitioningScheme, HashDistributionSplitAssigner.createSourcePartitionToTaskPartition(sourcePartitioningScheme, partitionedSources, sourceDataSizeEstimates, targetPartitionSizeInBytes, targetMinTaskCount, targetMaxTaskCount, sourceId -> fragment.getPartitioning().isScaleWriters(), !HashDistributionSplitAssigner.isWriteFragment(fragment)));
    }

    @VisibleForTesting
    HashDistributionSplitAssigner(PlanFragmentId fragmentId, Optional<CatalogHandle> catalogRequirement, Set<PlanNodeId> partitionedSources, Set<PlanNodeId> replicatedSources, FaultTolerantPartitioningScheme sourcePartitioningScheme, Map<Integer, TaskPartition> sourcePartitionToTaskPartition) {
        this.fragmentId = Objects.requireNonNull(fragmentId, "fragmentId is null");
        this.catalogRequirement = Objects.requireNonNull(catalogRequirement, "catalogRequirement is null");
        this.replicatedSources = ImmutableSet.copyOf((Collection)Objects.requireNonNull(replicatedSources, "replicatedSources is null"));
        this.allSources = ImmutableSet.builder().addAll(partitionedSources).addAll(replicatedSources).build();
        this.sourcePartitioningScheme = Objects.requireNonNull(sourcePartitioningScheme, "sourcePartitioningScheme is null");
        this.sourcePartitionToTaskPartition = ImmutableMap.copyOf(Objects.requireNonNull(sourcePartitionToTaskPartition, "sourcePartitionToTaskPartition is null"));
    }

    @Override
    public SplitAssigner.AssignmentResult assign(PlanNodeId planNodeId, ListMultimap<Integer, Split> splits, boolean noMoreSplits) {
        SplitAssigner.AssignmentResult.Builder assignment = SplitAssigner.AssignmentResult.builder();
        if (!this.allTaskPartitionsCreated) {
            int nextTaskPartitionId = 0;
            for (int sourcePartitionId2 = 0; sourcePartitionId2 < this.sourcePartitioningScheme.getPartitionCount(); ++sourcePartitionId2) {
                TaskPartition taskPartition = this.sourcePartitionToTaskPartition.get(sourcePartitionId2);
                Verify.verify((taskPartition != null ? 1 : 0) != 0, (String)"taskPartition not found for fragment %s plan node %s for sourcePartitionId %s", (Object)this.fragmentId, (Object)planNodeId, (Object)sourcePartitionId2);
                for (SubPartition subPartition : taskPartition.getSubPartitions()) {
                    if (subPartition.isIdAssigned()) continue;
                    int taskPartitionId = nextTaskPartitionId++;
                    subPartition.assignId(taskPartitionId);
                    Optional<HostAddress> hostRequirement = this.sourcePartitioningScheme.getNodeRequirement(sourcePartitionId2).map(InternalNode::getHostAndPort);
                    assignment.addPartition(new SplitAssigner.Partition(taskPartitionId, new NodeRequirements(this.catalogRequirement, hostRequirement, hostRequirement.isEmpty())));
                    this.createdTaskPartitions.add(taskPartitionId);
                }
            }
            assignment.setNoMorePartitions();
            this.allTaskPartitionsCreated = true;
        }
        if (this.replicatedSources.contains(planNodeId)) {
            this.replicatedSplits.putAll((Object)planNodeId, (Iterable)splits.values());
            for (Integer partitionId : this.createdTaskPartitions) {
                assignment.updatePartition(new SplitAssigner.PartitionUpdate(partitionId, planNodeId, false, HashDistributionSplitAssigner.replicatedSourcePartition((List<Split>)ImmutableList.copyOf((Collection)splits.values())), noMoreSplits));
            }
        } else {
            splits.forEach((sourcePartitionId, split) -> {
                TaskPartition taskPartition = this.sourcePartitionToTaskPartition.get(sourcePartitionId);
                Verify.verify((taskPartition != null ? 1 : 0) != 0, (String)"taskPartition not found for fragment %s plan node %s for sourcePartitionId %s", (Object)this.fragmentId, (Object)planNodeId, (Object)sourcePartitionId);
                ImmutableList subPartitions = taskPartition.getSplitBy().isPresent() && taskPartition.getSplitBy().get().equals(planNodeId) ? ImmutableList.of((Object)taskPartition.getNextSubPartition()) : taskPartition.getSubPartitions();
                for (SubPartition subPartition : subPartitions) {
                    assignment.updatePartition(new SplitAssigner.PartitionUpdate(subPartition.getId(), planNodeId, true, (ListMultimap<Integer, Split>)ImmutableListMultimap.of((Object)sourcePartitionId, (Object)split), false));
                }
            });
        }
        if (noMoreSplits) {
            this.completedSources.add(planNodeId);
            for (Integer taskPartition : this.createdTaskPartitions) {
                assignment.updatePartition(new SplitAssigner.PartitionUpdate(taskPartition, planNodeId, false, (ListMultimap<Integer, Split>)ImmutableListMultimap.of(), true));
            }
            if (this.completedSources.containsAll(this.allSources)) {
                for (Integer taskPartition : this.createdTaskPartitions) {
                    assignment.sealPartition(taskPartition);
                }
                this.replicatedSplits.clear();
            }
        }
        return assignment.build();
    }

    public static ListMultimap<Integer, Split> replicatedSourcePartition(List<Split> splits) {
        ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
        builder.putAll((Object)0, splits);
        return builder.build();
    }

    @Override
    public SplitAssigner.AssignmentResult finish() {
        Preconditions.checkState((!this.createdTaskPartitions.isEmpty() ? 1 : 0) != 0, (Object)"createdTaskPartitions is not expected to be empty");
        return SplitAssigner.AssignmentResult.builder().build();
    }

    @VisibleForTesting
    static Map<Integer, TaskPartition> createSourcePartitionToTaskPartition(FaultTolerantPartitioningScheme sourcePartitioningScheme, Set<PlanNodeId> partitionedSources, Map<PlanNodeId, OutputDataSizeEstimate> sourceDataSizeEstimates, long targetPartitionSizeInBytes, int targetMinTaskCount, int targetMaxTaskCount, Predicate<PlanNodeId> canSplit, boolean canMerge) {
        int partitionCount = sourcePartitioningScheme.getPartitionCount();
        if (sourcePartitioningScheme.isExplicitPartitionToNodeMappingPresent() || partitionedSources.isEmpty() || !sourceDataSizeEstimates.keySet().containsAll(partitionedSources)) {
            return (Map)IntStream.range(0, partitionCount).boxed().collect(ImmutableMap.toImmutableMap(Function.identity(), key -> new TaskPartition(1, Optional.empty())));
        }
        List partitionedSourcesEstimates = (List)sourceDataSizeEstimates.entrySet().stream().filter(entry -> partitionedSources.contains(entry.getKey())).map(Map.Entry::getValue).collect(ImmutableList.toImmutableList());
        OutputDataSizeEstimate mergedEstimate = OutputDataSizeEstimate.merge(partitionedSourcesEstimates);
        if (targetMaxTaskCount != Integer.MAX_VALUE || targetMinTaskCount != 0) {
            long totalBytes = mergedEstimate.getTotalSizeInBytes();
            if (totalBytes / targetPartitionSizeInBytes > (long)targetMaxTaskCount) {
                targetPartitionSizeInBytes = (totalBytes + (long)targetMaxTaskCount - 1L) / (long)targetMaxTaskCount;
            }
            if (totalBytes / targetPartitionSizeInBytes < (long)targetMinTaskCount) {
                targetPartitionSizeInBytes = Math.max(totalBytes / (long)targetMinTaskCount, 1L);
            }
        }
        ImmutableMap.Builder result = ImmutableMap.builder();
        PriorityQueue<PartitionAssignment> assignments = new PriorityQueue<PartitionAssignment>();
        for (int partitionId = 0; partitionId < partitionCount; ++partitionId) {
            long partitionSizeInBytes = mergedEstimate.getPartitionSizeInBytes(partitionId);
            if (assignments.isEmpty() || ((PartitionAssignment)assignments.peek()).assignedDataSizeInBytes() + partitionSizeInBytes > targetPartitionSizeInBytes || !canMerge) {
                TaskPartition taskPartition = HashDistributionSplitAssigner.createTaskPartition(partitionSizeInBytes, targetPartitionSizeInBytes, partitionedSources, sourceDataSizeEstimates, partitionId, canSplit);
                result.put((Object)partitionId, (Object)taskPartition);
                assignments.add(new PartitionAssignment(taskPartition, partitionSizeInBytes));
                continue;
            }
            PartitionAssignment assignment = (PartitionAssignment)assignments.poll();
            result.put((Object)partitionId, (Object)assignment.taskPartition());
            assignments.add(new PartitionAssignment(assignment.taskPartition(), assignment.assignedDataSizeInBytes() + partitionSizeInBytes));
        }
        return result.buildOrThrow();
    }

    private static TaskPartition createTaskPartition(long partitionSizeInBytes, long targetPartitionSizeInBytes, Set<PlanNodeId> partitionedSources, Map<PlanNodeId, OutputDataSizeEstimate> sourceDataSizeEstimates, int partitionId, Predicate<PlanNodeId> canSplit) {
        PlanNodeId largestSource;
        Map<PlanNodeId, Long> sourceSizes;
        long largestSourceSizeInBytes;
        long remainingSourcesSizeInBytes;
        if (partitionSizeInBytes > targetPartitionSizeInBytes && (remainingSourcesSizeInBytes = partitionSizeInBytes - (largestSourceSizeInBytes = (sourceSizes = HashDistributionSplitAssigner.getSourceSizes(partitionedSources, sourceDataSizeEstimates, partitionId)).get(largestSource = sourceSizes.entrySet().stream().max(Map.Entry.comparingByValue()).map(Map.Entry::getKey).orElseThrow()).longValue())) <= targetPartitionSizeInBytes / 4L && canSplit.test(largestSource)) {
            long targetLargestSourceSizeInBytes = targetPartitionSizeInBytes - remainingSourcesSizeInBytes;
            return new TaskPartition(Math.toIntExact(largestSourceSizeInBytes / targetLargestSourceSizeInBytes) + 1, Optional.of(largestSource));
        }
        return new TaskPartition(1, Optional.empty());
    }

    private static Map<PlanNodeId, Long> getSourceSizes(Set<PlanNodeId> partitionedSources, Map<PlanNodeId, OutputDataSizeEstimate> sourceDataSizeEstimates, int partitionId) {
        return (Map)partitionedSources.stream().collect(ImmutableMap.toImmutableMap(Function.identity(), source -> ((OutputDataSizeEstimate)sourceDataSizeEstimates.get(source)).getPartitionSizeInBytes(partitionId)));
    }

    private static boolean isWriteFragment(PlanFragment fragment) {
        PlanVisitor<Boolean, Void> visitor = new PlanVisitor<Boolean, Void>(){

            @Override
            protected Boolean visitPlan(PlanNode node, Void context) {
                for (PlanNode child : node.getSources()) {
                    if (!child.accept(this, context).booleanValue()) continue;
                    return true;
                }
                return false;
            }

            @Override
            public Boolean visitTableWriter(TableWriterNode node, Void context) {
                return true;
            }
        };
        return fragment.getRoot().accept(visitor, null);
    }

    public String toString() {
        return MoreObjects.toStringHelper((Object)this).add("catalogRequirement", this.catalogRequirement).add("replicatedSources", this.replicatedSources).add("allSources", this.allSources).add("sourcePartitioningScheme", (Object)this.sourcePartitioningScheme).add("sourcePartitionToTaskPartition", this.sourcePartitionToTaskPartition).add("createdTaskPartitions", this.createdTaskPartitions).add("completedSources", this.completedSources).add("replicatedSplits.size()", this.replicatedSplits.size()).add("allTaskPartitionsCreated", this.allTaskPartitionsCreated).toString();
    }

    @VisibleForTesting
    static class TaskPartition {
        private final List<SubPartition> subPartitions;
        private final Optional<PlanNodeId> splitBy;
        private int nextSubPartition;

        private TaskPartition(int subPartitionCount, Optional<PlanNodeId> splitBy) {
            Preconditions.checkArgument((subPartitionCount > 0 ? 1 : 0) != 0, (Object)"subPartitionCount is expected to be greater than zero");
            this.subPartitions = (List)IntStream.range(0, subPartitionCount).mapToObj(i -> new SubPartition()).collect(ImmutableList.toImmutableList());
            Preconditions.checkArgument((subPartitionCount == 1 || splitBy.isPresent() ? 1 : 0) != 0, (Object)"splitBy is expected to be present when subPartitionCount is greater than 1");
            this.splitBy = Objects.requireNonNull(splitBy, "splitBy is null");
        }

        public SubPartition getNextSubPartition() {
            SubPartition result = this.subPartitions.get(this.nextSubPartition);
            this.nextSubPartition = (this.nextSubPartition + 1) % this.subPartitions.size();
            return result;
        }

        public List<SubPartition> getSubPartitions() {
            return this.subPartitions;
        }

        public Optional<PlanNodeId> getSplitBy() {
            return this.splitBy;
        }

        public String toString() {
            return MoreObjects.toStringHelper((Object)this).add("subPartitions", this.subPartitions).add("splitBy", this.splitBy).add("nextSubPartition", this.nextSubPartition).toString();
        }
    }

    @VisibleForTesting
    static class SubPartition {
        private OptionalInt id = OptionalInt.empty();

        SubPartition() {
        }

        public void assignId(int id) {
            Preconditions.checkState((boolean)this.id.isEmpty(), (Object)"id is already assigned");
            this.id = OptionalInt.of(id);
        }

        public boolean isIdAssigned() {
            return this.id.isPresent();
        }

        public int getId() {
            Preconditions.checkState((boolean)this.id.isPresent(), (Object)"id is expected to be assigned");
            return this.id.getAsInt();
        }

        public String toString() {
            return this.id.toString();
        }
    }

    private record PartitionAssignment(TaskPartition taskPartition, long assignedDataSizeInBytes) implements Comparable<PartitionAssignment>
    {
        public PartitionAssignment(TaskPartition taskPartition, long assignedDataSizeInBytes) {
            this.taskPartition = Objects.requireNonNull(taskPartition, "taskPartition is null");
            this.assignedDataSizeInBytes = assignedDataSizeInBytes;
        }

        @Override
        public int compareTo(PartitionAssignment other) {
            return Long.compare(this.assignedDataSizeInBytes, other.assignedDataSizeInBytes);
        }
    }
}

