/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.join;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterators;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.concurrent.MoreFutures;
import io.airlift.units.DataSize;
import io.trino.RowPagesBuilder;
import io.trino.Session;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.operator.Driver;
import io.trino.operator.DriverContext;
import io.trino.operator.HashArraySizeSupplier;
import io.trino.operator.JoinOperatorType;
import io.trino.operator.Operator;
import io.trino.operator.OperatorFactories;
import io.trino.operator.OperatorFactory;
import io.trino.operator.PagesIndex;
import io.trino.operator.PipelineContext;
import io.trino.operator.SpillContext;
import io.trino.operator.TaskContext;
import io.trino.operator.ValuesOperator;
import io.trino.operator.exchange.LocalExchange;
import io.trino.operator.exchange.LocalExchangeSinkOperator;
import io.trino.operator.exchange.LocalExchangeSourceOperator;
import io.trino.operator.join.HashBuilderOperator;
import io.trino.operator.join.InternalJoinFilterFunction;
import io.trino.operator.join.JoinBridgeManager;
import io.trino.operator.join.LookupSourceFactory;
import io.trino.operator.join.LookupSourceProvider;
import io.trino.operator.join.PartitionedLookupSourceFactory;
import io.trino.operator.join.StandardJoinFilterFunction;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.Page;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.spiller.PartitioningSpillerFactory;
import io.trino.spiller.SingleStreamSpiller;
import io.trino.spiller.SingleStreamSpillerFactory;
import io.trino.sql.gen.JoinFilterFunctionCompiler;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.type.BlockTypeOperators;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Function;
import java.util.stream.IntStream;

public final class JoinTestUtils {
    private static final int PARTITION_COUNT = 4;
    private static final BlockTypeOperators TYPE_OPERATOR_FACTORY = new BlockTypeOperators(new TypeOperators());

    private JoinTestUtils() {
    }

    public static OperatorFactory innerJoinOperatorFactory(JoinBridgeManager<PartitionedLookupSourceFactory> lookupSourceFactoryManager, RowPagesBuilder probePages, PartitioningSpillerFactory partitioningSpillerFactory, boolean hasFilter) {
        return JoinTestUtils.innerJoinOperatorFactory(lookupSourceFactoryManager, probePages, partitioningSpillerFactory, false, hasFilter);
    }

    public static OperatorFactory innerJoinOperatorFactory(JoinBridgeManager<PartitionedLookupSourceFactory> lookupSourceFactoryManager, RowPagesBuilder probePages, PartitioningSpillerFactory partitioningSpillerFactory, boolean outputSingleMatch, boolean hasFilter) {
        return OperatorFactories.spillingJoin((JoinOperatorType)JoinOperatorType.innerJoin((boolean)outputSingleMatch, (boolean)false), (int)0, (PlanNodeId)new PlanNodeId("test"), lookupSourceFactoryManager, (boolean)hasFilter, probePages.getTypes(), probePages.getHashChannels().orElseThrow(), (OptionalInt)JoinTestUtils.getHashChannelAsInt(probePages), Optional.empty(), (OptionalInt)OptionalInt.of(1), (PartitioningSpillerFactory)partitioningSpillerFactory, (BlockTypeOperators)TYPE_OPERATOR_FACTORY);
    }

    public static void instantiateBuildDrivers(BuildSideSetup buildSideSetup, TaskContext taskContext) {
        PipelineContext buildPipeline = taskContext.addPipelineContext(1, true, true, false);
        ArrayList<Driver> buildDrivers = new ArrayList<Driver>();
        ArrayList<HashBuilderOperator> buildOperators = new ArrayList<HashBuilderOperator>();
        for (int i = 0; i < buildSideSetup.getPartitionCount(); ++i) {
            DriverContext buildDriverContext = buildPipeline.addDriverContext();
            HashBuilderOperator buildOperator = buildSideSetup.getBuildOperatorFactory().createOperator(buildDriverContext);
            Driver driver = Driver.createDriver((DriverContext)buildDriverContext, (Operator)buildSideSetup.getBuildSideSourceOperatorFactory().createOperator(buildDriverContext), (Operator[])new Operator[]{buildOperator});
            buildDrivers.add(driver);
            buildOperators.add(buildOperator);
        }
        buildSideSetup.setDriversAndOperators(buildDrivers, buildOperators);
    }

