/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import io.trino.operator.aggregation.AggregationMask;
import io.trino.spi.block.Block;
import io.trino.spi.block.ByteArrayBlock;
import io.trino.spi.block.IntArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import java.util.Arrays;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestAggregationMask {
    @Test
    public void testUnsetNulls() {
        AggregationMask aggregationMask = AggregationMask.createSelectAll((int)0);
        TestAggregationMask.assertAggregationMaskAll(aggregationMask, 0);
        for (int positionCount = 7; positionCount < 10; ++positionCount) {
            aggregationMask.reset(positionCount);
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            aggregationMask.unselectNullPositions((Block)new IntArrayBlock(positionCount, Optional.empty(), new int[positionCount]));
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            boolean[] nullFlags = new boolean[positionCount];
            aggregationMask.unselectNullPositions((Block)new IntArrayBlock(positionCount, Optional.of(nullFlags), new int[positionCount]));
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            Arrays.fill(nullFlags, true);
            nullFlags[1] = false;
            nullFlags[3] = false;
            nullFlags[5] = false;
            aggregationMask.unselectNullPositions((Block)new IntArrayBlock(positionCount, Optional.of(nullFlags), new int[positionCount]));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, 1, 3, 5);
            nullFlags[3] = true;
            aggregationMask.unselectNullPositions((Block)new IntArrayBlock(positionCount, Optional.of(nullFlags), new int[positionCount]));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, 1, 5);
            nullFlags[1] = true;
            nullFlags[5] = true;
            aggregationMask.unselectNullPositions((Block)new IntArrayBlock(positionCount, Optional.of(nullFlags), new int[positionCount]));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, new int[0]);
            aggregationMask.reset(positionCount);
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            aggregationMask.unselectNullPositions(RunLengthEncodedBlock.create((Block)new IntArrayBlock(1, Optional.empty(), new int[1]), (int)positionCount));
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            aggregationMask.unselectNullPositions(RunLengthEncodedBlock.create((Block)new IntArrayBlock(1, Optional.of(new boolean[]{false}), new int[1]), (int)positionCount));
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            aggregationMask.unselectNullPositions(RunLengthEncodedBlock.create((Block)new IntArrayBlock(1, Optional.of(new boolean[]{true}), new int[1]), (int)positionCount));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, new int[0]);
        }
    }

    @Test
    public void testApplyMask() {
        AggregationMask aggregationMask = AggregationMask.createSelectAll((int)0);
        TestAggregationMask.assertAggregationMaskAll(aggregationMask, 0);
        for (int positionCount = 7; positionCount < 10; ++positionCount) {
            aggregationMask.reset(positionCount);
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            byte[] mask = new byte[positionCount];
            Arrays.fill(mask, (byte)1);
            aggregationMask.applyMaskBlock((Block)new ByteArrayBlock(positionCount, Optional.empty(), mask));
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            Arrays.fill(mask, (byte)0);
            mask[1] = 1;
            mask[3] = 1;
            mask[5] = 1;
            aggregationMask.applyMaskBlock((Block)new ByteArrayBlock(positionCount, Optional.empty(), mask));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, 1, 3, 5);
            mask[3] = 0;
            aggregationMask.applyMaskBlock((Block)new ByteArrayBlock(positionCount, Optional.empty(), mask));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, 1, 5);
            mask[1] = 0;
            mask[5] = 0;
            aggregationMask.applyMaskBlock((Block)new ByteArrayBlock(positionCount, Optional.empty(), mask));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, new int[0]);
            aggregationMask.reset(positionCount);
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            aggregationMask.applyMaskBlock(RunLengthEncodedBlock.create((Block)new ByteArrayBlock(1, Optional.empty(), new byte[]{1}), (int)positionCount));
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            aggregationMask.applyMaskBlock(RunLengthEncodedBlock.create((Block)new ByteArrayBlock(1, Optional.empty(), new byte[]{0}), (int)positionCount));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, new int[0]);
        }
    }

    @Test
    public void testApplyMaskNulls() {
        AggregationMask aggregationMask = AggregationMask.createSelectAll((int)0);
        TestAggregationMask.assertAggregationMaskAll(aggregationMask, 0);
        for (int positionCount = 7; positionCount < 10; ++positionCount) {
            aggregationMask.reset(positionCount);
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            byte[] mask = new byte[positionCount];
            Arrays.fill(mask, (byte)1);
            aggregationMask.applyMaskBlock((Block)new ByteArrayBlock(positionCount, Optional.empty(), mask));
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            boolean[] nullFlags = new boolean[positionCount];
            aggregationMask.applyMaskBlock((Block)new ByteArrayBlock(positionCount, Optional.of(nullFlags), mask));
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            Arrays.fill(nullFlags, true);
            nullFlags[1] = false;
            nullFlags[3] = false;
            nullFlags[5] = false;
            aggregationMask.applyMaskBlock((Block)new ByteArrayBlock(positionCount, Optional.of(nullFlags), mask));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, 1, 3, 5);
            nullFlags[3] = true;
            aggregationMask.applyMaskBlock((Block)new ByteArrayBlock(positionCount, Optional.of(nullFlags), mask));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, 1, 5);
            nullFlags[1] = true;
            nullFlags[5] = true;
            aggregationMask.applyMaskBlock((Block)new ByteArrayBlock(positionCount, Optional.of(nullFlags), mask));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, new int[0]);
            aggregationMask.reset(positionCount);
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            aggregationMask.applyMaskBlock(RunLengthEncodedBlock.create((Block)new ByteArrayBlock(1, Optional.empty(), new byte[]{1}), (int)positionCount));
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            aggregationMask.applyMaskBlock(RunLengthEncodedBlock.create((Block)new ByteArrayBlock(1, Optional.of(new boolean[]{false}), new byte[]{1}), (int)positionCount));
            TestAggregationMask.assertAggregationMaskAll(aggregationMask, positionCount);
            aggregationMask.applyMaskBlock(RunLengthEncodedBlock.create((Block)new ByteArrayBlock(1, Optional.of(new boolean[]{true}), new byte[]{1}), (int)positionCount));
            TestAggregationMask.assertAggregationMaskPositions(aggregationMask, positionCount, new int[0]);
        }
    }

    private static void assertAggregationMaskAll(AggregationMask aggregationMask, int expectedPositionCount) {
        Assertions.assertThat((boolean)aggregationMask.isSelectAll()).isTrue();
        Assertions.assertThat((boolean)aggregationMask.isSelectNone()).isEqualTo(expectedPositionCount == 0);
        Assertions.assertThat((int)aggregationMask.getPositionCount()).isEqualTo(expectedPositionCount);
        Assertions.assertThat((int)aggregationMask.getSelectedPositionCount()).isEqualTo(expectedPositionCount);
        Assertions.assertThatThrownBy(() -> ((AggregationMask)aggregationMask).getSelectedPositions()).isInstanceOf(IllegalStateException.class);
    }

    private static void assertAggregationMaskPositions(AggregationMask aggregationMask, int expectedPositionCount, int ... expectedPositions) {
        Assertions.assertThat((boolean)aggregationMask.isSelectAll()).isFalse();
        Assertions.assertThat((boolean)aggregationMask.isSelectNone()).isEqualTo(expectedPositions.length == 0);
        Assertions.assertThat((int)aggregationMask.getPositionCount()).isEqualTo(expectedPositionCount);
        Assertions.assertThat((int)aggregationMask.getSelectedPositionCount()).isEqualTo(expectedPositions.length);
        if (expectedPositions.length > 0) {
            Assertions.assertThat((int[])aggregationMask.getSelectedPositions()).startsWith(expectedPositions);
        }
    }
}

