/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.fn.harness.state;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.fn.harness.state.FakeBeamFnStateClient;
import org.apache.beam.fn.harness.state.StateFetchingIterators;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Ints;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(value=Enclosed.class)
public class StateFetchingIteratorsTest {
    private static BeamFnStateClient fakeStateClient(AtomicInteger callCount, ByteString ... expected) {
        return requestBuilder -> {
            callCount.incrementAndGet();
            if (expected.length == 0) {
                return CompletableFuture.completedFuture(BeamFnApi.StateResponse.newBuilder().setId(requestBuilder.getId()).setGet(BeamFnApi.StateGetResponse.newBuilder()).build());
            }
            ByteString continuationToken = requestBuilder.getGet().getContinuationToken();
            int requestedPosition = 0;
            if (!ByteString.EMPTY.equals((Object)continuationToken)) {
                requestedPosition = Integer.parseInt(continuationToken.toStringUtf8());
            }
            ByteString newContinuationToken = ByteString.EMPTY;
            if (requestedPosition != expected.length - 1) {
                newContinuationToken = ByteString.copyFromUtf8((String)Integer.toString(requestedPosition + 1));
            }
            return CompletableFuture.completedFuture(BeamFnApi.StateResponse.newBuilder().setId(requestBuilder.getId()).setGet(BeamFnApi.StateGetResponse.newBuilder().setData(expected[requestedPosition]).setContinuationToken(newContinuationToken)).build());
        };
    }

    @RunWith(value=JUnit4.class)
    public static class LazyBlockingStateFetchingIteratorTest {
        @Test
        public void testEmpty() throws Exception {
            this.testFetch(ByteString.EMPTY);
        }

        @Test
        public void testNonEmpty() throws Exception {
            this.testFetch(ByteString.copyFromUtf8((String)"A"));
        }

        @Test
        public void testWithLastByteStringBeingEmpty() throws Exception {
            this.testFetch(ByteString.copyFromUtf8((String)"A"), ByteString.EMPTY);
        }

        @Test
        public void testMulti() throws Exception {
            this.testFetch(ByteString.copyFromUtf8((String)"BC"), ByteString.copyFromUtf8((String)"DEF"));
        }

        @Test
        public void testMultiWithEmptyByteStrings() throws Exception {
            this.testFetch(ByteString.EMPTY, ByteString.copyFromUtf8((String)"BC"), ByteString.EMPTY, ByteString.EMPTY, ByteString.copyFromUtf8((String)"DEF"), ByteString.EMPTY);
        }

        @Test
        public void testPrefetchIgnoredWhenExistingPrefetchOngoing() throws Exception {
            final AtomicInteger callCount = new AtomicInteger();
            BeamFnStateClient fakeStateClient = new BeamFnStateClient(){

                public CompletableFuture<BeamFnApi.StateResponse> handle(BeamFnApi.StateRequest.Builder requestBuilder) {
                    callCount.incrementAndGet();
                    return new CompletableFuture<BeamFnApi.StateResponse>();
                }
            };
            StateFetchingIterators.LazyBlockingStateFetchingIterator byteStrings = new StateFetchingIterators.LazyBlockingStateFetchingIterator(fakeStateClient, BeamFnApi.StateRequest.getDefaultInstance());
            Assert.assertEquals(0L, callCount.get());
            byteStrings.prefetch();
            Assert.assertEquals(1L, callCount.get());
            byteStrings.prefetch();
            Assert.assertEquals(1L, callCount.get());
        }

        @Test
        public void testSeekToContinuationToken() throws Exception {
            BeamFnStateClient fakeStateClient = new BeamFnStateClient(){

                public CompletableFuture<BeamFnApi.StateResponse> handle(BeamFnApi.StateRequest.Builder requestBuilder) {
                    int token = 0;
                    if (!ByteString.EMPTY.equals((Object)requestBuilder.getGet().getContinuationToken())) {
                        token = Integer.parseInt(requestBuilder.getGet().getContinuationToken().toStringUtf8());
                    }
                    return CompletableFuture.completedFuture(BeamFnApi.StateResponse.newBuilder().setGet(BeamFnApi.StateGetResponse.newBuilder().setData(ByteString.copyFromUtf8((String)("value" + token))).setContinuationToken(ByteString.copyFromUtf8((String)Integer.toString(token + 1)))).build());
                }
            };
            StateFetchingIterators.LazyBlockingStateFetchingIterator byteStrings = new StateFetchingIterators.LazyBlockingStateFetchingIterator(fakeStateClient, BeamFnApi.StateRequest.getDefaultInstance());
            Assert.assertEquals(ByteString.copyFromUtf8((String)"value0"), byteStrings.next());
            Assert.assertEquals(ByteString.copyFromUtf8((String)"value1"), byteStrings.next());
            Assert.assertEquals(ByteString.copyFromUtf8((String)"value2"), byteStrings.next());
            byteStrings.seekToContinuationToken(ByteString.EMPTY);
            Assert.assertEquals(ByteString.copyFromUtf8((String)"value0"), byteStrings.next());
            Assert.assertEquals(ByteString.copyFromUtf8((String)"value1"), byteStrings.next());
            Assert.assertEquals(ByteString.copyFromUtf8((String)"value2"), byteStrings.next());
            byteStrings.seekToContinuationToken(ByteString.copyFromUtf8((String)"42"));
            Assert.assertEquals(ByteString.copyFromUtf8((String)"value42"), byteStrings.next());
            Assert.assertEquals(ByteString.copyFromUtf8((String)"value43"), byteStrings.next());
            Assert.assertEquals(ByteString.copyFromUtf8((String)"value44"), byteStrings.next());
        }

