/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.join;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.io.Closer;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.trino.annotation.NotThreadSafe;
import io.trino.operator.InterpretedHashGenerator;
import io.trino.operator.exchange.LocalPartitionGenerator;
import io.trino.operator.join.LookupSource;
import io.trino.operator.join.OuterPositionIterator;
import io.trino.operator.join.TrackingLookupSourceSupplier;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import jakarta.annotation.Nullable;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;

@NotThreadSafe
public class PartitionedLookupSource
implements LookupSource {
    private final LookupSource[] lookupSources;
    private final LocalPartitionGenerator partitionGenerator;
    private final int partitionMask;
    private final int shiftSize;
    @Nullable
    private final OuterPositionTracker outerPositionTracker;
    private boolean closed;

    public static TrackingLookupSourceSupplier createPartitionedLookupSourceSupplier(final List<Supplier<LookupSource>> partitions, final List<Type> hashChannelTypes, boolean outer, final TypeOperators typeOperators) {
        if (outer) {
            final OuterPositionTracker.Factory outerPositionTrackerFactory = new OuterPositionTracker.Factory(partitions);
            return new TrackingLookupSourceSupplier(){

                @Override
                public LookupSource getLookupSource() {
                    return new PartitionedLookupSource((List)partitions.stream().map(Supplier::get).collect(ImmutableList.toImmutableList()), hashChannelTypes, Optional.of(outerPositionTrackerFactory.create()), typeOperators);
                }

                @Override
                public OuterPositionIterator getOuterPositionIterator() {
                    return outerPositionTrackerFactory.getOuterPositionIterator();
                }
            };
        }
        return TrackingLookupSourceSupplier.nonTracking(() -> new PartitionedLookupSource((List)partitions.stream().map(Supplier::get).collect(ImmutableList.toImmutableList()), hashChannelTypes, Optional.empty(), typeOperators));
    }

    private PartitionedLookupSource(List<? extends LookupSource> lookupSources, List<Type> hashChannelTypes, Optional<OuterPositionTracker> outerPositionTracker, TypeOperators typeOperators) {
        this.lookupSources = lookupSources.toArray(new LookupSource[lookupSources.size()]);
        this.partitionGenerator = new LocalPartitionGenerator(InterpretedHashGenerator.createPagePrefixHashGenerator(hashChannelTypes, typeOperators), lookupSources.size());
        this.partitionMask = lookupSources.size() - 1;
        this.shiftSize = Integer.numberOfTrailingZeros(lookupSources.size()) + 1;
        this.outerPositionTracker = outerPositionTracker.orElse(null);
    }

    @Override
    public boolean isEmpty() {
        return Arrays.stream(this.lookupSources).allMatch(LookupSource::isEmpty);
    }

    @Override
    public long getJoinPositionCount() {
        return Arrays.stream(this.lookupSources).mapToLong(LookupSource::getJoinPositionCount).sum();
    }

    @Override
    public long getInMemorySizeInBytes() {
        return Arrays.stream(this.lookupSources).mapToLong(LookupSource::getInMemorySizeInBytes).sum();
    }

    @Override
    public long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage) {
        return this.getJoinPosition(position, hashChannelsPage, allChannelsPage, this.partitionGenerator.getRawHash(hashChannelsPage, position));
    }

    @Override
    public long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage, long rawHash) {
        int partition = this.partitionGenerator.getPartition(rawHash);
        LookupSource lookupSource = this.lookupSources[partition];
        long joinPosition = lookupSource.getJoinPosition(position, hashChannelsPage, allChannelsPage, rawHash);
        if (joinPosition < 0L) {
            return joinPosition;
        }
        return this.encodePartitionedJoinPosition(partition, Math.toIntExact(joinPosition));
    }

    @Override
    public long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage) {
        long joinPosition;
        int partition = this.decodePartition(currentJoinPosition);
        LookupSource lookupSource = this.lookupSources[partition];
        long nextJoinPosition = lookupSource.getNextJoinPosition(joinPosition = (long)this.decodeJoinPosition(currentJoinPosition), probePosition, allProbeChannelsPage);
        if (nextJoinPosition < 0L) {
            return nextJoinPosition;
        }
        return this.encodePartitionedJoinPosition(partition, Math.toIntExact(nextJoinPosition));
    }

    @Override
    public boolean isJoinPositionEligible(long currentJoinPosition, int probePosition, Page allProbeChannelsPage) {
        int partition = this.decodePartition(currentJoinPosition);
        long joinPosition = this.decodeJoinPosition(currentJoinPosition);
        LookupSource lookupSource = this.lookupSources[partition];
        return lookupSource.isJoinPositionEligible(joinPosition, probePosition, allProbeChannelsPage);
    }

    @Override
    public void appendTo(long partitionedJoinPosition, PageBuilder pageBuilder, int outputChannelOffset) {
        int partition = this.decodePartition(partitionedJoinPosition);
        int joinPosition = this.decodeJoinPosition(partitionedJoinPosition);
        this.lookupSources[partition].appendTo(joinPosition, pageBuilder, outputChannelOffset);
        if (this.outerPositionTracker != null) {
            this.outerPositionTracker.positionVisited(partition, joinPosition);
        }
    }

    @Override
    public long joinPositionWithinPartition(long joinPosition) {
        return this.decodeJoinPosition(joinPosition);
    }

    @Override
    public void close() {
        if (this.closed) {
            return;
        }
        try (Closer closer = Closer.create();){
            if (this.outerPositionTracker != null) {
                closer.register(this.outerPositionTracker::commit);
            }
            Arrays.stream(this.lookupSources).forEach(arg_0 -> ((Closer)closer).register(arg_0));
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
        this.closed = true;
    }

    private int decodePartition(long partitionedJoinPosition) {
        return (int)(partitionedJoinPosition & (long)this.partitionMask);
    }

    private int decodeJoinPosition(long partitionedJoinPosition) {
        return Math.toIntExact(partitionedJoinPosition >>> this.shiftSize);
    }

    private long encodePartitionedJoinPosition(int partition, int joinPosition) {
        return (long)joinPosition << this.shiftSize | (long)partition;
    }

    private static class OuterPositionTracker {
        private final boolean[][] visitedPositions;
        private final AtomicBoolean finished;
        private final AtomicLong referenceCount;
        private boolean written;

        private OuterPositionTracker(boolean[][] visitedPositions, AtomicBoolean finished, AtomicLong referenceCount) {
            this.visitedPositions = visitedPositions;
            this.finished = finished;
            this.referenceCount = referenceCount;
        }

        public void positionVisited(int partition, int position) {
            if (!this.written) {
                this.written = true;
                Verify.verify((!this.finished.get() ? 1 : 0) != 0);
                this.referenceCount.incrementAndGet();
            }
            this.visitedPositions[partition][position] = true;
        }

        public void commit() {
            if (this.written) {
                this.referenceCount.decrementAndGet();
            }
        }

        public static class Factory {
            private final LookupSource[] lookupSources;
            private final boolean[][] visitedPositions;
            private final AtomicBoolean finished = new AtomicBoolean();
            private final AtomicLong referenceCount = new AtomicLong();

            public Factory(List<Supplier<LookupSource>> partitions) {
                this.lookupSources = (LookupSource[])partitions.stream().map(Supplier::get).toArray(LookupSource[]::new);
                this.visitedPositions = (boolean[][])Arrays.stream(this.lookupSources).map(LookupSource::getJoinPositionCount).map(Math::toIntExact).map(x$0 -> new boolean[x$0.intValue()]).toArray(x$0 -> new boolean[x$0][]);
            }

            public OuterPositionTracker create() {
                return new OuterPositionTracker(this.visitedPositions, this.finished, this.referenceCount);
            }

            public OuterPositionIterator getOuterPositionIterator() {
                Verify.verify((this.referenceCount.get() == 0L ? 1 : 0) != 0);
                this.finished.set(true);
                return new PartitionedLookupOuterPositionIterator(this.lookupSources, this.visitedPositions);
            }
        }
    }

    private static class PartitionedLookupOuterPositionIterator
    implements OuterPositionIterator {
        private final LookupSource[] lookupSources;
        private final boolean[][] visitedPositions;
        @GuardedBy(value="this")
        private int currentSource;
        @GuardedBy(value="this")
        private int currentPosition;

        public PartitionedLookupOuterPositionIterator(LookupSource[] lookupSources, boolean[][] visitedPositions) {
            this.lookupSources = lookupSources;
            this.visitedPositions = visitedPositions;
        }

        @Override
        public synchronized boolean appendToNext(PageBuilder pageBuilder, int outputChannelOffset) {
            while (this.currentSource < this.lookupSources.length) {
                while (this.currentPosition < this.visitedPositions[this.currentSource].length) {
                    if (!this.visitedPositions[this.currentSource][this.currentPosition]) {
                        this.lookupSources[this.currentSource].appendTo(this.currentPosition, pageBuilder, outputChannelOffset);
                        ++this.currentPosition;
                        return true;
                    }
                    ++this.currentPosition;
                }
                this.currentPosition = 0;
                ++this.currentSource;
            }
            return false;
        }
    }
}

