/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.join;

import io.airlift.slice.SizeOf;
import io.airlift.units.DataSize;
import io.trino.operator.HashArraySizeSupplier;
import io.trino.operator.PagesHashStrategy;
import io.trino.operator.SyntheticAddress;
import io.trino.operator.join.PagesHash;
import io.trino.operator.join.PositionLinks;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

public final class DefaultPagesHash
implements PagesHash {
    private static final int INSTANCE_SIZE = SizeOf.instanceSize(DefaultPagesHash.class);
    private static final DataSize CACHE_SIZE = DataSize.of((long)128L, (DataSize.Unit)DataSize.Unit.KILOBYTE);
    private final LongArrayList addresses;
    private final PagesHashStrategy pagesHashStrategy;
    private final int mask;
    private final int[] keys;
    private final long size;
    private final byte[] positionToHashes;

    public DefaultPagesHash(LongArrayList addresses, PagesHashStrategy pagesHashStrategy, PositionLinks.FactoryBuilder positionLinks, HashArraySizeSupplier hashArraySizeSupplier) {
        this.addresses = Objects.requireNonNull(addresses, "addresses is null");
        this.pagesHashStrategy = Objects.requireNonNull(pagesHashStrategy, "pagesHashStrategy is null");
        int hashSize = hashArraySizeSupplier.getHashArraySize(addresses.size());
        this.mask = hashSize - 1;
        this.keys = new int[hashSize];
        Arrays.fill(this.keys, -1);
        this.positionToHashes = new byte[addresses.size()];
        int positionsInStep = Math.min(addresses.size() + 1, (int)CACHE_SIZE.toBytes() / 32);
        long[] positionToFullHashes = new long[positionsInStep];
        int step = 0;
        while (step * positionsInStep <= addresses.size()) {
            int stepBeginPosition = step * positionsInStep;
            int stepEndPosition = Math.min((step + 1) * positionsInStep, addresses.size());
            int stepSize = stepEndPosition - stepBeginPosition;
            this.extractHashes(positionToFullHashes, stepBeginPosition, stepSize);
            this.indexPages(positionLinks, positionToFullHashes, stepBeginPosition, stepSize);
            ++step;
        }
        this.size = SizeOf.sizeOf((long[])addresses.elements()) + pagesHashStrategy.getSizeInBytes() + SizeOf.sizeOf((int[])this.keys) + SizeOf.sizeOf((byte[])this.positionToHashes);
    }

    private void extractHashes(long[] positionToFullHashes, int stepBeginPosition, int stepSize) {
        for (int batchIndex = 0; batchIndex < stepSize; ++batchIndex) {
            long hash;
            int addressIndex = batchIndex + stepBeginPosition;
            positionToFullHashes[batchIndex] = hash = this.readHashPosition(addressIndex);
            this.positionToHashes[addressIndex] = (byte)hash;
        }
    }

    private void indexPages(PositionLinks.FactoryBuilder positionLinks, long[] positionToFullHashes, int stepBeginPosition, int stepSize) {
        for (int position = 0; position < stepSize; ++position) {
            int realPosition = position + stepBeginPosition;
            if (this.isPositionNull(realPosition)) continue;
            long hash = positionToFullHashes[position];
            int pos = PagesHash.getHashPosition(hash, this.mask);
            this.insertValue(positionLinks, realPosition, (byte)hash, pos);
        }
    }

    private void insertValue(PositionLinks.FactoryBuilder positionLinks, int realPosition, byte hash, int pos) {
        while (this.keys[pos] != -1) {
            int currentKey = this.keys[pos];
            if (hash == this.positionToHashes[currentKey] && this.positionEqualsPositionIgnoreNulls(currentKey, realPosition)) {
                realPosition = positionLinks.link(realPosition, currentKey);
                break;
            }
            pos = pos + 1 & this.mask;
        }
        this.keys[pos] = realPosition;
    }

    @Override
    public int getPositionCount() {
        return this.addresses.size();
    }

    @Override
    public long getInMemorySizeInBytes() {
        return (long)INSTANCE_SIZE + this.size;
    }

    @Override
    public int getAddressIndex(int position, Page hashChannelsPage) {
        return this.getAddressIndex(position, hashChannelsPage, this.pagesHashStrategy.hashRow(position, hashChannelsPage));
    }

    @Override
    public int getAddressIndex(int rightPosition, Page hashChannelsPage, long rawHash) {
        int pos = PagesHash.getHashPosition(rawHash, this.mask);
        while (this.keys[pos] != -1) {
            if (this.positionEqualsCurrentRowIgnoreNulls(this.keys[pos], (byte)rawHash, rightPosition, hashChannelsPage)) {
                return this.keys[pos];
            }
            pos = pos + 1 & this.mask;
        }
        return -1;
    }

    @Override
    public int[] getAddressIndex(int[] positions, Page hashChannelsPage) {
        long[] hashes = new long[positions[positions.length - 1] + 1];
        for (int i = 0; i < positions.length; ++i) {
            hashes[positions[i]] = this.pagesHashStrategy.hashRow(positions[i], hashChannelsPage);
        }
        return this.getAddressIndex(positions, hashChannelsPage, hashes);
    }

    @Override
    public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes) {
        int positionCount = positions.length;
        int[] hashPositions = this.calculateHashPositions(positions, rawHashes, positionCount);
        int[] found = new int[positionCount];
        int foundCount = 0;
        int[] result = new int[positionCount];
        Arrays.fill(result, -1);
        int[] foundKeys = new int[positionCount];
        this.findPositions(positionCount, hashPositions, foundKeys);
        for (int i = 0; i < positionCount; ++i) {
            if (foundKeys[i] == -1) continue;
            found[foundCount++] = i;
        }
        int remainingCount = this.checkFoundPositions(positions, hashChannelsPage, rawHashes, found, foundCount, result, foundKeys);
        int[] remaining = found;
        this.findRemainingPositions(positions, hashChannelsPage, rawHashes, hashPositions, result, remainingCount, remaining);
        return result;
    }

    private void findRemainingPositions(int[] positions, Page hashChannelsPage, long[] rawHashes, int[] hashPositions, int[] result, int remainingCount, int[] remaining) {
        block0: for (int i = 0; i < remainingCount; ++i) {
            int index = remaining[i];
            int position = hashPositions[index] + 1 & this.mask;
            while (this.keys[position] != -1) {
                if (this.positionEqualsCurrentRowIgnoreNulls(this.keys[position], (byte)rawHashes[positions[index]], positions[index], hashChannelsPage)) {
                    result[index] = this.keys[position];
                    continue block0;
                }
                position = position + 1 & this.mask;
            }
        }
    }

    private int checkFoundPositions(int[] positions, Page hashChannelsPage, long[] rawHashes, int[] found, int foundCount, int[] result, int[] foundKeys) {
        int[] remaining = found;
        int remainingCount = 0;
        for (int i = 0; i < foundCount; ++i) {
            int index = found[i];
            if (this.positionEqualsCurrentRowIgnoreNulls(foundKeys[index], (byte)rawHashes[positions[index]], positions[index], hashChannelsPage)) {
                result[index] = foundKeys[index];
                continue;
            }
            remaining[remainingCount++] = index;
        }
        return remainingCount;
    }

    private void findPositions(int positionCount, int[] hashPositions, int[] foundKeys) {
        for (int i = 0; i < positionCount; ++i) {
            foundKeys[i] = this.keys[hashPositions[i]];
        }
    }

    private int[] calculateHashPositions(int[] positions, long[] rawHashes, int positionCount) {
        int[] hashPositions = new int[positionCount];
        for (int i = 0; i < positionCount; ++i) {
            hashPositions[i] = PagesHash.getHashPosition(rawHashes[positions[i]], this.mask);
        }
        return hashPositions;
    }

    @Override
    public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset) {
        long pageAddress = this.addresses.getLong(Math.toIntExact(position));
        int blockIndex = SyntheticAddress.decodeSliceIndex(pageAddress);
        int blockPosition = SyntheticAddress.decodePosition(pageAddress);
        this.pagesHashStrategy.appendTo(blockIndex, blockPosition, pageBuilder, outputChannelOffset);
    }

    private boolean isPositionNull(int position) {
        long pageAddress = this.addresses.getLong(position);
        int blockIndex = SyntheticAddress.decodeSliceIndex(pageAddress);
        int blockPosition = SyntheticAddress.decodePosition(pageAddress);
        return this.pagesHashStrategy.isPositionNull(blockIndex, blockPosition);
    }

    private long readHashPosition(int position) {
        long pageAddress = this.addresses.getLong(position);
        int blockIndex = SyntheticAddress.decodeSliceIndex(pageAddress);
        int blockPosition = SyntheticAddress.decodePosition(pageAddress);
        return this.pagesHashStrategy.hashPosition(blockIndex, blockPosition);
    }

    private boolean positionEqualsCurrentRowIgnoreNulls(int leftPosition, byte rawHash, int rightPosition, Page rightPage) {
        if (this.positionToHashes[leftPosition] != rawHash) {
            return false;
        }
        long pageAddress = this.addresses.getLong(leftPosition);
        int blockIndex = SyntheticAddress.decodeSliceIndex(pageAddress);
        int blockPosition = SyntheticAddress.decodePosition(pageAddress);
        return this.pagesHashStrategy.positionEqualsRowIgnoreNulls(blockIndex, blockPosition, rightPosition, rightPage);
    }

    private boolean positionEqualsPositionIgnoreNulls(int leftPosition, int rightPosition) {
        long leftPageAddress = this.addresses.getLong(leftPosition);
        int leftBlockIndex = SyntheticAddress.decodeSliceIndex(leftPageAddress);
        int leftBlockPosition = SyntheticAddress.decodePosition(leftPageAddress);
        long rightPageAddress = this.addresses.getLong(rightPosition);
        int rightBlockIndex = SyntheticAddress.decodeSliceIndex(rightPageAddress);
        int rightBlockPosition = SyntheticAddress.decodePosition(rightPageAddress);
        return this.pagesHashStrategy.positionEqualsPositionIgnoreNulls(leftBlockIndex, leftBlockPosition, rightBlockIndex, rightBlockPosition);
    }

    public static long getEstimatedRetainedSizeInBytes(int positionCount, HashArraySizeSupplier hashArraySizeSupplier, LongArrayList addresses, List<ObjectArrayList<Block>> channels, long blocksSizeInBytes) {
        return SizeOf.sizeOf((long[])addresses.elements()) + (channels.size() > 0 ? SizeOf.sizeOf((Object[])channels.get(0).elements()) * (long)channels.size() : 0L) + blocksSizeInBytes + SizeOf.sizeOfIntArray((int)hashArraySizeSupplier.getHashArraySize(positionCount)) + SizeOf.sizeOfByteArray((int)positionCount);
    }
}

