/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.samples.connectors.timestream;

import com.amazonaws.samples.connectors.timestream.BatchConverter;
import com.amazonaws.samples.connectors.timestream.TimestreamModelUtils;
import com.amazonaws.samples.connectors.timestream.TimestreamSinkConfig;
import com.amazonaws.samples.connectors.timestream.WriteRequestFailureHandler;
import com.amazonaws.samples.connectors.timestream.metrics.CloudWatchEmittedMetricGroupHelper;
import com.amazonaws.samples.connectors.timestream.metrics.MetricsCollector;
import com.amazonaws.samples.connectors.timestream.metrics.TimestreamSinkMetricGroup;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletionException;
import java.util.function.Consumer;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.connector.sink2.Sink;
import org.apache.flink.connector.base.sink.writer.AsyncSinkWriter;
import org.apache.flink.connector.base.sink.writer.BufferedRequestState;
import org.apache.flink.connector.base.sink.writer.ElementConverter;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.util.InstantiationUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.auth.credentials.SystemPropertyCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient;
import software.amazon.awssdk.profiles.ProfileFile;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.timestreamwrite.TimestreamWriteAsyncClient;
import software.amazon.awssdk.services.timestreamwrite.TimestreamWriteAsyncClientBuilder;
import software.amazon.awssdk.services.timestreamwrite.model.Record;
import software.amazon.awssdk.services.timestreamwrite.model.WriteRecordsRequest;

