/*
 * Decompiled with CFR 0.152.
 */
package com.google.cloud.spanner;

import com.google.api.core.NanoClock;
import com.google.api.gax.retrying.RetrySettings;
import com.google.cloud.grpc.GrpcTransportOptions;
import com.google.cloud.spanner.DatabaseId;
import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SessionClient;
import com.google.cloud.spanner.SessionImpl;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.SpannerImpl;
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spanner.spi.v1.SpannerRpc;
import com.google.common.truth.Truth;
import com.google.spanner.v1.Session;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

@RunWith(value=Parameterized.class)
public class SessionClientTest {
    @Parameterized.Parameter
    public int numChannels;
    private final String dbName = "projects/p1/instances/i1/databases/d1";
    @Mock
    private SpannerImpl spanner;
    @Mock
    private SpannerRpc rpc;
    @Mock
    private SpannerOptions spannerOptions;
    @Captor
    ArgumentCaptor<Map<SpannerRpc.Option, Object>> options;

    @Parameterized.Parameters(name="NumChannels = {0}")
    public static Collection<Object[]> data() {
        return Arrays.asList({1}, {2}, {4}, {8});
    }

    @Before
    public void setUp() {
        MockitoAnnotations.initMocks((Object)this);
        GrpcTransportOptions transportOptions = (GrpcTransportOptions)Mockito.mock(GrpcTransportOptions.class);
        Mockito.when((Object)transportOptions.getExecutorFactory()).thenReturn((Object)new GrpcTransportOptions.ExecutorFactory<ScheduledExecutorService>(){

            public void release(ScheduledExecutorService executor) {
                executor.shutdown();
            }

            public ScheduledExecutorService get() {
                return new ScheduledThreadPoolExecutor(2);
            }
        });
        Mockito.when((Object)this.spannerOptions.getTransportOptions()).thenReturn((Object)transportOptions);
        Mockito.when((Object)this.spannerOptions.getNumChannels()).thenReturn((Object)this.numChannels);
        Mockito.when((Object)this.spannerOptions.getPrefetchChunks()).thenReturn((Object)1);
        Mockito.when((Object)this.spannerOptions.getRetrySettings()).thenReturn((Object)RetrySettings.newBuilder().build());
        Mockito.when((Object)this.spannerOptions.getClock()).thenReturn((Object)NanoClock.getDefaultClock());
        Mockito.when((Object)this.spanner.getOptions()).thenReturn((Object)this.spannerOptions);
        Mockito.when((Object)this.spanner.getRpc()).thenReturn((Object)this.rpc);
    }

    @Test
    public void createAndCloseSession() {
        DatabaseId db = DatabaseId.of((String)"projects/p1/instances/i1/databases/d1");
        String sessionName = "projects/p1/instances/i1/databases/d1/sessions/s1";
        HashMap<String, String> labels = new HashMap<String, String>();
        labels.put("env", "dev");
        Mockito.when((Object)this.spannerOptions.getSessionLabels()).thenReturn(labels);
        Session sessionProto = Session.newBuilder().setName(sessionName).putAllLabels(labels).build();
        Mockito.when((Object)this.rpc.createSession((String)Mockito.eq((Object)"projects/p1/instances/i1/databases/d1"), (Map)Mockito.eq(labels), (Map)this.options.capture())).thenReturn((Object)sessionProto);
        try (SessionClient client = new SessionClient(this.spanner, db, (GrpcTransportOptions.ExecutorFactory)new TestExecutorFactory());){
            SessionImpl session = client.createSession();
            Truth.assertThat((String)session.getName()).isEqualTo((Object)sessionName);
            session.close();
            ((SpannerRpc)Mockito.verify((Object)this.rpc)).deleteSession(sessionName, (Map)this.options.getValue());
        }
    }

