/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Iterators;
import com.google.common.collect.Multimap;
import io.airlift.concurrent.Threads;
import io.airlift.slice.SizeOf;
import io.trino.Session;
import io.trino.client.NodeVersion;
import io.trino.execution.MockRemoteTaskFactory;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.RemoteTask;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.execution.scheduler.FlatNetworkTopology;
import io.trino.execution.scheduler.NetworkLocation;
import io.trino.execution.scheduler.NetworkTopology;
import io.trino.execution.scheduler.NodeScheduler;
import io.trino.execution.scheduler.NodeSchedulerConfig;
import io.trino.execution.scheduler.NodeSelector;
import io.trino.execution.scheduler.NodeSelectorFactory;
import io.trino.execution.scheduler.TopologyAwareNodeSelectorConfig;
import io.trino.execution.scheduler.TopologyAwareNodeSelectorFactory;
import io.trino.execution.scheduler.UniformNodeSelectorFactory;
import io.trino.jmh.Benchmarks;
import io.trino.metadata.InMemoryNodeManager;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.metadata.Split;
import io.trino.spi.HostAddress;
import io.trino.spi.connector.ConnectorSplit;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingHandles;
import io.trino.testing.TestingSession;
import io.trino.util.FinalizerService;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;

@State(value=Scope.Thread)
@OutputTimeUnit(value=TimeUnit.MICROSECONDS)
@Fork(value=1)
@Warmup(iterations=10, time=500, timeUnit=TimeUnit.MILLISECONDS)
@Measurement(iterations=10, time=500, timeUnit=TimeUnit.MILLISECONDS)
@BenchmarkMode(value={Mode.AverageTime})
public class BenchmarkNodeScheduler {
    private static final int MAX_SPLITS_PER_NODE = 100;
    private static final int MAX_PENDING_SPLITS_PER_TASK_PER_NODE = 50;
    private static final int NODES = 200;
    private static final int DATA_NODES = 10000;
    private static final int RACKS = 400;
    private static final int SPLITS = 23200;
    private static final int SPLIT_BATCH_SIZE = 100;

    @Benchmark
    @OperationsPerInvocation(value=23200)
    public Object benchmark(BenchmarkData data) {
        ImmutableList remoteTasks = ImmutableList.copyOf(data.getTaskMap().values());
        Iterator finishingTask = Iterators.cycle(data.getTaskMap().values());
        Iterator<Split> splits = data.getSplits().iterator();
        HashSet<Split> batch = new HashSet<Split>();
        while (splits.hasNext() || !batch.isEmpty()) {
            Multimap assignments = data.getNodeSelector().computeAssignments(batch, (List)remoteTasks).getAssignments();
            for (InternalNode node : assignments.keySet()) {
                MockRemoteTaskFactory.MockRemoteTask remoteTask = data.getTaskMap().get(node);
                remoteTask.addSplits((Multimap<PlanNodeId, Split>)ImmutableMultimap.builder().putAll((Object)new PlanNodeId("sourceId"), (Iterable)assignments.get((Object)node)).build());
                remoteTask.startSplits(100);
            }
            if (assignments.size() == batch.size()) {
                batch.clear();
            } else {
                batch.removeAll(assignments.values());
            }
            while (batch.size() < 100 && splits.hasNext()) {
                batch.add(splits.next());
            }
            ((MockRemoteTaskFactory.MockRemoteTask)finishingTask.next()).finishSplits((int)Math.ceil(2.0));
        }
        return remoteTasks;
    }

    public static void main(String[] args) throws Exception {
        Benchmarks.benchmark(BenchmarkNodeScheduler.class).run();
    }

    private static TopologyAwareNodeSelectorConfig getBenchmarkNetworkTopologyConfig() {
        return new TopologyAwareNodeSelectorConfig().setLocationSegmentNames((List)ImmutableList.of((Object)"rack", (Object)"machine"));
    }

    private static HostAddress addressForHost(int host) {
        int rack = Integer.hashCode(host) % 400;
        return HostAddress.fromParts((String)("host" + host + ".rack" + rack), (int)1);
    }

    @State(value=Scope.Thread)
    public static class BenchmarkData {
        @Param(value={"uniform", "benchmark", "topology"})
        private String policy = "uniform";
        private FinalizerService finalizerService = new FinalizerService();
        private NodeSelector nodeSelector;
        private Map<InternalNode, MockRemoteTaskFactory.MockRemoteTask> taskMap = new HashMap<InternalNode, MockRemoteTaskFactory.MockRemoteTask>();
        private List<Split> splits = new ArrayList<Split>();