    public static BuildSideSetup setupBuildSide(NodePartitioningManager nodePartitioningManager, boolean parallelBuild, TaskContext taskContext, RowPagesBuilder buildPages, Optional<InternalJoinFilterFunction> filterFunction, boolean spillEnabled, SingleStreamSpillerFactory singleStreamSpillerFactory) {
        Optional<JoinFilterFunctionCompiler.JoinFilterFunctionFactory> filterFunctionFactory = filterFunction.map(function -> (session, addresses, pages) -> new StandardJoinFilterFunction(function, addresses, pages));
        int partitionCount = parallelBuild ? 4 : 1;
        List<Integer> hashChannels = buildPages.getHashChannels().orElseThrow();
        List<Type> types = buildPages.getTypes();
        List hashChannelTypes = (List)hashChannels.stream().map(types::get).collect(ImmutableList.toImmutableList());
        LocalExchange localExchange = new LocalExchange(nodePartitioningManager, taskContext.getSession(), partitionCount, SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION, hashChannels, hashChannelTypes, buildPages.getHashChannel(), DataSize.of((long)32L, (DataSize.Unit)DataSize.Unit.MEGABYTE), TYPE_OPERATOR_FACTORY, DataSize.of((long)32L, (DataSize.Unit)DataSize.Unit.MEGABYTE));
        DriverContext collectDriverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext();
        ValuesOperator.ValuesOperatorFactory valuesOperatorFactory = new ValuesOperator.ValuesOperatorFactory(0, new PlanNodeId("values"), buildPages.build());
        LocalExchangeSinkOperator.LocalExchangeSinkOperatorFactory sinkOperatorFactory = new LocalExchangeSinkOperator.LocalExchangeSinkOperatorFactory(localExchange.createSinkFactory(), 1, new PlanNodeId("sink"), Function.identity());
        Driver sourceDriver = Driver.createDriver((DriverContext)collectDriverContext, (Operator)valuesOperatorFactory.createOperator(collectDriverContext), (Operator[])new Operator[]{sinkOperatorFactory.createOperator(collectDriverContext)});
        valuesOperatorFactory.noMoreOperators();
        sinkOperatorFactory.noMoreOperators();
        sinkOperatorFactory.localPlannerComplete();
        while (!sourceDriver.isFinished()) {
            sourceDriver.processUntilBlocked();
        }
        LocalExchangeSourceOperator.LocalExchangeSourceOperatorFactory sourceOperatorFactory = new LocalExchangeSourceOperator.LocalExchangeSourceOperatorFactory(0, new PlanNodeId("source"), localExchange);
        JoinBridgeManager lookupSourceFactoryManager = JoinBridgeManager.lookupAllAtOnce((PartitionedLookupSourceFactory)new PartitionedLookupSourceFactory(buildPages.getTypes(), (List)JoinTestUtils.rangeList(buildPages.getTypes().size()).stream().map(buildPages.getTypes()::get).collect(ImmutableList.toImmutableList()), (List)hashChannels.stream().map(buildPages.getTypes()::get).collect(ImmutableList.toImmutableList()), partitionCount, false, TYPE_OPERATOR_FACTORY));
        HashBuilderOperator.HashBuilderOperatorFactory buildOperatorFactory = new HashBuilderOperator.HashBuilderOperatorFactory(1, new PlanNodeId("build"), lookupSourceFactoryManager, JoinTestUtils.rangeList(buildPages.getTypes().size()), hashChannels, buildPages.getHashChannel().map(OptionalInt::of).orElse(OptionalInt.empty()), filterFunctionFactory, Optional.empty(), (List)ImmutableList.of(), 100, (PagesIndex.Factory)new PagesIndex.TestingFactory(false), spillEnabled, singleStreamSpillerFactory, HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier((Session)taskContext.getSession()));
        return new BuildSideSetup((JoinBridgeManager<PartitionedLookupSourceFactory>)lookupSourceFactoryManager, buildOperatorFactory, sourceOperatorFactory, partitionCount);
    }

    public static void buildLookupSource(ExecutorService executor, BuildSideSetup buildSideSetup) {
        Objects.requireNonNull(buildSideSetup, "buildSideSetup is null");
        LookupSourceFactory lookupSourceFactory = (LookupSourceFactory)buildSideSetup.getLookupSourceFactoryManager().getJoinBridge();
        ListenableFuture lookupSourceProvider = lookupSourceFactory.createLookupSourceProvider();
        List<Driver> buildDrivers = buildSideSetup.getBuildDrivers();
        while (!lookupSourceProvider.isDone()) {
            for (Driver buildDriver : buildDrivers) {
                buildDriver.processForNumberOfIterations(1);
            }
        }
        ((LookupSourceProvider)MoreFutures.getFutureValue((Future)lookupSourceProvider)).close();
        for (Driver buildDriver : buildDrivers) {
            JoinTestUtils.runDriverInThread(executor, buildDriver);
        }
    }

    public static void runDriverInThread(ExecutorService executor, Driver driver) {
        executor.execute(() -> {
            if (!driver.isFinished()) {
                try {
                    driver.processUntilBlocked();
                }
                catch (TrinoException e) {
                    driver.getDriverContext().failed((Throwable)e);
                    return;
                }
                JoinTestUtils.runDriverInThread(executor, driver);
            }
        });
    }

    public static OptionalInt getHashChannelAsInt(RowPagesBuilder probePages) {
        return probePages.getHashChannel().map(OptionalInt::of).orElse(OptionalInt.empty());
    }