    @Test
    public void batchCreateAndCloseSessions() {
        DatabaseId db = DatabaseId.of((String)"projects/p1/instances/i1/databases/d1");
        String sessionName = "projects/p1/instances/i1/databases/d1/sessions/s%d";
        HashMap<String, String> labels = new HashMap<String, String>();
        labels.put("env", "dev");
        Mockito.when((Object)this.spannerOptions.getSessionLabels()).thenReturn(labels);
        List usedChannels = Collections.synchronizedList(new ArrayList());
        Mockito.when((Object)this.rpc.batchCreateSessions((String)Mockito.eq((Object)"projects/p1/instances/i1/databases/d1"), Mockito.anyInt(), (Map)Mockito.eq(labels), Mockito.anyMap())).then(invocation -> {
            Map options = (Map)invocation.getArgumentAt(3, Map.class);
            Long channelHint = (Long)options.get(SpannerRpc.Option.CHANNEL_HINT);
            usedChannels.add(channelHint);
            int sessionCount = (Integer)invocation.getArgumentAt(1, Integer.class);
            ArrayList<Session> res = new ArrayList<Session>();
            for (int i = 1; i <= sessionCount; ++i) {
                res.add(Session.newBuilder().setName(String.format("projects/p1/instances/i1/databases/d1/sessions/s%d", i)).putAllLabels(labels).build());
            }
            return res;
        });
        final AtomicInteger returnedSessionCount = new AtomicInteger();
        SessionClient.SessionConsumer consumer = new SessionClient.SessionConsumer(){

            public void onSessionReady(SessionImpl session) {
                Truth.assertThat((String)session.getName()).startsWith("projects/p1/instances/i1/databases/d1/sessions/s");
                returnedSessionCount.incrementAndGet();
                session.close();
            }

            public void onSessionCreateFailure(Throwable t, int createFailureForSessionCount) {
            }
        };
        int numSessions = 10;
        try (SessionClient client = new SessionClient(this.spanner, db, (GrpcTransportOptions.ExecutorFactory)new TestExecutorFactory());){
            client.asyncBatchCreateSessions(10, true, consumer);
        }
        Truth.assertThat((Integer)returnedSessionCount.get()).isEqualTo((Object)10);
        Truth.assertThat((Integer)usedChannels.size()).isEqualTo((Object)this.spannerOptions.getNumChannels());
        ArrayList<Long> expectedChannels = new ArrayList<Long>();
        for (long l = 0L; l < (long)this.spannerOptions.getNumChannels(); ++l) {
            expectedChannels.add(l);
        }
        Truth.assertThat(usedChannels).containsExactlyElementsIn(expectedChannels);
    }

    @Test
    public void batchCreateSessionsDistributesMultipleRequestsOverChannels() {
        DatabaseId db = DatabaseId.of((String)"projects/p1/instances/i1/databases/d1");
        String sessionName = "projects/p1/instances/i1/databases/d1/sessions/s%d";
        Map labels = Collections.emptyMap();
        Mockito.when((Object)this.spannerOptions.getSessionLabels()).thenReturn(labels);
        Set usedChannelHints = Collections.synchronizedSet(new HashSet());
        Mockito.when((Object)this.rpc.batchCreateSessions((String)Mockito.eq((Object)"projects/p1/instances/i1/databases/d1"), Mockito.anyInt(), (Map)Mockito.eq(labels), Mockito.anyMap())).then(invocation -> {
            Map options = (Map)invocation.getArgumentAt(3, Map.class);
            Long channelHint = (Long)options.get(SpannerRpc.Option.CHANNEL_HINT);
            usedChannelHints.add(channelHint);
            int sessionCount = (Integer)invocation.getArgumentAt(1, Integer.class);
            ArrayList<Session> res = new ArrayList<Session>();
            for (int i = 1; i <= sessionCount; ++i) {
                res.add(Session.newBuilder().setName(String.format("projects/p1/instances/i1/databases/d1/sessions/s%d", i)).putAllLabels(labels).build());
            }
            return res;
        });
        final AtomicInteger returnedSessionCount = new AtomicInteger();
        SessionClient.SessionConsumer consumer = new SessionClient.SessionConsumer(){

            public void onSessionReady(SessionImpl session) {
                Truth.assertThat((String)session.getName()).startsWith("projects/p1/instances/i1/databases/d1/sessions/s");
                returnedSessionCount.incrementAndGet();
                session.close();
            }

            public void onSessionCreateFailure(Throwable t, int createFailureForSessionCount) {
            }
        };
        int numSessions = 10;
        int numBatches = this.spannerOptions.getNumChannels() * 2;
        try (SessionClient client = new SessionClient(this.spanner, db, (GrpcTransportOptions.ExecutorFactory)new TestExecutorFactory());){
            for (int batch = 0; batch < numBatches; ++batch) {
                client.asyncBatchCreateSessions(10, false, consumer);
            }
        }
        Truth.assertThat((Integer)returnedSessionCount.get()).isEqualTo((Object)(10 * numBatches));
        Truth.assertThat((Integer)usedChannelHints.size()).isEqualTo((Object)(this.spannerOptions.getNumChannels() * 2));
        ArrayList<Long> expectedChannels = new ArrayList<Long>();
        for (long l = 0L; l < (long)(this.spannerOptions.getNumChannels() * 2); ++l) {
            expectedChannels.add(l);
        }
        Truth.assertThat(usedChannelHints).containsExactlyElementsIn(expectedChannels);
    }