        @Setup
        public void setup() {
            int i;
            this.finalizerService.start();
            NodeTaskMap nodeTaskMap = new NodeTaskMap(this.finalizerService);
            ImmutableList.Builder nodeBuilder = ImmutableList.builder();
            for (int i2 = 0; i2 < 200; ++i2) {
                nodeBuilder.add((Object)new InternalNode("node" + i2, URI.create("http://" + BenchmarkNodeScheduler.addressForHost(i2).getHostText()), NodeVersion.UNKNOWN, false));
            }
            ImmutableList nodes = nodeBuilder.build();
            MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(Executors.newCachedThreadPool(Threads.daemonThreadsNamed((String)"remoteTaskExecutor-%s")), Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed((String)"remoteTaskScheduledExecutor-%s")));
            for (i = 0; i < nodes.size(); ++i) {
                InternalNode node = (InternalNode)nodes.get(i);
                ImmutableList.Builder initialSplits = ImmutableList.builder();
                for (int j = 0; j < 150; ++j) {
                    initialSplits.add((Object)new Split(TestingHandles.TEST_CATALOG_HANDLE, (ConnectorSplit)new TestSplitRemote(i)));
                }
                TaskId taskId = new TaskId(new StageId("test", 1), i, 0);
                MockRemoteTaskFactory.MockRemoteTask remoteTask = remoteTaskFactory.createTableScanTask(taskId, node, (List<Split>)initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(node, taskId));
                nodeTaskMap.addTask(node, (RemoteTask)remoteTask);
                this.taskMap.put(node, remoteTask);
            }
            for (i = 0; i < 23200; ++i) {
                this.splits.add(new Split(TestingHandles.TEST_CATALOG_HANDLE, (ConnectorSplit)new TestSplitRemote(ThreadLocalRandom.current().nextInt(10000))));
            }
            NodeScheduler nodeScheduler = new NodeScheduler(this.getNodeSelectorFactory((InternalNodeManager)new InMemoryNodeManager(new InternalNode[0]), nodeTaskMap));
            Session session = TestingSession.testSessionBuilder().setSystemProperty("max_unacknowledged_splits_per_task", Integer.toString(Integer.MAX_VALUE)).build();
            this.nodeSelector = nodeScheduler.createNodeSelector(session, Optional.of(TestingHandles.TEST_CATALOG_HANDLE));
        }

        @TearDown
        public void tearDown() {
            this.finalizerService.destroy();
        }

        private NodeSchedulerConfig getNodeSchedulerConfig() {
            return new NodeSchedulerConfig().setMaxSplitsPerNode(100).setIncludeCoordinator(false).setNodeSchedulerPolicy(this.policy).setMinPendingSplitsPerTask(50);
        }

        private NodeSelectorFactory getNodeSelectorFactory(InternalNodeManager nodeManager, NodeTaskMap nodeTaskMap) {
            NodeSchedulerConfig nodeSchedulerConfig = this.getNodeSchedulerConfig();
            switch (this.policy) {
                case "uniform": {
                    return new UniformNodeSelectorFactory(nodeManager, nodeSchedulerConfig, nodeTaskMap);
                }
                case "topology": {
                    return new TopologyAwareNodeSelectorFactory((NetworkTopology)new FlatNetworkTopology(), nodeManager, nodeSchedulerConfig, nodeTaskMap, new TopologyAwareNodeSelectorConfig());
                }
                case "benchmark": {
                    return new TopologyAwareNodeSelectorFactory((NetworkTopology)new BenchmarkNetworkTopology(), nodeManager, nodeSchedulerConfig, nodeTaskMap, BenchmarkNodeScheduler.getBenchmarkNetworkTopologyConfig());
                }
            }
            throw new IllegalStateException();
        }

        public Map<InternalNode, MockRemoteTaskFactory.MockRemoteTask> getTaskMap() {
            return this.taskMap;
        }

        public NodeSelector getNodeSelector() {
            return this.nodeSelector;
        }

        public List<Split> getSplits() {
            return this.splits;
        }
    }

    private static class TestSplitRemote
    implements ConnectorSplit {
        private static final int INSTANCE_SIZE = SizeOf.instanceSize(TestSplitRemote.class);
        private final List<HostAddress> hosts;

        public TestSplitRemote(int dataHost) {
            this.hosts = ImmutableList.of((Object)BenchmarkNodeScheduler.addressForHost(dataHost));
        }

        public List<HostAddress> getAddresses() {
            return this.hosts;
        }

        public Object getInfo() {
            return this;
        }

        public long getRetainedSizeInBytes() {
            return (long)INSTANCE_SIZE + SizeOf.estimatedSizeOf(this.hosts, HostAddress::getRetainedSizeInBytes);
        }
    }

    private static class BenchmarkNetworkTopology
    implements NetworkTopology {
        private BenchmarkNetworkTopology() {
        }

        public NetworkLocation locate(HostAddress address) {
            ArrayList parts = new ArrayList(ImmutableList.copyOf((Iterable)Splitter.on((String)".").split((CharSequence)address.getHostText())));
            Collections.reverse(parts);
            return new NetworkLocation(parts);
        }
    }
}