        private void testFetch(ByteString ... expected) {
            AtomicInteger callCount = new AtomicInteger();
            BeamFnStateClient fakeStateClient = StateFetchingIteratorsTest.fakeStateClient(callCount, expected);
            StateFetchingIterators.LazyBlockingStateFetchingIterator byteStrings = new StateFetchingIterators.LazyBlockingStateFetchingIterator(fakeStateClient, BeamFnApi.StateRequest.getDefaultInstance());
            Assert.assertEquals(0L, callCount.get());
            Assert.assertFalse(byteStrings.isReady());
            ArrayList<ByteString> results = new ArrayList<ByteString>();
            for (int i = 0; i < expected.length; ++i) {
                if (i % 2 == 0) {
                    byteStrings.prefetch();
                    Assert.assertEquals(i + 1, callCount.get());
                    Assert.assertTrue(byteStrings.isReady());
                }
                Assert.assertTrue(byteStrings.hasNext());
                results.add((ByteString)byteStrings.next());
            }
            Assert.assertFalse(byteStrings.hasNext());
            Assert.assertTrue(byteStrings.isReady());
            Assert.assertEquals(Arrays.asList(expected), results);
        }
    }

    @RunWith(value=JUnit4.class)
    public static class CachingStateIterableTest {
        @Test
        public void testEmpty() throws Exception {
            this.testFetchAndClear(4, new int[0]);
        }

        @Test
        public void testNonEmpty() throws Exception {
            this.testFetchAndClear(4, 0);
        }

        @Test
        public void testMultipleElementsPerChunk() throws Exception {
            this.testFetchAndClear(8, 0, 1, 2, 3, 4, 5);
        }

        @Test
        public void testSingleElementPerChunk() throws Exception {
            this.testFetchAndClear(4, 0, 1, 2, 3, 4, 5);
        }

        @Test
        public void testChunkSmallerThenElementSize() throws Exception {
            this.testFetchAndClear(3, 0, 1, 2, 3, 4, 5);
        }

        @Test
        public void testChunkLargerThenElementSize() throws Exception {
            this.testFetchAndClear(5, 0, 1, 2, 3, 4, 5);
        }