    @Test
    public void batchCreateSessionsWithExceptions() {
        for (AddRemoveSetException behavior : AddRemoveSetException.values()) {
            ArrayList<Long> errorOnChannels = new ArrayList<Long>();
            if (behavior == AddRemoveSetException.REMOVE) {
                for (int c = 0; c < this.spannerOptions.getNumChannels(); ++c) {
                    errorOnChannels.add(Long.valueOf(c));
                }
            }
            for (int errorOnChannel = 0; errorOnChannel < this.spannerOptions.getNumChannels(); ++errorOnChannel) {
                switch (behavior) {
                    case SET: {
                        errorOnChannels.clear();
                    }
                    case ADD: {
                        errorOnChannels.add(Long.valueOf(errorOnChannel));
                        break;
                    }
                    case REMOVE: {
                        errorOnChannels.remove((Object)errorOnChannel);
                        break;
                    }
                    default: {
                        throw new IllegalStateException();
                    }
                }
                DatabaseId db = DatabaseId.of((String)"projects/p1/instances/i1/databases/d1");
                String sessionName = "projects/p1/instances/i1/databases/d1/sessions/s%d";
                Mockito.when((Object)this.rpc.batchCreateSessions((String)Mockito.eq((Object)"projects/p1/instances/i1/databases/d1"), Mockito.anyInt(), Mockito.anyMap(), Mockito.anyMap())).then(invocation -> {
                    Map options = (Map)invocation.getArgumentAt(3, Map.class);
                    Long channelHint = (Long)options.get(SpannerRpc.Option.CHANNEL_HINT);
                    if (errorOnChannels.contains(channelHint)) {
                        throw SpannerExceptionFactory.newSpannerException((ErrorCode)ErrorCode.RESOURCE_EXHAUSTED, (String)"could not create any more sessions");
                    }
                    int sessionCount = (Integer)invocation.getArgumentAt(1, Integer.class);
                    ArrayList<Session> res = new ArrayList<Session>();
                    for (int i = 1; i <= sessionCount; ++i) {
                        res.add(Session.newBuilder().setName(String.format("projects/p1/instances/i1/databases/d1/sessions/s%d", i)).build());
                    }
                    return res;
                });
                final AtomicInteger errorForSessionsCount = new AtomicInteger();
                final AtomicInteger errorCount = new AtomicInteger();
                final AtomicInteger returnedSessionCount = new AtomicInteger();
                SessionClient.SessionConsumer consumer = new SessionClient.SessionConsumer(){

                    public void onSessionReady(SessionImpl session) {
                        Truth.assertThat((String)session.getName()).startsWith("projects/p1/instances/i1/databases/d1/sessions/s");
                        returnedSessionCount.incrementAndGet();
                        session.close();
                    }

                    public void onSessionCreateFailure(Throwable t, int createFailureForSessionCount) {
                        Truth.assertThat((Throwable)t).isInstanceOf(SpannerException.class);
                        SpannerException e = (SpannerException)t;
                        Truth.assertThat((Comparable)e.getErrorCode()).isEqualTo((Object)ErrorCode.RESOURCE_EXHAUSTED);
                        errorCount.incrementAndGet();
                        errorForSessionsCount.addAndGet(createFailureForSessionCount);
                    }
                };
                int numSessions = 10;
                try (SessionClient client = new SessionClient(this.spanner, db, (GrpcTransportOptions.ExecutorFactory)new TestExecutorFactory());){
                    client.asyncBatchCreateSessions(10, true, consumer);
                }
                Truth.assertThat((Integer)errorCount.get()).isEqualTo((Object)errorOnChannels.size());
                Truth.assertThat((Integer)returnedSessionCount.get()).isAtLeast((Comparable)Integer.valueOf(10 - (10 / this.spannerOptions.getNumChannels() * errorOnChannels.size() + 10 % this.spannerOptions.getNumChannels())));
                Truth.assertThat((Integer)(returnedSessionCount.get() + errorForSessionsCount.get())).isEqualTo((Object)10);
            }
        }
    }

