/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import io.trino.operator.aggregation.AggregationMask;
import io.trino.operator.aggregation.AggregationMaskBuilder;
import io.trino.operator.aggregation.AggregationMaskCompiler;
import io.trino.operator.aggregation.InterpretedAggregationMaskBuilder;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.ByteArrayBlock;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.IntArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.block.ShortArrayBlock;
import java.util.Arrays;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestAggregationMaskCompiler {
    private static final Supplier<AggregationMaskBuilder> INTERPRETED_MASK_BUILDER_SUPPLIER = () -> new InterpretedAggregationMaskBuilder(1);
    private static final Supplier<AggregationMaskBuilder> COMPILED_MASK_BUILDER_SUPPLIER = () -> {
        try {
            return (AggregationMaskBuilder)AggregationMaskCompiler.generateAggregationMaskBuilder((int[])new int[]{1}).newInstance(new Object[0]);
        }
        catch (ReflectiveOperationException e) {
            throw new RuntimeException(e);
        }
    };

    @Test
    public void testSupplier() {
        this.testSupplier(INTERPRETED_MASK_BUILDER_SUPPLIER);
        this.testSupplier(COMPILED_MASK_BUILDER_SUPPLIER);
    }

    private void testSupplier(Supplier<AggregationMaskBuilder> maskBuilderSupplier) {
        Assertions.assertThat((Object)maskBuilderSupplier.get()).isNotSameAs((Object)maskBuilderSupplier.get());
        Page page = TestAggregationMaskCompiler.buildSingleColumnPage(5);
        Assertions.assertThat((Object)maskBuilderSupplier.get().buildAggregationMask(page, Optional.empty())).isNotSameAs((Object)maskBuilderSupplier.get().buildAggregationMask(page, Optional.empty()));
        boolean[] nullFlags = new boolean[5];
        nullFlags[1] = true;
        nullFlags[3] = true;
        Page pageWithNulls = TestAggregationMaskCompiler.buildSingleColumnPage(nullFlags);
        Assertions.assertThat((Object)maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty())).isNotSameAs((Object)maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty()));
        Assertions.assertThat((int[])maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions()).isNotSameAs((Object)maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions());
        Assertions.assertThat((int[])maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions()).isEqualTo((Object)maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions());
        AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get();
        Assertions.assertThat((int[])maskBuilder.buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions()).isSameAs((Object)maskBuilder.buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions());
    }

    @Test
    public void testUnsetNulls() {
        this.testUnsetNulls(INTERPRETED_MASK_BUILDER_SUPPLIER);
        this.testUnsetNulls(COMPILED_MASK_BUILDER_SUPPLIER);
    }

    private void testUnsetNulls(Supplier<AggregationMaskBuilder> maskBuilderSupplier) {
        AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get();
        AggregationMask aggregationMask = maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(0), Optional.empty());
        TestAggregationMaskCompiler.assertAggregationMaskAll(aggregationMask, 0);
        for (int positionCount = 7; positionCount < 10; ++positionCount) {
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPageRle(positionCount, Optional.of(true)), Optional.empty()), positionCount, new int[0]);
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.empty()), positionCount);
            boolean[] nullFlags = new boolean[positionCount];
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(nullFlags), Optional.empty()), positionCount);
            Arrays.fill(nullFlags, true);
            nullFlags[1] = false;
            nullFlags[3] = false;
            nullFlags[5] = false;
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(nullFlags), Optional.empty()), positionCount, 1, 3, 5);
            nullFlags[3] = true;
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(nullFlags), Optional.empty()), positionCount, 1, 5);
            nullFlags[2] = false;
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(nullFlags), Optional.empty()), positionCount, 1, 2, 5);
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPageRle(positionCount, Optional.empty()), Optional.empty()), positionCount);
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPageRle(positionCount, Optional.of(false)), Optional.empty()), positionCount);
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPageRle(positionCount, Optional.of(true)), Optional.empty()), positionCount, new int[0]);
        }
    }

    @Test
    public void testApplyMask() {
        this.testApplyMask(INTERPRETED_MASK_BUILDER_SUPPLIER);
        this.testApplyMask(COMPILED_MASK_BUILDER_SUPPLIER);
    }

    private void testApplyMask(Supplier<AggregationMaskBuilder> maskBuilderSupplier) {
        AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get();
        for (int positionCount = 7; positionCount < 10; ++positionCount) {
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockRle(positionCount, (byte)1))), positionCount);
            byte[] mask = new byte[positionCount];
            Arrays.fill(mask, (byte)1);
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlock(positionCount, mask))), positionCount);
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockAsDictionary(positionCount, mask))), positionCount);
            Arrays.fill(mask, (byte)0);
            mask[1] = 1;
            mask[3] = 1;
            mask[5] = 1;
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlock(positionCount, mask))), positionCount, 1, 3, 5);
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockAsDictionary(positionCount, mask))), positionCount, 1, 3, 5);
            mask[3] = 0;
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlock(positionCount, mask))), positionCount, 1, 5);
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockAsDictionary(positionCount, mask))), positionCount, 1, 5);
            mask[2] = 1;
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlock(positionCount, mask))), positionCount, 1, 2, 5);
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockAsDictionary(positionCount, mask))), positionCount, 1, 2, 5);
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockRle(positionCount, (byte)1))), positionCount);
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockRle(positionCount, (byte)0))), positionCount, new int[0]);
        }
    }

    @Test
    public void testApplyMaskNulls() {
        this.testApplyMaskNulls(INTERPRETED_MASK_BUILDER_SUPPLIER);
        this.testApplyMaskNulls(COMPILED_MASK_BUILDER_SUPPLIER);
    }

    private void testApplyMaskNulls(Supplier<AggregationMaskBuilder> maskBuilderSupplier) {
        AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get();
        for (int positionCount = 7; positionCount < 10; ++positionCount) {
            byte[] mask = new byte[positionCount];
            Arrays.fill(mask, (byte)1);
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlock(positionCount, mask))), positionCount);
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockAsDictionary(positionCount, mask))), positionCount);
            boolean[] nullFlags = new boolean[positionCount];
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockNulls(nullFlags))), positionCount);
            Arrays.fill(nullFlags, true);
            nullFlags[1] = false;
            nullFlags[3] = false;
            nullFlags[5] = false;
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockNulls(nullFlags))), positionCount, 1, 3, 5);
            nullFlags[3] = true;
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockNulls(nullFlags))), positionCount, 1, 5);
            nullFlags[1] = true;
            nullFlags[5] = true;
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockNulls(nullFlags))), positionCount, new int[0]);
            TestAggregationMaskCompiler.assertAggregationMaskAll(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockNullsRle(positionCount, false))), positionCount);
            TestAggregationMaskCompiler.assertAggregationMaskPositions(maskBuilder.buildAggregationMask(TestAggregationMaskCompiler.buildSingleColumnPage(positionCount), Optional.of(TestAggregationMaskCompiler.createMaskBlockNullsRle(positionCount, true))), positionCount, new int[0]);
        }
    }

    private static Block createMaskBlock(int positionCount, byte[] mask) {
        return new ByteArrayBlock(positionCount, Optional.empty(), mask);
    }

    private static Block createMaskBlockRle(int positionCount, byte mask) {
        return RunLengthEncodedBlock.create((Block)TestAggregationMaskCompiler.createMaskBlock(1, new byte[]{mask}), (int)positionCount);
    }

    private static Block createMaskBlockAsDictionary(int positionCount, byte[] mask) {
        byte[] newMask = new byte[positionCount * 2];
        for (int i2 = positionCount - 1; i2 >= 0; --i2) {
            newMask[i2 * 2] = mask[i2];
            newMask[i2 * 2 + 1] = (byte)(mask[i2] == 0 ? 1 : 0);
        }
        Block block = DictionaryBlock.create((int)(positionCount * 2), (Block)new ByteArrayBlock(positionCount * 2, Optional.empty(), newMask), (int[])IntStream.range(0, positionCount * 2).toArray());
        return block.getPositions(IntStream.range(0, positionCount).map(i -> i * 2).toArray(), 0, positionCount);
    }

    private static Block createMaskBlockNulls(boolean[] nulls) {
        int positionCount = nulls.length;
        byte[] mask = new byte[positionCount];
        Arrays.fill(mask, (byte)1);
        return new ByteArrayBlock(positionCount, Optional.of(nulls), mask);
    }

    private static Block createMaskBlockNullsRle(int positionCount, boolean nullValue) {
        return RunLengthEncodedBlock.create((Block)TestAggregationMaskCompiler.createMaskBlockNulls(new boolean[]{nullValue}), (int)positionCount);
    }

    private static Page buildSingleColumnPage(int positionCount) {
        boolean[] ignoredColumnNulls = new boolean[positionCount];
        Arrays.fill(ignoredColumnNulls, true);
        return new Page(new Block[]{new ShortArrayBlock(positionCount, Optional.of(ignoredColumnNulls), new short[positionCount]), new IntArrayBlock(positionCount, Optional.empty(), new int[positionCount])});
    }

    private static Page buildSingleColumnPage(boolean[] nulls) {
        int positionCount = nulls.length;
        boolean[] ignoredColumnNulls = new boolean[positionCount];
        Arrays.fill(ignoredColumnNulls, true);
        return new Page(new Block[]{new ShortArrayBlock(positionCount, Optional.of(ignoredColumnNulls), new short[positionCount]), new IntArrayBlock(positionCount, Optional.of(nulls), new int[positionCount])});
    }

    /*
     * Exception decompiling
     */
    private static Page buildSingleColumnPageRle(int positionCount, Optional<Boolean> nullValue) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * java.lang.UnsupportedOperationException
         *     at org.benf.cfr.reader.bytecode.analysis.parse.expression.NewAnonymousArray.getDimSize(NewAnonymousArray.java:142)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.isNewArrayLambda(LambdaRewriter.java:455)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewriteDynamicExpression(LambdaRewriter.java:409)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewriteDynamicExpression(LambdaRewriter.java:167)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewriteExpression(LambdaRewriter.java:105)
         *     at org.benf.cfr.reader.bytecode.analysis.parse.rewriters.ExpressionRewriterHelper.applyForwards(ExpressionRewriterHelper.java:12)
         *     at org.benf.cfr.reader.bytecode.analysis.parse.expression.AbstractMemberFunctionInvokation.applyExpressionRewriterToArgs(AbstractMemberFunctionInvokation.java:101)
         *     at org.benf.cfr.reader.bytecode.analysis.parse.expression.AbstractMemberFunctionInvokation.applyExpressionRewriter(AbstractMemberFunctionInvokation.java:88)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewriteExpression(LambdaRewriter.java:103)
         *     at org.benf.cfr.reader.bytecode.analysis.structured.statement.StructuredAssignment.rewriteExpressions(StructuredAssignment.java:146)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewrite(LambdaRewriter.java:88)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.rewriteLambdas(Op04StructuredStatement.java:1137)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:912)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private static void assertAggregationMaskAll(AggregationMask aggregationMask, int expectedPositionCount) {
        Assertions.assertThat((boolean)aggregationMask.isSelectAll()).isTrue();
        Assertions.assertThat((boolean)aggregationMask.isSelectNone()).isEqualTo(expectedPositionCount == 0);
        Assertions.assertThat((int)aggregationMask.getPositionCount()).isEqualTo(expectedPositionCount);
        Assertions.assertThat((int)aggregationMask.getSelectedPositionCount()).isEqualTo(expectedPositionCount);
        Assertions.assertThatThrownBy(() -> ((AggregationMask)aggregationMask).getSelectedPositions()).isInstanceOf(IllegalStateException.class);
    }

    private static void assertAggregationMaskPositions(AggregationMask aggregationMask, int expectedPositionCount, int ... expectedPositions) {
        Assertions.assertThat((boolean)aggregationMask.isSelectAll()).isFalse();
        Assertions.assertThat((boolean)aggregationMask.isSelectNone()).isEqualTo(expectedPositions.length == 0);
        Assertions.assertThat((int)aggregationMask.getPositionCount()).isEqualTo(expectedPositionCount);
        Assertions.assertThat((int)aggregationMask.getSelectedPositionCount()).isEqualTo(expectedPositions.length);
        if (expectedPositions.length > 0) {
            Assertions.assertThat((int[])aggregationMask.getSelectedPositions()).startsWith(expectedPositions);
        }
    }
}

