/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.jdbc;

import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.plugin.jdbc.JdbcAssignmentItem;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcErrorCode;
import io.trino.plugin.jdbc.JdbcMergeTableHandle;
import io.trino.plugin.jdbc.JdbcNamedRelationHandle;
import io.trino.plugin.jdbc.JdbcOutputTableHandle;
import io.trino.plugin.jdbc.JdbcPageSink;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.plugin.jdbc.SinkSqlProvider;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.Page;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.RowBlock;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorMergeSink;
import io.trino.spi.connector.ConnectorMergeTableHandle;
import io.trino.spi.connector.ConnectorPageSink;
import io.trino.spi.connector.ConnectorPageSinkId;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import java.lang.invoke.LambdaMetafactory;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.IntStream;

public class JdbcMergeSink
implements ConnectorMergeSink {
    private final int columnCount;
    private final ConnectorPageSinkId pageSinkId;
    private final ConnectorPageSink insertSink;
    private final ConnectorPageSink deleteSink;
    private final Map<Integer, Supplier<ConnectorPageSink>> updateSinkSuppliers;
    private final Map<Integer, int[]> updateCaseChannels;

    public JdbcMergeSink(ConnectorSession session, ConnectorMergeTableHandle mergeTableHandle, JdbcClient jdbcClient, ConnectorPageSinkId pageSinkId, RemoteQueryModifier queryModifier, QueryBuilder queryBuilder) {
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(mergeTableHandle, "mergeTableHandle is null");
        Objects.requireNonNull(jdbcClient, "jdbcClient is null");
        Objects.requireNonNull(queryModifier, "queryModifier is null");
        Objects.requireNonNull(queryBuilder, "queryBuilder is null");
        JdbcMergeTableHandle mergeHandle = (JdbcMergeTableHandle)mergeTableHandle;
        JdbcOutputTableHandle outputHandle = mergeHandle.getOutputTableHandle();
        List<JdbcColumnHandle> primaryKeys = mergeHandle.getPrimaryKeys();
        Preconditions.checkArgument((!primaryKeys.isEmpty() ? 1 : 0) != 0, (Object)"primary keys not exists");
        this.columnCount = outputHandle.getColumnNames().size();
        this.pageSinkId = Objects.requireNonNull(pageSinkId, "pageSinkId is null");
        ImmutableMap.Builder primaryKeysDomainBuilder = ImmutableMap.builder();
        Domain dummy = Domain.singleValue((Type)BigintType.BIGINT, (Object)0L);
        for (JdbcColumnHandle columnHandle : primaryKeys) {
            primaryKeysDomainBuilder.put((Object)columnHandle, (Object)dummy);
        }
        TupleDomain primaryKeysDomain = TupleDomain.withColumnDomains((Map)primaryKeysDomainBuilder.buildOrThrow());
        JdbcNamedRelationHandle relation = mergeHandle.getTableHandle().getRequiredNamedRelation();
        this.insertSink = new JdbcPageSink(session, outputHandle, jdbcClient, pageSinkId, queryModifier, JdbcClient::buildInsertSql);
        this.deleteSink = JdbcMergeSink.createDeleteSink(session, relation, (TupleDomain<ColumnHandle>)primaryKeysDomain, primaryKeys, jdbcClient, pageSinkId, queryModifier, queryBuilder);
        Map<Integer, Collection<ColumnHandle>> updateCaseColumns = mergeHandle.getUpdateCaseColumns();
        List<JdbcColumnHandle> columns = mergeHandle.getDataColumns();
        ImmutableMap.Builder updateSinksBuilder = ImmutableMap.builder();
        ImmutableMap.Builder updateCaseChannelsBuilder = ImmutableMap.builder();
        for (Map.Entry<Integer, Collection<ColumnHandle>> entry : updateCaseColumns.entrySet()) {
            int caseNumber = entry.getKey();
            Set columnChannels = (Set)entry.getValue().stream().map(JdbcColumnHandle.class::cast).map(columns::indexOf).collect(ImmutableSet.toImmutableSet());
            com.google.common.base.Supplier updateSupplier = Suppliers.memoize(() -> JdbcMergeSink.createUpdateSink(session, relation, (TupleDomain<ColumnHandle>)primaryKeysDomain, primaryKeys, jdbcClient, pageSinkId, queryModifier, queryBuilder, columns, columnChannels));
            updateSinksBuilder.put((Object)caseNumber, (Object)updateSupplier);
            updateCaseChannelsBuilder.put((Object)caseNumber, (Object)columnChannels.stream().mapToInt(Integer::intValue).sorted().toArray());
        }
        this.updateSinkSuppliers = updateSinksBuilder.buildOrThrow();
        this.updateCaseChannels = updateCaseChannelsBuilder.buildOrThrow();
    }

    private static ConnectorPageSink createUpdateSink(ConnectorSession session, JdbcNamedRelationHandle relation, TupleDomain<ColumnHandle> domain, List<JdbcColumnHandle> primaryKeys, JdbcClient jdbcClient, ConnectorPageSinkId pageSinkId, RemoteQueryModifier remoteQueryModifier, QueryBuilder queryBuilder, List<JdbcColumnHandle> columns, Set<Integer> updateChannels) {
        ImmutableList.Builder assignmentItemBuilder = ImmutableList.builder();
        ImmutableList.Builder columnNamesBuilder = ImmutableList.builder();
        ImmutableList.Builder columnTypesBuilder = ImmutableList.builder();
        QueryParameter dummy = new QueryParameter((Type)BigintType.BIGINT, Optional.empty());
        for (int channel = 0; channel < columns.size(); ++channel) {
            JdbcColumnHandle columnHandle = columns.get(channel);
            if (!updateChannels.contains(channel)) continue;
            columnNamesBuilder.add((Object)columnHandle.getColumnName());
            columnTypesBuilder.add((Object)columnHandle.getColumnType());
            assignmentItemBuilder.add((Object)new JdbcAssignmentItem(columnHandle, dummy));
        }
        for (JdbcColumnHandle columnHandle : primaryKeys) {
            columnNamesBuilder.add((Object)columnHandle.getColumnName());
            columnTypesBuilder.add((Object)columnHandle.getColumnType());
        }
        return new JdbcPageSink(session, new JdbcOutputTableHandle(relation.getRemoteTableName(), (List<String>)columnNamesBuilder.build(), (List<Type>)columnTypesBuilder.build(), Optional.empty(), Optional.empty(), Optional.empty()), jdbcClient, pageSinkId, remoteQueryModifier, JdbcMergeSink.updateSqlProvider(session, relation, domain, (List<JdbcAssignmentItem>)assignmentItemBuilder.build(), jdbcClient, queryBuilder));
    }

    private static SinkSqlProvider updateSqlProvider(ConnectorSession session, JdbcNamedRelationHandle relation, TupleDomain<ColumnHandle> domain, List<JdbcAssignmentItem> assignmentItems, JdbcClient jdbcClient, QueryBuilder queryBuilder) {
        SinkSqlProvider sinkSqlProvider;
        block8: {
            Connection connection = jdbcClient.getConnection(session);
            try {
                sinkSqlProvider = (jdbcClient2, jdbcOutputTableHandle, list) -> queryBuilder.prepareUpdateQuery(jdbcClient, session, connection, relation, domain, Optional.empty(), assignmentItems).query();
                if (connection == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (connection != null) {
                        try {
                            connection.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (SQLException e) {
                    throw new TrinoException((ErrorCodeSupplier)JdbcErrorCode.JDBC_ERROR, (Throwable)e);
                }
            }
            connection.close();
        }
        return sinkSqlProvider;
    }

    private static ConnectorPageSink createDeleteSink(ConnectorSession session, JdbcNamedRelationHandle relation, TupleDomain<ColumnHandle> domain, List<JdbcColumnHandle> primaryKeys, JdbcClient jdbcClient, ConnectorPageSinkId pageSinkId, RemoteQueryModifier remoteQueryModifier, QueryBuilder queryBuilder) {
        ImmutableList.Builder columnNamesBuilder = ImmutableList.builder();
        ImmutableList.Builder columnTypesBuilder = ImmutableList.builder();
        for (JdbcColumnHandle columnHandle : primaryKeys) {
            columnNamesBuilder.add((Object)columnHandle.getColumnName());
            columnTypesBuilder.add((Object)columnHandle.getColumnType());
        }
        return new JdbcPageSink(session, new JdbcOutputTableHandle(relation.getRemoteTableName(), (List<String>)columnNamesBuilder.build(), (List<Type>)columnTypesBuilder.build(), Optional.empty(), Optional.empty(), Optional.empty()), jdbcClient, pageSinkId, remoteQueryModifier, JdbcMergeSink.deleteSqlProvider(session, relation, domain, jdbcClient, queryBuilder));
    }

    private static SinkSqlProvider deleteSqlProvider(ConnectorSession session, JdbcNamedRelationHandle relation, TupleDomain<ColumnHandle> domain, JdbcClient jdbcClient, QueryBuilder queryBuilder) {
        SinkSqlProvider sinkSqlProvider;
        block8: {
            Connection connection = jdbcClient.getConnection(session);
            try {
                sinkSqlProvider = (jdbcClient2, jdbcOutputTableHandle, list) -> queryBuilder.prepareDeleteQuery(jdbcClient, session, connection, relation, domain, Optional.empty()).query();
                if (connection == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (connection != null) {
                        try {
                            connection.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (SQLException e) {
                    throw new TrinoException((ErrorCodeSupplier)JdbcErrorCode.JDBC_ERROR, (Throwable)e);
                }
            }
            connection.close();
        }
        return sinkSqlProvider;
    }

    public void storeMergedRows(Page page) {
        Preconditions.checkArgument((page.getChannelCount() == 3 + this.columnCount ? 1 : 0) != 0, (String)"The page size should be 3 + columnCount (%s), but is %s", (int)this.columnCount, (int)page.getChannelCount());
        int positionCount = page.getPositionCount();
        Block operationBlock = page.getBlock(this.columnCount);
        int[] dataChannel = IntStream.range(0, this.columnCount).toArray();
        Page dataPage = page.getColumns(dataChannel);
        int[] insertPositions = new int[positionCount];
        int insertPositionCount = 0;
        int[] deletePositions = new int[positionCount];
        int deletePositionCount = 0;
        Block updateCaseBlock = page.getBlock(this.columnCount + 1);
        HashMap<Integer, int[]> updatePositions = new HashMap<Integer, int[]>();
        HashMap<Integer, Integer> updatePositionCounts = new HashMap<Integer, Integer>();
        block5: for (int position = 0; position < positionCount; ++position) {
            byte operation = TinyintType.TINYINT.getByte(operationBlock, position);
            switch (operation) {
                case 1: {
                    insertPositions[insertPositionCount] = position;
                    ++insertPositionCount;
                    continue block5;
                }
                case 2: {
                    deletePositions[deletePositionCount] = position;
                    ++deletePositionCount;
                    continue block5;
                }
                case 3: {
                    int caseNumber = IntegerType.INTEGER.getInt(updateCaseBlock, position);
                    int updatePositionCount = updatePositionCounts.getOrDefault(caseNumber, 0);
                    updatePositions.computeIfAbsent(Integer.valueOf((int)caseNumber), (Function<Integer, int[]>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)Ljava/lang/Object;, lambda$storeMergedRows$3(int java.lang.Integer ), (Ljava/lang/Integer;)[I)((int)positionCount))[updatePositionCount] = position;
                    updatePositionCounts.put(caseNumber, updatePositionCount + 1);
                    continue block5;
                }
                default: {
                    throw new IllegalStateException("Unexpected value: " + operation);
                }
            }
        }
        if (insertPositionCount > 0) {
            this.insertSink.appendPage(dataPage.getPositions(insertPositions, 0, insertPositionCount));
        }
        List rowIdFields = RowBlock.getRowFieldsFromBlock((Block)page.getBlock(this.columnCount + 2));
        if (deletePositionCount > 0) {
            this.deleteSink.appendPage(new Page(deletePositionCount, JdbcMergeSink.extractRowIdBlocks(rowIdFields, deletePositionCount, deletePositions)));
        }
        for (Map.Entry entry : updatePositionCounts.entrySet()) {
            int caseNumber = (Integer)entry.getKey();
            int updatePositionCount = (Integer)entry.getValue();
            if (updatePositionCount <= 0) continue;
            Preconditions.checkArgument((boolean)updatePositions.containsKey(caseNumber), (String)"Unexpected case number %s", (int)caseNumber);
            int[] positions = (int[])updatePositions.get(caseNumber);
            int[] updateAssignmentChannels = this.updateCaseChannels.get(caseNumber);
            Block[] updateBlocks = new Block[updateAssignmentChannels.length + rowIdFields.size()];
            for (int channel = 0; channel < updateAssignmentChannels.length; ++channel) {
                updateBlocks[channel] = dataPage.getBlock(updateAssignmentChannels[channel]).getPositions(positions, 0, updatePositionCount);
            }
            System.arraycopy(JdbcMergeSink.extractRowIdBlocks(rowIdFields, updatePositionCount, positions), 0, updateBlocks, updateAssignmentChannels.length, rowIdFields.size());
            this.updateSinkSuppliers.get(caseNumber).get().appendPage(new Page(updatePositionCount, updateBlocks));
        }
    }

    private static Block[] extractRowIdBlocks(List<Block> rowIdFields, int positionCount, int[] positions) {
        Block[] blocks = new Block[rowIdFields.size()];
        for (int field = 0; field < rowIdFields.size(); ++field) {
            blocks[field] = rowIdFields.get(field).getPositions(positions, 0, positionCount);
        }
        return blocks;
    }

    public CompletableFuture<Collection<Slice>> finish() {
        this.insertSink.finish();
        this.deleteSink.finish();
        this.updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::finish);
        Slice value = Slices.allocate((int)8);
        value.setLong(0, this.pageSinkId.getId());
        return CompletableFuture.completedFuture(ImmutableList.of((Object)value));
    }

    public void abort() {
        this.insertSink.abort();
        this.deleteSink.abort();
        this.updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::abort);
    }

    private static /* synthetic */ int[] lambda$storeMergedRows$3(int positionCount, Integer n) {
        return new int[positionCount];
    }
}