    @Test
    public void batchCreateSessionsServerReturnsLessSessionsPerBatch() {
        int MAX_SESSIONS_PER_BATCH = 5;
        DatabaseId db = DatabaseId.of((String)"projects/p1/instances/i1/databases/d1");
        String sessionName = "projects/p1/instances/i1/databases/d1/sessions/s%d";
        Mockito.when((Object)this.rpc.batchCreateSessions((String)Mockito.eq((Object)"projects/p1/instances/i1/databases/d1"), Mockito.anyInt(), Mockito.anyMap(), Mockito.anyMap())).then(invocation -> {
            int sessionCount = (Integer)invocation.getArgumentAt(1, Integer.class);
            ArrayList<Session> res = new ArrayList<Session>();
            for (int i = 1; i <= Math.min(5, sessionCount); ++i) {
                res.add(Session.newBuilder().setName(String.format("projects/p1/instances/i1/databases/d1/sessions/s%d", i)).build());
            }
            return res;
        });
        final AtomicInteger returnedSessionCount = new AtomicInteger();
        SessionClient.SessionConsumer consumer = new SessionClient.SessionConsumer(){

            public void onSessionReady(SessionImpl session) {
                Truth.assertThat((String)session.getName()).startsWith("projects/p1/instances/i1/databases/d1/sessions/s");
                returnedSessionCount.incrementAndGet();
                session.close();
            }

            public void onSessionCreateFailure(Throwable t, int createFailureForSessionCount) {
            }
        };
        int numSessions = 100;
        try (SessionClient client = new SessionClient(this.spanner, db, (GrpcTransportOptions.ExecutorFactory)new TestExecutorFactory());){
            client.asyncBatchCreateSessions(100, true, consumer);
        }
        Truth.assertThat((Integer)returnedSessionCount.get()).isEqualTo((Object)100);
    }

    private static enum AddRemoveSetException {
        SET,
        ADD,
        REMOVE;

    }

    private final class TestExecutorFactory
    implements GrpcTransportOptions.ExecutorFactory<ScheduledExecutorService> {
        private TestExecutorFactory() {
        }

        public ScheduledExecutorService get() {
            return Executors.newScheduledThreadPool(((SpannerOptions)SessionClientTest.this.spanner.getOptions()).getNumChannels());
        }

        public void release(ScheduledExecutorService executor) {
            executor.shutdown();
            try {
                executor.awaitTermination(10000L, TimeUnit.SECONDS);
            }
            catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
    }
}