    private static List<Integer> rangeList(int endExclusive) {
        return (List)IntStream.range(0, endExclusive).boxed().collect(ImmutableList.toImmutableList());
    }

    public static class BuildSideSetup {
        private final JoinBridgeManager<PartitionedLookupSourceFactory> lookupSourceFactoryManager;
        private final HashBuilderOperator.HashBuilderOperatorFactory buildOperatorFactory;
        private final LocalExchangeSourceOperator.LocalExchangeSourceOperatorFactory buildSideSourceOperatorFactory;
        private final int partitionCount;
        private List<Driver> buildDrivers;
        private List<HashBuilderOperator> buildOperators;

        public BuildSideSetup(JoinBridgeManager<PartitionedLookupSourceFactory> lookupSourceFactoryManager, HashBuilderOperator.HashBuilderOperatorFactory buildOperatorFactory, LocalExchangeSourceOperator.LocalExchangeSourceOperatorFactory buildSideSourceOperatorFactory, int partitionCount) {
            this.lookupSourceFactoryManager = Objects.requireNonNull(lookupSourceFactoryManager, "lookupSourceFactoryManager is null");
            this.buildOperatorFactory = Objects.requireNonNull(buildOperatorFactory, "buildOperatorFactory is null");
            this.buildSideSourceOperatorFactory = buildSideSourceOperatorFactory;
            this.partitionCount = partitionCount;
        }

        public void setDriversAndOperators(List<Driver> buildDrivers, List<HashBuilderOperator> buildOperators) {
            Preconditions.checkArgument((buildDrivers.size() == buildOperators.size() ? 1 : 0) != 0);
            this.buildDrivers = ImmutableList.copyOf(buildDrivers);
            this.buildOperators = ImmutableList.copyOf(buildOperators);
        }

        public JoinBridgeManager<PartitionedLookupSourceFactory> getLookupSourceFactoryManager() {
            return this.lookupSourceFactoryManager;
        }

        public HashBuilderOperator.HashBuilderOperatorFactory getBuildOperatorFactory() {
            return this.buildOperatorFactory;
        }

        public LocalExchangeSourceOperator.LocalExchangeSourceOperatorFactory getBuildSideSourceOperatorFactory() {
            return this.buildSideSourceOperatorFactory;
        }

        public int getPartitionCount() {
            return this.partitionCount;
        }

        public List<Driver> getBuildDrivers() {
            Preconditions.checkState((this.buildDrivers != null ? 1 : 0) != 0, (Object)"buildDrivers is not initialized yet");
            return this.buildDrivers;
        }

        public List<HashBuilderOperator> getBuildOperators() {
            Preconditions.checkState((this.buildOperators != null ? 1 : 0) != 0, (Object)"buildDrivers is not initialized yet");
            return this.buildOperators;
        }
    }

    public static class TestInternalJoinFilterFunction
    implements InternalJoinFilterFunction {
        private final Lambda lambda;

        public TestInternalJoinFilterFunction(Lambda lambda) {
            this.lambda = lambda;
        }

        public boolean filter(int leftPosition, Page leftPage, int rightPosition, Page rightPage) {
            return this.lambda.filter(leftPosition, leftPage, rightPosition, rightPage);
        }

        public static interface Lambda {
            public boolean filter(int var1, Page var2, int var3, Page var4);
        }
    }

    public static class DummySpillerFactory
    implements SingleStreamSpillerFactory {
        private volatile boolean failSpill;
        private volatile boolean failUnspill;

        public void failSpill() {
            this.failSpill = true;
        }

        public void failUnspill() {
            this.failUnspill = true;
        }

        public SingleStreamSpiller create(List<Type> types, SpillContext spillContext, LocalMemoryContext memoryContext) {
            return new SingleStreamSpiller(){
                private boolean writing = true;
                private final List<Page> spills = new ArrayList<Page>();

                public ListenableFuture<Void> spill(Iterator<Page> pageIterator) {
                    Preconditions.checkState((boolean)this.writing, (Object)"writing already finished");
                    if (failSpill) {
                        return Futures.immediateFailedFuture((Throwable)new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Spill failed"));
                    }
                    Iterators.addAll(this.spills, pageIterator);
                    return Futures.immediateVoidFuture();
                }

                public Iterator<Page> getSpilledPages() {
                    if (failUnspill) {
                        throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Unspill failed");
                    }
                    this.writing = false;
                    return Iterators.unmodifiableIterator(this.spills.iterator());
                }

                public long getSpilledPagesInMemorySize() {
                    return this.spills.stream().mapToLong(Page::getSizeInBytes).sum();
                }

                public ListenableFuture<List<Page>> getAllSpilledPages() {
                    if (failUnspill) {
                        return Futures.immediateFailedFuture((Throwable)new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Unspill failed"));
                    }
                    this.writing = false;
                    return Futures.immediateFuture((Object)ImmutableList.copyOf(this.spills));
                }

                public void close() {
                    this.writing = false;
                }
            };
        }
    }
}