        @Test
        public void testAppend() throws Exception {
            int[] expected = new int[]{0, 1, 2, 3, 4, 5};
            StateFetchingIterators.CachingStateIterable<Integer> iterable = this.create(5, expected);
            iterable.append(Ints.asList((int[])new int[]{42, 43}));
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), expected);
            iterable = this.create(5, expected);
            iterable.append(Ints.asList((int[])new int[]{42, 43}));
            Boolean ignored = iterable.iterator().hasNext();
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), expected);
            iterable = this.create(5, expected);
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), expected);
            iterable.append(Ints.asList((int[])new int[]{42, 43}));
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), 0, 1, 2, 3, 4, 5, 42, 43);
        }

        @Test
        public void testRemove() throws Exception {
            int[] expected = new int[]{0, 1, 2, 3, 4, 5};
            HashSet<Object> toRemove = new HashSet<Object>();
            toRemove.add(BigEndianIntegerCoder.of().structuralValue((Object)2));
            toRemove.add(BigEndianIntegerCoder.of().structuralValue((Object)4));
            StateFetchingIterators.CachingStateIterable<Integer> iterable = this.create(5, expected);
            iterable.remove(toRemove);
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), expected);
            iterable = this.create(5, expected);
            iterable.remove(toRemove);
            Boolean ignored = iterable.iterator().hasNext();
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), expected);
            iterable = this.create(5, expected);
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), expected);
            iterable.remove(toRemove);
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), 0, 1, 3, 5);
        }

        @Test
        public void testCacheEvictionOrphansIteratorAndAllowsForIteratorToRejoin() throws Exception {
            int[] expected = new int[]{0, 1, 2, 3, 4, 5};
            BeamFnApi.StateRequest requestForFirstChunk = BeamFnApi.StateRequest.newBuilder().setStateKey(BeamFnApi.StateKey.newBuilder().setBagUserState(BeamFnApi.StateKey.BagUserState.newBuilder().setTransformId("transformId").setUserStateId("stateId").setKey(ByteString.copyFromUtf8((String)"key")).setWindow(ByteString.copyFromUtf8((String)"window")))).setGet(BeamFnApi.StateGetRequest.getDefaultInstance()).build();
            FakeBeamFnStateClient fakeStateClient = new FakeBeamFnStateClient(BigEndianIntegerCoder.of(), ImmutableMap.of((Object)requestForFirstChunk.getStateKey(), (Object)Ints.asList((int[])expected)), 4);
            Cache cache = Caches.eternal();
            StateFetchingIterators.CachingStateIterable iterable = new StateFetchingIterators.CachingStateIterable(cache, (BeamFnStateClient)fakeStateClient, requestForFirstChunk, (Coder)BigEndianIntegerCoder.of());
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), expected);
            MatcherAssert.assertThat((StateFetchingIterators.CachingStateIterable.Blocks)cache.peek((Object)StateFetchingIterators.IterableCacheKey.INSTANCE), Matchers.is(Matchers.instanceOf(StateFetchingIterators.CachingStateIterable.BlocksPrefix.class)));
            int stateRequestCount = fakeStateClient.getCallCount();
            PrefetchableIterator iterator = iterable.iterator();
            Assert.assertEquals(0L, ((Integer)iterator.next()).intValue());
            Assert.assertEquals(stateRequestCount, fakeStateClient.getCallCount());
            stateRequestCount = fakeStateClient.getCallCount();
            cache.remove((Object)StateFetchingIterators.IterableCacheKey.INSTANCE);
            Assert.assertEquals(1L, ((Integer)iterator.next()).intValue());
            Assert.assertEquals(stateRequestCount += 2, fakeStateClient.getCallCount());
            Assert.assertEquals(2L, ((Integer)iterator.next()).intValue());
            Assert.assertEquals(++stateRequestCount, fakeStateClient.getCallCount());
            Assert.assertNull(cache.peek((Object)StateFetchingIterators.IterableCacheKey.INSTANCE));
            stateRequestCount = fakeStateClient.getCallCount();
            PrefetchableIterator iterator2 = iterable.iterator();
            Assert.assertEquals(0L, ((Integer)iterator2.next()).intValue());
            MatcherAssert.assertThat((StateFetchingIterators.CachingStateIterable.Blocks)cache.peek((Object)StateFetchingIterators.IterableCacheKey.INSTANCE), Matchers.is(Matchers.instanceOf(StateFetchingIterators.CachingStateIterable.BlocksPrefix.class)));
            Assert.assertTrue(stateRequestCount < fakeStateClient.getCallCount());
            Assert.assertEquals(1L, ((Integer)iterator2.next()).intValue());
            Assert.assertEquals(2L, ((Integer)iterator2.next()).intValue());
            Assert.assertEquals(3L, ((Integer)iterator2.next()).intValue());
            Assert.assertEquals(4L, ((Integer)iterator2.next()).intValue());
            stateRequestCount = fakeStateClient.getCallCount();
            Assert.assertEquals(3L, ((Integer)iterator.next()).intValue());
            Assert.assertEquals(stateRequestCount, fakeStateClient.getCallCount());
        }

        @Test
        public void testBlocksPrefixShrinkage() throws Exception {
            List<StateFetchingIterators.CachingStateIterable.Block> originalBlocks = Arrays.asList(StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("A"), null), StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("B"), null), StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("C"), null), StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("D"), null), StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("E"), null), StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("F"), null));
            StateFetchingIterators.CachingStateIterable.BlocksPrefix blocks = new StateFetchingIterators.CachingStateIterable.BlocksPrefix(originalBlocks);
            StateFetchingIterators.CachingStateIterable.BlocksPrefix abcBlocks = blocks.shrink();
            StateFetchingIterators.CachingStateIterable.BlocksPrefix aBlocks = abcBlocks.shrink();
            MatcherAssert.assertThat(abcBlocks.getBlocks(), Matchers.contains(originalBlocks.get(0), originalBlocks.get(1), originalBlocks.get(2)));
            MatcherAssert.assertThat(aBlocks.getBlocks(), Matchers.contains(originalBlocks.get(0)));
            Assert.assertNull(aBlocks.shrink());
        }

        @Test
        public void testBlocksWeight() throws Exception {
            List<StateFetchingIterators.CachingStateIterable.Block> originalBlocks = Arrays.asList(StateFetchingIterators.CachingStateIterable.Block.mutatedBlock(Arrays.asList("A"), (long)10L), StateFetchingIterators.CachingStateIterable.Block.mutatedBlock(Arrays.asList("B"), (long)0x3FFFFFFFFFFFFFFFL), StateFetchingIterators.CachingStateIterable.Block.mutatedBlock(Arrays.asList("C"), (long)0x3FFFFFFFFFFFFFFFL), StateFetchingIterators.CachingStateIterable.Block.mutatedBlock(Arrays.asList("D"), (long)5L));
            StateFetchingIterators.CachingStateIterable.BlocksPrefix blocks = new StateFetchingIterators.CachingStateIterable.BlocksPrefix(originalBlocks.subList(0, 2));
            Assert.assertEquals(0x4000000000000009L, blocks.getWeight());
            StateFetchingIterators.CachingStateIterable.BlocksPrefix blocksOverflow = new StateFetchingIterators.CachingStateIterable.BlocksPrefix(originalBlocks);
            Assert.assertEquals(Long.MAX_VALUE, blocksOverflow.getWeight());
        }

        private StateFetchingIterators.CachingStateIterable<Integer> create(int chunkSize, int ... values) {
            BeamFnApi.StateRequest requestForFirstChunk = BeamFnApi.StateRequest.newBuilder().setStateKey(BeamFnApi.StateKey.newBuilder().setBagUserState(BeamFnApi.StateKey.BagUserState.newBuilder().setTransformId("transformId").setUserStateId("stateId").setKey(ByteString.copyFromUtf8((String)"key")).setWindow(ByteString.copyFromUtf8((String)"window")))).setGet(BeamFnApi.StateGetRequest.getDefaultInstance()).build();
            FakeBeamFnStateClient fakeStateClient = new FakeBeamFnStateClient(BigEndianIntegerCoder.of(), ImmutableMap.of((Object)requestForFirstChunk.getStateKey(), (Object)Ints.asList((int[])values)), chunkSize);
            StateFetchingIterators.CachingStateIterable iterable = new StateFetchingIterators.CachingStateIterable(Caches.eternal(), (BeamFnStateClient)fakeStateClient, requestForFirstChunk, (Coder)BigEndianIntegerCoder.of());
            PrefetchableIterator ignored = iterable.iterator();
            Assert.assertEquals(0L, fakeStateClient.getCallCount());
            return iterable;
        }

        private void testFetchAndClear(int chunkSize, int ... expected) throws Exception {
            PrefetchableIterator iterator = this.create(chunkSize, expected).iterator();
            Assert.assertFalse(iterator.isReady());
            this.verifyFetch((PrefetchableIterator<Integer>)iterator, expected);
            this.verifyClear(chunkSize, expected);
        }

        private void verifyFetch(PrefetchableIterator<Integer> iterator, int ... expected) {
            ArrayList<Integer> results = new ArrayList<Integer>();
            for (int i = 0; i < expected.length; ++i) {
                Assert.assertTrue(iterator.hasNext());
                results.add((Integer)iterator.next());
            }
            Assert.assertFalse(iterator.hasNext());
            Assert.assertTrue(iterator.isReady());
            Assert.assertEquals(Ints.asList((int[])expected), results);
        }

        private void verifyClear(int chunkSize, int ... expected) throws Exception {
            StateFetchingIterators.CachingStateIterable<Integer> iterable = this.create(chunkSize, expected);
            iterable.clearAndAppend(Ints.asList((int[])new int[]{42, 43}));
            Assert.assertTrue(iterable.iterator().isReady());
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), 42, 43);
            iterable = this.create(chunkSize, expected);
            Boolean ignored = iterable.iterator().hasNext();
            iterable.clearAndAppend(Ints.asList((int[])new int[]{42, 43}));
            Assert.assertTrue(iterable.iterator().isReady());
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), 42, 43);
            iterable = this.create(chunkSize, expected);
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), expected);
            iterable.clearAndAppend(Ints.asList((int[])new int[]{42, 43}));
            Assert.assertTrue(iterable.iterator().isReady());
            this.verifyFetch((PrefetchableIterator<Integer>)iterable.iterator(), 42, 43);
        }
    }
}