public class TimestreamSinkWriter<InputT>
extends AsyncSinkWriter<InputT, Record> {
    private static final Logger LOG = LoggerFactory.getLogger(TimestreamSinkWriter.class);
    private final BatchConverter batchConverter;
    private final TimestreamWriteAsyncClient client;
    private final WriteRequestFailureHandler failureHandler;
    private final MetricsCollector metricsCollector;

    public TimestreamSinkWriter(ElementConverter<InputT, Record> elementConverter, BatchConverter batchConverter, Sink.InitContext context, TimestreamSinkConfig timestreamSinkConfig) {
        super(elementConverter, context, timestreamSinkConfig.getMaxBatchSize(), timestreamSinkConfig.getMaxInFlightRequests(), timestreamSinkConfig.getMaxBufferedRequests(), Integer.MAX_VALUE, timestreamSinkConfig.getMaxTimeInBufferMS(), Integer.MAX_VALUE);
        this.batchConverter = batchConverter;
        this.client = this.openAsyncClient(timestreamSinkConfig);
        this.failureHandler = this.createFailureHandler(timestreamSinkConfig);
        this.metricsCollector = this.openMetricCollector(context);
    }

    TimestreamSinkMetricGroup createTimestreamSinkMetricGroup(Sink.InitContext context) {
        MetricGroup metricGroup = CloudWatchEmittedMetricGroupHelper.extendMetricGroup((MetricGroup)context.metricGroup());
        return new TimestreamSinkMetricGroup(metricGroup);
    }

    protected WriteRequestFailureHandler createFailureHandler(TimestreamSinkConfig timestreamSinkConfig) {
        WriteRequestFailureHandler instance = (WriteRequestFailureHandler)InstantiationUtil.instantiate((String)timestreamSinkConfig.getFailureHandlerConfig().getFailureHandlerClass(), WriteRequestFailureHandler.class, (ClassLoader)Thread.currentThread().getContextClassLoader());
        instance.open(this.getFatalExceptionCons(), timestreamSinkConfig.getFailureHandlerConfig());
        return instance;
    }

    @VisibleForTesting
    protected MetricsCollector openMetricCollector(Sink.InitContext context) {
        return new MetricsCollector(this.createTimestreamSinkMetricGroup(context));
    }

    @VisibleForTesting
    protected TimestreamWriteAsyncClient openAsyncClient(TimestreamSinkConfig timestreamSinkConfig) {
        TimestreamWriteAsyncClientBuilder asyncClientBuilder = (TimestreamWriteAsyncClientBuilder)((TimestreamWriteAsyncClientBuilder)((TimestreamWriteAsyncClientBuilder)((TimestreamWriteAsyncClientBuilder)TimestreamWriteAsyncClient.builder().overrideConfiguration((ClientOverrideConfiguration)ClientOverrideConfiguration.builder().apiCallAttemptTimeout(timestreamSinkConfig.getWriteClientConfig().getRequestTimeout()).retryPolicy(RetryPolicy.builder().numRetries(Integer.valueOf(timestreamSinkConfig.getWriteClientConfig().getMaxErrorRetry())).build()).build())).httpClient(AwsCrtAsyncHttpClient.builder().maxConcurrency(timestreamSinkConfig.getWriteClientConfig().getMaxConcurrency()).build())).region(Region.of((String)timestreamSinkConfig.getWriteClientConfig().getRegion()))).credentialsProvider(this.getCredentialProvider(timestreamSinkConfig.getCredentialsProviderType(), timestreamSinkConfig.getCredentialConfig()));
        String endpointOverride = timestreamSinkConfig.getWriteClientConfig().getEndpointOverride();
        if (endpointOverride != null) {
            URI endpointOverrideURI = this.parseEndpointOverride(endpointOverride);
            asyncClientBuilder = (TimestreamWriteAsyncClientBuilder)asyncClientBuilder.endpointOverride(endpointOverrideURI);
        }
        LOG.debug("AmazonTimestreamWriteAsync client constructed.");
        return (TimestreamWriteAsyncClient)asyncClientBuilder.build();
    }

    private AwsCredentialsProvider getCredentialProvider(TimestreamSinkConfig.CredentialProviderType credentialType, TimestreamSinkConfig.CredentialConfig credentialConfig) {
        switch (credentialType) {
            case ENV_VAR: {
                return EnvironmentVariableCredentialsProvider.create();
            }
            case SYS_PROP: {
                return SystemPropertyCredentialsProvider.create();
            }
            case PROFILE: {
                String profileName = credentialConfig.getProfileName();
                String profileConfigPath = credentialConfig.getProfileConfigPath();
                return profileConfigPath == null ? ProfileCredentialsProvider.create((String)profileName) : ProfileCredentialsProvider.builder().profileName(profileName).profileFile(ProfileFile.builder().content(Path.of(profileConfigPath, new String[0])).build()).build();
            }
            case AUTO: {
                return DefaultCredentialsProvider.create();
            }
        }
        throw new IllegalArgumentException("Credential provider not supported: " + credentialType);
    }

    private URI parseEndpointOverride(String endpointOverride) {
        try {
            return new URI(endpointOverride);
        }
        catch (URISyntaxException uriSyntaxException) {
            throw new RuntimeException("Invalid EndpointOverride Config: " + endpointOverride);
        }
    }

    protected void submitRequestEntries(List<Record> requestEntries, Consumer<List<Record>> requestResult) {
        WriteRecordsRequest request = this.batchConverter.apply(requestEntries);
        LOG.debug("Sending WriteRecordsRequest with {} records to Timestream...", (Object)request.records().size());
        this.metricsCollector.collectPreWriteMetrics(request);
        try {
            this.asyncWriteRecords(requestEntries, requestResult, request);
        }
        catch (Exception t) {
            LOG.error("Unexpected exception occurred when sending records to Timestream. Retrying all records.", (Throwable)t);
            this.metricsCollector.collectExceptionMetrics(t);
            requestResult.accept(requestEntries);
        }
    }

    private void asyncWriteRecords(List<Record> requestEntries, Consumer<List<Record>> requestResult, WriteRecordsRequest request) {
        this.client.writeRecords(request).whenComplete((response, err) -> {
            if (err != null) {
                if (err instanceof CompletionException) {
                    err = err.getCause();
                }
                if (err instanceof Exception) {
                    Exception exception = (Exception)err;
                    this.metricsCollector.collectExceptionMetrics(exception);
                    Consumer<List<Record>> requestResultMetricsWrapped = records -> {
                        this.metricsCollector.collectRetries((Collection<Record>)records);
                        requestResult.accept((List<Record>)records);
                    };
                    Consumer<List<Record>> droppedRecordsMetricsWrapped = records -> this.metricsCollector.collectDropped((Collection<Record>)records, request);
                    this.failureHandler.onWriteError(requestEntries, request, exception, requestResultMetricsWrapped, droppedRecordsMetricsWrapped);
                } else {
                    this.getFatalExceptionCons().accept(new Exception((Throwable)err));
                }
            } else {
                LOG.trace("Timestream writeRecordsAsync onSuccess: {} -> {}", (Object)request, response);
                this.metricsCollector.collectSuccessMetrics(request);
                requestResult.accept(Collections.emptyList());
            }
        });
    }

    protected long getSizeInBytes(Record requestEntry) {
        return TimestreamModelUtils.getRecordSizeInBytes(requestEntry);
    }

    public List<BufferedRequestState<Record>> snapshotState(long checkpointId) {
        try {
            this.flush(true);
        }
        catch (InterruptedException e) {
            throw new RuntimeException("Interrupted while flushing buffer during snapshotState", e);
        }
        return super.snapshotState(checkpointId);
    }
}

