/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.fn.harness.state;

import com.google.auto.value.AutoValue;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.apache.beam.fn.harness.state.AutoValue_CachingBeamFnStateClient_StateCacheKey;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LoadingCache;

public class CachingBeamFnStateClient
implements BeamFnStateClient {
    private final BeamFnStateClient beamFnStateClient;
    private final LoadingCache<BeamFnApi.StateKey, Map<StateCacheKey, BeamFnApi.StateGetResponse>> stateCache;
    private final Map<BeamFnApi.ProcessBundleRequest.CacheToken.SideInput, ByteString> sideInputCacheTokens;
    private final ByteString userStateToken;

    public CachingBeamFnStateClient(BeamFnStateClient beamFnStateClient, LoadingCache<BeamFnApi.StateKey, Map<StateCacheKey, BeamFnApi.StateGetResponse>> stateCache, List<BeamFnApi.ProcessBundleRequest.CacheToken> cacheTokenList) {
        this.beamFnStateClient = beamFnStateClient;
        this.stateCache = stateCache;
        this.sideInputCacheTokens = new HashMap<BeamFnApi.ProcessBundleRequest.CacheToken.SideInput, ByteString>();
        ByteString tempUserStateToken = ByteString.EMPTY;
        for (BeamFnApi.ProcessBundleRequest.CacheToken token : cacheTokenList) {
            if (token.hasUserState()) {
                tempUserStateToken = token.getToken();
                continue;
            }
            if (!token.hasSideInput()) continue;
            this.sideInputCacheTokens.put(token.getSideInput(), token.getToken());
        }
        this.userStateToken = tempUserStateToken;
    }

    @Override
    public CompletableFuture<BeamFnApi.StateResponse> handle(BeamFnApi.StateRequest.Builder requestBuilder) {
        BeamFnApi.StateKey stateKey = requestBuilder.getStateKey();
        ByteString cacheToken = this.getCacheToken(stateKey);
        if (ByteString.EMPTY.equals(cacheToken)) {
            return this.beamFnStateClient.handle(requestBuilder);
        }
        switch (requestBuilder.getRequestCase()) {
            case GET: {
                StateCacheKey cacheKey = StateCacheKey.create(cacheToken, requestBuilder.getGet().getContinuationToken());
                Map<StateCacheKey, BeamFnApi.StateGetResponse> stateKeyMap = this.stateCache.getUnchecked(stateKey);
                BeamFnApi.StateGetResponse cachedPage = stateKeyMap.get(cacheKey);
                if (cachedPage != null) {
                    return CompletableFuture.completedFuture(BeamFnApi.StateResponse.newBuilder().setId(requestBuilder.getId()).setGet(cachedPage).build());
                }
                CompletableFuture<BeamFnApi.StateResponse> response = this.beamFnStateClient.handle(requestBuilder);
                response.thenAccept(stateResponse -> this.stateCache.getUnchecked(stateKey).put(cacheKey, stateResponse.getGet()));
                return response;
            }
            case APPEND: {
                CompletableFuture<BeamFnApi.StateResponse> response = this.beamFnStateClient.handle(requestBuilder);
                Map<StateCacheKey, BeamFnApi.StateGetResponse> map = this.stateCache.getUnchecked(stateKey);
                map.entrySet().removeIf(entry -> ((BeamFnApi.StateGetResponse)entry.getValue()).getContinuationToken().equals(ByteString.EMPTY));
                return response;
            }
            case CLEAR: {
                CompletableFuture<BeamFnApi.StateResponse> response = this.beamFnStateClient.handle(requestBuilder);
                HashMap<StateCacheKey, BeamFnApi.StateGetResponse> clearedData = new HashMap<StateCacheKey, BeamFnApi.StateGetResponse>();
                StateCacheKey newKey = StateCacheKey.create(cacheToken, ByteString.EMPTY);
                clearedData.put(newKey, BeamFnApi.StateGetResponse.getDefaultInstance());
                this.stateCache.put(stateKey, clearedData);
                return response;
            }
        }
        throw new IllegalStateException(String.format("Unknown request type %s", requestBuilder.getRequestCase()));
    }

    private ByteString getCacheToken(BeamFnApi.StateKey stateKey) {
        if (stateKey.hasBagUserState()) {
            return this.userStateToken;
        }
        if (stateKey.hasRunner()) {
            return ByteString.EMPTY;
        }
        BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.Builder sideInputBuilder = BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.newBuilder();
        if (stateKey.hasIterableSideInput()) {
            BeamFnApi.StateKey.IterableSideInput iterableSideInput = stateKey.getIterableSideInput();
            sideInputBuilder.setTransformId(iterableSideInput.getTransformId()).setSideInputId(iterableSideInput.getSideInputId());
        } else if (stateKey.hasMultimapSideInput()) {
            BeamFnApi.StateKey.MultimapSideInput multimapSideInput = stateKey.getMultimapSideInput();
            sideInputBuilder.setTransformId(multimapSideInput.getTransformId()).setSideInputId(multimapSideInput.getSideInputId());
        } else if (stateKey.hasMultimapKeysSideInput()) {
            BeamFnApi.StateKey.MultimapKeysSideInput multimapKeysSideInput = stateKey.getMultimapKeysSideInput();
            sideInputBuilder.setTransformId(multimapKeysSideInput.getTransformId()).setSideInputId(multimapKeysSideInput.getSideInputId());
        }
        return this.sideInputCacheTokens.getOrDefault(sideInputBuilder.build(), ByteString.EMPTY);
    }

    @AutoValue
    public static abstract class StateCacheKey {
        public abstract ByteString getCacheToken();

        public abstract ByteString getContinuationToken();

        static StateCacheKey create(ByteString cacheToken, ByteString continuationToken) {
            return new AutoValue_CachingBeamFnStateClient_StateCacheKey(cacheToken, continuationToken);
        }
    }
}

