/*
 * Decompiled with CFR 0.152.
 */
package io.kestra.jdbc.runner;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.CaseFormat;
import com.google.common.collect.Iterables;
import io.kestra.core.exceptions.DeserializationException;
import io.kestra.core.metrics.MetricRegistry;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.queues.MessageTooBigException;
import io.kestra.core.queues.QueueException;
import io.kestra.core.queues.QueueInterface;
import io.kestra.core.queues.QueueService;
import io.kestra.core.queues.UnsupportedMessageException;
import io.kestra.core.utils.Either;
import io.kestra.core.utils.ExecutorsUtils;
import io.kestra.core.utils.IdUtils;
import io.kestra.core.utils.Rethrow;
import io.kestra.jdbc.JdbcMapper;
import io.kestra.jdbc.JdbcTableConfigs;
import io.kestra.jdbc.JooqDSLContextWrapper;
import io.kestra.jdbc.repository.AbstractJdbcRepository;
import io.kestra.jdbc.runner.JdbcQueueIndexer;
import io.kestra.jdbc.runner.MessageProtectionConfiguration;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import io.micronaut.context.ApplicationContext;
import io.micronaut.context.annotation.ConfigurationProperties;
import io.micronaut.transaction.exceptions.CannotCreateTransactionException;
import java.io.IOException;
import java.time.Duration;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import lombok.Generated;
import org.jooq.Condition;
import org.jooq.DSLContext;
import org.jooq.Field;
import org.jooq.JSONB;
import org.jooq.OrderField;
import org.jooq.Record;
import org.jooq.Result;
import org.jooq.SelectConditionStep;
import org.jooq.SelectField;
import org.jooq.SelectLimitPercentStep;
import org.jooq.Table;
import org.jooq.exception.DataException;
import org.jooq.impl.DSL;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class JdbcQueue<T>
implements QueueInterface<T> {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(JdbcQueue.class);
    private static final int MAX_ASYNC_THREADS = Runtime.getRuntime().availableProcessors();
    protected static final ObjectMapper MAPPER = JdbcMapper.of();
    private final ExecutorService poolExecutor;
    private final ExecutorService asyncPoolExecutor;
    protected final QueueService queueService;
    protected final Class<T> cls;
    protected final JooqDSLContextWrapper dslContextWrapper;
    protected final Configuration configuration;
    protected final MessageProtectionConfiguration messageProtectionConfiguration;
    private final MetricRegistry metricRegistry;
    protected final Table<Record> table;
    protected final JdbcQueueIndexer jdbcQueueIndexer;
    private final boolean immediateRepoll;
    private final AtomicBoolean isClosed = new AtomicBoolean(false);
    private final AtomicBoolean isPaused = new AtomicBoolean(false);
    private final Counter bigMessageCounter;

    public JdbcQueue(Class<T> cls, ApplicationContext applicationContext) {
        ExecutorsUtils executorsUtils = (ExecutorsUtils)applicationContext.getBean(ExecutorsUtils.class);
        this.poolExecutor = executorsUtils.cachedThreadPool("jdbc-queue-" + cls.getSimpleName());
        this.asyncPoolExecutor = executorsUtils.maxCachedThreadPool(MAX_ASYNC_THREADS, "jdbc-queue-async-" + cls.getSimpleName());
        this.queueService = (QueueService)applicationContext.getBean(QueueService.class);
        this.cls = cls;
        this.dslContextWrapper = (JooqDSLContextWrapper)applicationContext.getBean(JooqDSLContextWrapper.class);
        this.configuration = (Configuration)applicationContext.getBean(Configuration.class);
        this.messageProtectionConfiguration = (MessageProtectionConfiguration)applicationContext.getBean(MessageProtectionConfiguration.class);
        this.metricRegistry = (MetricRegistry)applicationContext.getBean(MetricRegistry.class);
        JdbcTableConfigs jdbcTableConfigs = (JdbcTableConfigs)applicationContext.getBean(JdbcTableConfigs.class);
        this.table = DSL.table((String)jdbcTableConfigs.tableConfig("queues").table());
        this.jdbcQueueIndexer = (JdbcQueueIndexer)applicationContext.getBean(JdbcQueueIndexer.class);
        this.immediateRepoll = applicationContext.getProperty("kestra.jdbc.queues.immediate-repoll", Boolean.class).orElse(true);
        this.bigMessageCounter = this.metricRegistry.counter("queue.big_message.count", "Total number of big messages", new String[]{"class_name", this.queueType()});
    }

    protected Map<Field<Object>, Object> produceFields(String consumerGroup, String key, T message) throws QueueException {
        byte[] bytes;
        try {
            bytes = MAPPER.writeValueAsBytes(message);
        }
        catch (JsonProcessingException e) {
            throw new QueueException("Unable to serialize the message", (Throwable)e);
        }
        if (this.messageProtectionConfiguration.enabled && bytes.length >= this.messageProtectionConfiguration.limit) {
            Execution execution;
            this.bigMessageCounter.increment();
            if (!(message instanceof Execution) || !(execution = (Execution)message).getState().isTerminated()) {
                throw new MessageTooBigException("Message of size " + bytes.length + " has exceeded the configured limit of " + this.messageProtectionConfiguration.limit);
            }
        }
        HashMap<Field<Object>, Object> fields = new HashMap<Field<Object>, Object>();
        fields.put(AbstractJdbcRepository.field("type"), this.queueType());
        fields.put(AbstractJdbcRepository.field("key"), key != null ? key : IdUtils.create());
        fields.put(AbstractJdbcRepository.field("value"), JSONB.valueOf((String)new String(bytes)));
        if (consumerGroup != null) {
            fields.put(AbstractJdbcRepository.field("consumer_group"), consumerGroup);
        }
        return fields;
    }

    private void produce(String consumerGroup, String key, T message, Boolean skipIndexer) throws QueueException {
        String[] stringArray;
        if (log.isTraceEnabled()) {
            log.trace("New message: topic '{}', value {}", (Object)this.queueType(), message);
        }
        Map<Field<Object>, Object> fields = this.produceFields(consumerGroup, key, message);
        try {
            this.dslContextWrapper.transaction(configuration -> {
                DSLContext context = DSL.using((org.jooq.Configuration)configuration);
                if (!skipIndexer.booleanValue()) {
                    this.jdbcQueueIndexer.accept(context, message);
                }
                context.insertInto(this.table).set(fields).execute();
            });
        }
        catch (DataException e) {
            if (e.getMessage() != null && e.getMessage().contains("ERROR: unsupported Unicode escape sequence")) {
                throw new UnsupportedMessageException(e.getMessage(), (Throwable)e);
            }
            throw new QueueException("Unable to emit a message to the queue", (Throwable)e);
        }
        if (consumerGroup == null) {
            String[] stringArray2 = new String[2];
            stringArray2[0] = "queue_type";
            stringArray = stringArray2;
            stringArray2[1] = this.queueType();
        } else {
            String[] stringArray3 = new String[4];
            stringArray3[0] = "queue_type";
            stringArray3[1] = this.queueType();
            stringArray3[2] = "consumer_group";
            stringArray = stringArray3;
            stringArray3[3] = consumerGroup;
        }
        String[] tags = stringArray;
        this.metricRegistry.counter("queue.produce.count", "Total number of produced messages", tags).increment();
    }

    public void emitOnly(String consumerGroup, T message) throws QueueException {
        this.produce(consumerGroup, this.queueService.key(message), message, true);
    }

    public void emit(String consumerGroup, T message) throws QueueException {
        this.produce(consumerGroup, this.queueService.key(message), message, false);
    }

    public void emitAsync(String consumerGroup, List<T> messages) throws QueueException {
        this.asyncPoolExecutor.submit(Rethrow.throwRunnable(() -> messages.forEach(Rethrow.throwConsumer(message -> this.emit(consumerGroup, message)))));
    }

    public void delete(String consumerGroup, T message) throws QueueException {
    }

    public void deleteByKey(String key) throws QueueException {
        this.dslContextWrapper.transaction(configuration -> {
            int deleted = DSL.using((org.jooq.Configuration)configuration).delete(this.table).where(this.buildTypeCondition(this.queueType())).and(AbstractJdbcRepository.field("key").eq((Object)key)).execute();
            log.debug("Cleaned {} records for key {}", (Object)deleted, (Object)key);
        });
    }

    protected String queueType() {
        return this.cls.getName();
    }

    public void deleteByKeys(List<String> keys) throws QueueException {
        Iterables.partition(keys, (int)100).forEach(batch -> this.dslContextWrapper.transaction(configuration -> {
            int deleted = DSL.using((org.jooq.Configuration)configuration).delete(this.table).where(this.buildTypeCondition(this.queueType())).and(AbstractJdbcRepository.field("key").in((Collection)batch)).execute();
            log.debug("Cleaned {} records for keys {}", (Object)deleted, batch);
        }));
    }

    protected Result<Record> receiveFetch(DSLContext ctx, String consumerGroup, Integer offset) {
        return this.receiveFetch(ctx, consumerGroup, offset, true);
    }

    protected Result<Record> receiveFetch(DSLContext ctx, String consumerGroup, Integer offset, boolean forUpdate) {
        SelectLimitPercentStep limitSelect;
        SelectConditionStep select = ctx.select(AbstractJdbcRepository.field("value"), AbstractJdbcRepository.field("offset")).from(this.table).where(this.buildTypeCondition(this.queueType()));
        if (offset != 0) {
            select = select.and(AbstractJdbcRepository.field("offset").gt((Object)offset));
        }
        select = consumerGroup != null ? select.and(AbstractJdbcRepository.field("consumer_group").eq((Object)consumerGroup)) : select.and(AbstractJdbcRepository.field("consumer_group").isNull());
        SelectLimitPercentStep configuredSelect = limitSelect = select.orderBy((OrderField)AbstractJdbcRepository.field("offset").asc()).limit((Number)this.configuration.getPollSize());
        if (forUpdate) {
            configuredSelect = limitSelect.forUpdate().skipLocked();
        }
        return (Result)configuredSelect.fetchMany().getFirst();
    }

    protected Result<Record> receiveFetch(DSLContext ctx, String consumerGroup, String queueType) {
        return this.receiveFetch(ctx, consumerGroup, queueType, true);
    }

    protected abstract Result<Record> receiveFetch(DSLContext var1, String var2, String var3, boolean var4);

    protected abstract void updateGroupOffsets(DSLContext var1, String var2, String var3, List<Integer> var4);

    protected abstract Condition buildTypeCondition(String var1);

    public Runnable receive(String consumerGroup, Consumer<Either<T, DeserializationException>> consumer, boolean forUpdate) {
        String[] stringArray;
        if (consumerGroup == null) {
            String[] stringArray2 = new String[2];
            stringArray2[0] = "queue_type";
            stringArray = stringArray2;
            stringArray2[1] = this.queueType();
        } else {
            String[] stringArray3 = new String[4];
            stringArray3[0] = "queue_type";
            stringArray3[1] = this.queueType();
            stringArray3[2] = "consumer_group";
            stringArray = stringArray3;
            stringArray3[3] = consumerGroup;
        }
        String[] tags = stringArray;
        AtomicInteger pollSize = new AtomicInteger();
        this.metricRegistry.gauge("queue.poll.size", "Size of a poll to the queue (message batch size)", (Number)pollSize, tags);
        AtomicInteger maxOffset = new AtomicInteger();
        this.dslContextWrapper.transaction(configuration -> {
            SelectConditionStep select = DSL.using((org.jooq.Configuration)configuration).select((SelectField)DSL.max(AbstractJdbcRepository.field("offset")).as("max")).from(this.table).where(this.buildTypeCondition(this.queueType()));
            select = consumerGroup != null ? select.and(AbstractJdbcRepository.field("consumer_group").eq((Object)consumerGroup)) : select.and(AbstractJdbcRepository.field("consumer_group").isNull());
            Integer integer = (Integer)select.fetchAny("max", Integer.class);
            if (integer != null) {
                maxOffset.set(integer);
            }
        });
        Timer timer = this.metricRegistry.timer("queue.receive.duration", "Queue duration to receive and consume a batch of messages", tags);
        return this.poll(() -> timer.record(() -> {
            Result fetch = (Result)this.dslContextWrapper.transactionResult(configuration -> {
                DSLContext ctx = DSL.using((org.jooq.Configuration)configuration);
                Result<Record> result = this.receiveFetch(ctx, consumerGroup, maxOffset.get(), forUpdate);
                if (!result.isEmpty()) {
                    List offsets = result.map(record -> (Integer)record.get("offset", Integer.class));
                    maxOffset.set((Integer)offsets.getLast());
                }
                return result;
            });
            this.send((Result<Record>)fetch, consumer);
            pollSize.set(fetch.size());
            return fetch.size();
        }));
    }

    public Runnable receive(String consumerGroup, Class<?> queueType, Consumer<Either<T, DeserializationException>> consumer, boolean forUpdate) {
        return this.receiveImpl(consumerGroup, queueType, (dslContext, eithers) -> eithers.forEach(consumer), false, forUpdate);
    }

    public Runnable receiveBatch(Class<?> queueType, Consumer<List<Either<T, DeserializationException>>> consumer) {
        return this.receiveBatch(null, queueType, consumer);
    }

    public Runnable receiveBatch(String consumerGroup, Class<?> queueType, Consumer<List<Either<T, DeserializationException>>> consumer) {
        return this.receiveBatch(consumerGroup, queueType, consumer, true);
    }

    public Runnable receiveBatch(String consumerGroup, Class<?> queueType, Consumer<List<Either<T, DeserializationException>>> consumer, boolean forUpdate) {
        return this.receiveImpl(consumerGroup, queueType, (dslContext, eithers) -> consumer.accept((List)eithers), false, forUpdate);
    }

    public Runnable receiveTransaction(String consumerGroup, Class<?> queueType, BiConsumer<DSLContext, List<Either<T, DeserializationException>>> consumer) {
        return this.receiveImpl(consumerGroup, queueType, consumer, true, true);
    }

    public Runnable receiveImpl(String consumerGroup, Class<?> queueType, BiConsumer<DSLContext, List<Either<T, DeserializationException>>> consumer, Boolean inTransaction, boolean forUpdate) {
        String[] stringArray;
        String queueName = this.queueName(queueType);
        if (consumerGroup == null) {
            String[] stringArray2 = new String[4];
            stringArray2[0] = "queue_type";
            stringArray2[1] = this.queueType();
            stringArray2[2] = "consumer";
            stringArray = stringArray2;
            stringArray2[3] = queueName;
        } else {
            String[] stringArray3 = new String[6];
            stringArray3[0] = "queue_type";
            stringArray3[1] = this.queueType();
            stringArray3[2] = "consumer";
            stringArray3[3] = queueName;
            stringArray3[4] = "consumer_group";
            stringArray = stringArray3;
            stringArray3[5] = consumerGroup;
        }
        String[] tags = stringArray;
        AtomicInteger pollSize = new AtomicInteger();
        this.metricRegistry.gauge("queue.poll.size", "Size of a poll to the queue (message batch size)", (Number)pollSize, tags);
        Timer timer = this.metricRegistry.timer("queue.receive.duration", "Queue duration to receive and consume a batch of messages", tags);
        return this.poll(() -> timer.record(() -> {
            Result fetch = (Result)this.dslContextWrapper.transactionResult(configuration -> {
                DSLContext ctx = DSL.using((org.jooq.Configuration)configuration);
                Result<Record> result = this.receiveFetch(ctx, consumerGroup, queueName, forUpdate);
                if (!result.isEmpty() && inTransaction.booleanValue()) {
                    consumer.accept(ctx, this.map(result));
                    this.updateGroupOffsets(ctx, consumerGroup, queueName, result.map(record -> (Integer)record.get("offset", Integer.class)));
                }
                return result;
            });
            if (!inTransaction.booleanValue()) {
                consumer.accept(null, this.map((Result<Record>)fetch));
                this.dslContextWrapper.transaction(configuration -> this.updateGroupOffsets(DSL.using((org.jooq.Configuration)configuration), consumerGroup, queueName, fetch.map(record -> (Integer)record.get("offset", Integer.class))));
            }
            pollSize.set(fetch.size());
            return fetch.size();
        }));
    }

    protected String queueName(Class<?> queueType) {
        return CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, queueType.getSimpleName());
    }

    protected Runnable poll(Supplier<Integer> runnable) {
        AtomicBoolean running = new AtomicBoolean(true);
        this.poolExecutor.execute(() -> {
            List<Configuration.Step> steps = this.configuration.computeSteps();
            Duration sleep = this.configuration.minPollInterval;
            ZonedDateTime lastPoll = ZonedDateTime.now();
            while (running.get() && !this.isClosed.get()) {
                block9: {
                    if (!this.isPaused.get()) {
                        try {
                            Integer count = (Integer)runnable.get();
                            if (count > 0) {
                                lastPoll = ZonedDateTime.now();
                                sleep = this.configuration.minPollInterval;
                                if (this.immediateRepoll || count.equals(this.configuration.pollSize)) {
                                    continue;
                                }
                            } else {
                                ZonedDateTime finalLastPoll = lastPoll;
                                List<Configuration.Step> selectedSteps = steps.stream().takeWhile(step -> finalLastPoll.plus(step.switchInterval()).compareTo(ZonedDateTime.now()) < 0).toList();
                                sleep = selectedSteps.isEmpty() ? this.configuration.minPollInterval : selectedSteps.getLast().pollInterval();
                            }
                        }
                        catch (CannotCreateTransactionException e) {
                            if (!log.isDebugEnabled()) break block9;
                            log.debug("Can't poll on receive", (Throwable)e);
                        }
                    }
                }
                try {
                    Thread.sleep(sleep);
                }
                catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
        });
        return () -> running.set(false);
    }

    protected List<Either<T, DeserializationException>> map(Result<Record> fetch) {
        return fetch.map(record -> {
            try {
                return Either.left((Object)MAPPER.readValue((String)record.get("value", String.class), this.cls));
            }
            catch (JsonProcessingException e) {
                return Either.right((Object)((Object)new DeserializationException((Exception)((Object)e), (String)record.get("value", String.class))));
            }
        });
    }

    protected void send(Result<Record> fetch, Consumer<Either<T, DeserializationException>> consumer) {
        this.map(fetch).forEach(consumer);
    }

    public void pause() {
        this.isPaused.set(true);
    }

    public void resume() {
        this.isPaused.set(false);
    }

    public void close() throws IOException {
        if (!this.isClosed.compareAndSet(false, true)) {
            return;
        }
        this.poolExecutor.shutdown();
        this.asyncPoolExecutor.shutdown();
    }

    @ConfigurationProperties(value="kestra.jdbc.queues")
    public static class Configuration {
        Duration minPollInterval = Duration.ofMillis(25L);
        Duration maxPollInterval = Duration.ofMillis(500L);
        Duration pollSwitchInterval = Duration.ofSeconds(60L);
        Integer pollSize = 100;
        Integer switchSteps = 5;

        public List<Step> computeSteps() {
            if (this.maxPollInterval.compareTo(this.minPollInterval) <= 0) {
                throw new IllegalArgumentException("'maxPollInterval' (" + String.valueOf(this.maxPollInterval) + ") must be greater than 'minPollInterval' (" + String.valueOf(this.minPollInterval) + ")");
            }
            ArrayList<Step> steps = new ArrayList<Step>();
            Step currentStep = new Step(this.maxPollInterval, this.pollSwitchInterval);
            steps.add(currentStep);
            for (int i = 0; i < this.switchSteps; ++i) {
                Duration stepPollInterval = Duration.ofMillis(currentStep.pollInterval().toMillis() / 2L);
                if (stepPollInterval.compareTo(this.minPollInterval) < 0) {
                    stepPollInterval = this.minPollInterval;
                }
                Duration stepSwitchInterval = Duration.ofMillis(currentStep.switchInterval().toMillis() / 2L);
                currentStep = new Step(stepPollInterval, stepSwitchInterval);
                steps.add(currentStep);
            }
            Collections.sort(steps);
            return steps;
        }

        @Generated
        public Duration getMinPollInterval() {
            return this.minPollInterval;
        }

        @Generated
        public Duration getMaxPollInterval() {
            return this.maxPollInterval;
        }

        @Generated
        public Duration getPollSwitchInterval() {
            return this.pollSwitchInterval;
        }

        @Generated
        public Integer getPollSize() {
            return this.pollSize;
        }

        @Generated
        public Integer getSwitchSteps() {
            return this.switchSteps;
        }

        public record Step(Duration pollInterval, Duration switchInterval) implements Comparable<Step>
        {
            @Override
            public int compareTo(Step o) {
                return this.switchInterval.compareTo(o.switchInterval);
            }
        }
    }
}

