/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.iceberg;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.trino.plugin.iceberg.IcebergPartitionFunction;
import io.trino.plugin.iceberg.IcebergPartitioningHandle;
import io.trino.plugin.iceberg.IcebergSplit;
import io.trino.plugin.iceberg.PartitionTransforms;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.RowBlock;
import io.trino.spi.connector.BucketFunction;
import io.trino.spi.connector.ConnectorSplit;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.type.TypeOperators;
import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Objects;
import java.util.function.ToIntFunction;

public class IcebergBucketFunction
implements BucketFunction,
ToIntFunction<ConnectorSplit> {
    private final int bucketCount;
    private final List<HashFunction> functions;
    private final boolean singleBucketFunction;

    public IcebergBucketFunction(IcebergPartitioningHandle partitioningHandle, TypeOperators typeOperators, int bucketCount) {
        Objects.requireNonNull(partitioningHandle, "partitioningHandle is null");
        Objects.requireNonNull(typeOperators, "typeOperators is null");
        Preconditions.checkArgument((bucketCount > 0 ? 1 : 0) != 0, (String)"Invalid bucketCount: %s", (int)bucketCount);
        this.bucketCount = bucketCount;
        List<IcebergPartitionFunction> partitionFunctions = partitioningHandle.partitionFunctions();
        this.functions = (List)partitionFunctions.stream().map(partitionFunction -> HashFunction.create(partitionFunction, typeOperators)).collect(ImmutableList.toImmutableList());
        this.singleBucketFunction = partitionFunctions.size() == 1 && partitionFunctions.getFirst().transform() == IcebergPartitionFunction.Transform.BUCKET && partitionFunctions.getFirst().size().orElseThrow() == bucketCount;
    }

    public int getBucket(Page page, int position) {
        if (this.singleBucketFunction) {
            long bucket = Objects.requireNonNullElse(this.functions.getFirst().getValue(page, position), 0L);
            Preconditions.checkArgument((0L <= bucket && bucket < (long)this.bucketCount ? 1 : 0) != 0, (String)"Bucket value out of range: %s (bucketCount: %s)", (long)bucket, (int)this.bucketCount);
            return (int)bucket;
        }
        long hash = 0L;
        for (HashFunction function : this.functions) {
            long valueHash = function.computeHash(page, position);
            hash = 31L * hash + valueHash;
        }
        return (int)((hash & Long.MAX_VALUE) % (long)this.bucketCount);
    }

    @Override
    public int applyAsInt(ConnectorSplit split) {
        List<Object> partitionValues = ((IcebergSplit)split).getPartitionValues().orElseThrow(() -> new IllegalArgumentException("Split does not contain partition values"));
        if (this.singleBucketFunction) {
            long bucket = Objects.requireNonNullElse(partitionValues.getFirst(), 0L);
            Preconditions.checkArgument((0L <= bucket && bucket < (long)this.bucketCount ? 1 : 0) != 0, (String)"Bucket value out of range: %s (bucketCount: %s)", (long)bucket, (int)this.bucketCount);
            return (int)bucket;
        }
        long hash = 0L;
        for (int i = 0; i < this.functions.size(); ++i) {
            long valueHash = this.functions.get(i).computeHash(partitionValues.get(i));
            hash = 31L * hash + valueHash;
        }
        return (int)((hash & Long.MAX_VALUE) % (long)this.bucketCount);
    }

    private record HashFunction(List<Integer> dataPath, PartitionTransforms.ValueTransform valueTransform, MethodHandle hashCodeOperator) {
        private HashFunction {
            Objects.requireNonNull(valueTransform, "valueTransform is null");
            Objects.requireNonNull(hashCodeOperator, "hashCodeOperator is null");
        }

        private static HashFunction create(IcebergPartitionFunction partitionFunction, TypeOperators typeOperators) {
            PartitionTransforms.ColumnTransform columnTransform = PartitionTransforms.getColumnTransform(partitionFunction);
            return new HashFunction(partitionFunction.dataPath(), columnTransform.valueTransform(), typeOperators.getHashCodeOperator(columnTransform.type(), InvocationConvention.simpleConvention((InvocationConvention.InvocationReturnConvention)InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (InvocationConvention.InvocationArgumentConvention[])new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL})));
        }

        public Object getValue(Page page, int position) {
            Block block = page.getBlock(this.dataPath.getFirst().intValue());
            for (int i = 1; i < this.dataPath.size(); ++i) {
                position = block.getUnderlyingValuePosition(position);
                block = ((RowBlock)block.getUnderlyingValueBlock()).getFieldBlock(this.dataPath.get(i).intValue());
            }
            return this.valueTransform.apply(block, position);
        }

        public long computeHash(Page page, int position) {
            return this.computeHash(this.getValue(page, position));
        }

        private long computeHash(Object value) {
            if (value == null) {
                return 0L;
            }
            try {
                return this.hashCodeOperator.invoke(value);
            }
            catch (Throwable throwable) {
                if (throwable instanceof Error) {
                    Error error = (Error)throwable;
                    throw error;
                }
                if (throwable instanceof RuntimeException) {
                    RuntimeException runtimeException = (RuntimeException)throwable;
                    throw runtimeException;
                }
                throw new RuntimeException(throwable);
            }
        }
    }
}

