/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.jdbc;

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.trino.plugin.jdbc.BooleanWriteFunction;
import io.trino.plugin.jdbc.DoubleWriteFunction;
import io.trino.plugin.jdbc.JdbcAssignmentItem;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcJoinCondition;
import io.trino.plugin.jdbc.JdbcNamedRelationHandle;
import io.trino.plugin.jdbc.JdbcProcedureHandle;
import io.trino.plugin.jdbc.JdbcQueryRelationHandle;
import io.trino.plugin.jdbc.JdbcRelationHandle;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.LongWriteFunction;
import io.trino.plugin.jdbc.ObjectWriteFunction;
import io.trino.plugin.jdbc.PreparedQuery;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.plugin.jdbc.RemoteTableName;
import io.trino.plugin.jdbc.SliceWriteFunction;
import io.trino.plugin.jdbc.WriteFunction;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.JoinType;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.Type;
import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Collectors;

public class DefaultQueryBuilder
implements QueryBuilder {
    private static final Logger log = Logger.get(DefaultQueryBuilder.class);
    private static final String ALWAYS_TRUE = "1=1";
    private static final String ALWAYS_FALSE = "1=0";
    private final RemoteQueryModifier queryModifier;

    @Inject
    public DefaultQueryBuilder(RemoteQueryModifier queryModifier) {
        this.queryModifier = Objects.requireNonNull(queryModifier, "queryModifier is null");
    }

    @Override
    public PreparedQuery prepareSelectQuery(JdbcClient client, ConnectorSession session, Connection connection, JdbcRelationHandle baseRelation, Optional<List<List<JdbcColumnHandle>>> groupingSets, List<JdbcColumnHandle> columns, Map<String, ParameterizedExpression> columnExpressions, TupleDomain<ColumnHandle> tupleDomain, Optional<ParameterizedExpression> additionalPredicate) {
        if (!tupleDomain.isNone()) {
            Map domains = (Map)tupleDomain.getDomains().orElseThrow();
            columns.stream().filter(domains::containsKey).filter(column -> columnExpressions.containsKey(column.getColumnName())).findFirst().ifPresent(column -> {
                throw new IllegalArgumentException(String.format("Column %s has an expression and a constraint attached at the same time", column));
            });
        }
        ImmutableList.Builder conjuncts = ImmutableList.builder();
        ImmutableList.Builder accumulator = ImmutableList.builder();
        String sql = "SELECT " + this.getProjection(client, columns, columnExpressions, arg_0 -> ((ImmutableList.Builder)accumulator).add(arg_0));
        sql = sql + this.getFrom(client, baseRelation, arg_0 -> ((ImmutableList.Builder)accumulator).add(arg_0));
        this.toConjuncts(client, session, connection, tupleDomain, (ImmutableList.Builder<String>)conjuncts, arg_0 -> ((ImmutableList.Builder)accumulator).add(arg_0));
        additionalPredicate.ifPresent(predicate -> {
            conjuncts.add((Object)predicate.expression());
            accumulator.addAll(predicate.parameters());
        });
        ImmutableList clauses = conjuncts.build();
        if (!clauses.isEmpty()) {
            sql = sql + " WHERE " + Joiner.on((String)" AND ").join((Iterable)clauses);
        }
        sql = sql + this.getGroupBy(client, groupingSets);
        return new PreparedQuery(sql, (List<QueryParameter>)accumulator.build());
    }

    @Override
    public PreparedQuery prepareJoinQuery(JdbcClient client, ConnectorSession session, Connection connection, JoinType joinType, PreparedQuery leftSource, Map<JdbcColumnHandle, String> leftProjections, PreparedQuery rightSource, Map<JdbcColumnHandle, String> rightProjections, List<ParameterizedExpression> joinConditions) {
        Verify.verify((!joinConditions.isEmpty() ? 1 : 0) != 0, (String)"joinConditions is empty", (Object[])new Object[0]);
        String query = String.format("SELECT %s FROM (SELECT %s FROM (%s) l) l %s (SELECT %s FROM (%s) r) r ON %s", this.formatProjectionAliases(client, (Collection<String>)ImmutableList.builder().addAll(leftProjections.values()).addAll(rightProjections.values()).build()), this.formatProjections(client, leftProjections), leftSource.query(), this.formatJoinType(joinType), this.formatProjections(client, rightProjections), rightSource.query(), joinConditions.stream().map(ParameterizedExpression::expression).collect(Collectors.joining(") AND (", "(", ")")));
        ImmutableList parameters = ImmutableList.builder().addAll(leftSource.parameters()).addAll(rightSource.parameters()).addAll(joinConditions.stream().flatMap(expression -> expression.parameters().stream()).iterator()).build();
        return new PreparedQuery(query, (List<QueryParameter>)parameters);
    }

    @Override
    public PreparedQuery legacyPrepareJoinQuery(JdbcClient client, ConnectorSession session, Connection connection, JoinType joinType, PreparedQuery leftSource, PreparedQuery rightSource, List<JdbcJoinCondition> joinConditions, Map<JdbcColumnHandle, String> leftAssignments, Map<JdbcColumnHandle, String> rightAssignments) {
        Verify.verify((!leftAssignments.isEmpty() ? 1 : 0) != 0, (String)"leftAssignments is empty", (Object[])new Object[0]);
        Verify.verify((!rightAssignments.isEmpty() ? 1 : 0) != 0, (String)"rightAssignments is empty", (Object[])new Object[0]);
        Verify.verify((!joinConditions.isEmpty() ? 1 : 0) != 0, (String)"joinConditions is empty", (Object[])new Object[0]);
        String leftRelationAlias = "l";
        String rightRelationAlias = "r";
        String query = String.format("SELECT %s, %s FROM (%s) %s %s (%s) %s ON %s", this.formatAssignments(client, leftRelationAlias, leftAssignments), this.formatAssignments(client, rightRelationAlias, rightAssignments), leftSource.query(), leftRelationAlias, this.formatJoinType(joinType), rightSource.query(), rightRelationAlias, joinConditions.stream().map(condition -> this.formatJoinCondition(client, leftRelationAlias, rightRelationAlias, (JdbcJoinCondition)condition)).collect(Collectors.joining(" AND ")));
        ImmutableList parameters = ImmutableList.builder().addAll(leftSource.parameters()).addAll(rightSource.parameters()).build();
        return new PreparedQuery(query, (List<QueryParameter>)parameters);
    }

    @Override
    public PreparedQuery prepareDeleteQuery(JdbcClient client, ConnectorSession session, Connection connection, JdbcNamedRelationHandle baseRelation, TupleDomain<ColumnHandle> tupleDomain, Optional<ParameterizedExpression> additionalPredicate) {
        String sql = "DELETE FROM " + this.getRelation(client, baseRelation.getRemoteTableName());
        ImmutableList.Builder conjuncts = ImmutableList.builder();
        ImmutableList.Builder accumulator = ImmutableList.builder();
        this.toConjuncts(client, session, connection, tupleDomain, (ImmutableList.Builder<String>)conjuncts, arg_0 -> ((ImmutableList.Builder)accumulator).add(arg_0));
        additionalPredicate.ifPresent(predicate -> {
            conjuncts.add((Object)predicate.expression());
            accumulator.addAll(predicate.parameters());
        });
        ImmutableList clauses = conjuncts.build();
        if (!clauses.isEmpty()) {
            sql = sql + " WHERE " + Joiner.on((String)" AND ").join((Iterable)clauses);
        }
        return new PreparedQuery(sql, (List<QueryParameter>)accumulator.build());
    }

    @Override
    public PreparedQuery prepareUpdateQuery(JdbcClient client, ConnectorSession session, Connection connection, JdbcNamedRelationHandle baseRelation, TupleDomain<ColumnHandle> tupleDomain, Optional<ParameterizedExpression> additionalPredicate, List<JdbcAssignmentItem> assignments) {
        ImmutableList.Builder accumulator = ImmutableList.builder();
        String sql = "UPDATE " + this.getRelation(client, baseRelation.getRemoteTableName()) + " SET ";
        assignments.forEach(entry -> {
            JdbcColumnHandle columnHandle = entry.column();
            accumulator.add((Object)new QueryParameter(columnHandle.getJdbcTypeHandle(), columnHandle.getColumnType(), entry.queryParameter().getValue()));
        });
        sql = sql + assignments.stream().map(JdbcAssignmentItem::column).map(columnHandle -> {
            String bindExpression = DefaultQueryBuilder.getWriteFunction(client, session, connection, columnHandle.getJdbcTypeHandle(), columnHandle.getColumnType()).getBindExpression();
            return client.quoted(columnHandle.getColumnName()) + " = " + bindExpression;
        }).collect(Collectors.joining(", "));
        ImmutableList.Builder conjuncts = ImmutableList.builder();
        this.toConjuncts(client, session, connection, tupleDomain, (ImmutableList.Builder<String>)conjuncts, arg_0 -> ((ImmutableList.Builder)accumulator).add(arg_0));
        additionalPredicate.ifPresent(predicate -> {
            conjuncts.add((Object)predicate.expression());
            accumulator.addAll(predicate.parameters());
        });
        ImmutableList clauses = conjuncts.build();
        if (!clauses.isEmpty()) {
            sql = sql + " WHERE " + Joiner.on((String)" AND ").join((Iterable)clauses);
        }
        return new PreparedQuery(sql, (List<QueryParameter>)accumulator.build());
    }

    @Override
    public PreparedStatement prepareStatement(JdbcClient client, ConnectorSession session, Connection connection, PreparedQuery preparedQuery, Optional<Integer> columnCount) throws SQLException {
        String modifiedQuery = this.queryModifier.apply(session, preparedQuery.query());
        log.debug("Preparing query: %s", new Object[]{modifiedQuery});
        columnCount = columnCount.map(count -> Math.max(count, 1));
        PreparedStatement statement = client.getPreparedStatement(connection, modifiedQuery, columnCount);
        List<QueryParameter> parameters = preparedQuery.parameters();
        for (int i = 0; i < parameters.size(); ++i) {
            QueryParameter parameter = parameters.get(i);
            int parameterIndex = i + 1;
            WriteFunction writeFunction = parameter.getJdbcType().map(jdbcType -> DefaultQueryBuilder.getWriteFunction(client, session, connection, jdbcType, parameter.getType())).orElseGet(() -> DefaultQueryBuilder.getWriteFunction(client, session, parameter.getType()));
            if (parameter.getValue().isEmpty()) {
                writeFunction.setNull(statement, parameterIndex);
                continue;
            }
            Class<?> javaType = writeFunction.getJavaType();
            Object value = parameter.getValue().orElseThrow(() -> new VerifyException("Value is missing"));
            if (javaType == Boolean.TYPE) {
                ((BooleanWriteFunction)writeFunction).set(statement, parameterIndex, (Boolean)value);
                continue;
            }
            if (javaType == Long.TYPE) {
                ((LongWriteFunction)writeFunction).set(statement, parameterIndex, (Long)value);
                continue;
            }
            if (javaType == Double.TYPE) {
                ((DoubleWriteFunction)writeFunction).set(statement, parameterIndex, (Double)value);
                continue;
            }
            if (javaType == Slice.class) {
                ((SliceWriteFunction)writeFunction).set(statement, parameterIndex, (Slice)value);
                continue;
            }
            ((ObjectWriteFunction)writeFunction).set(statement, parameterIndex, value);
        }
        return statement;
    }

    @Override
    public CallableStatement callProcedure(JdbcClient client, ConnectorSession session, Connection connection, JdbcProcedureHandle.ProcedureQuery procedureQuery) throws SQLException {
        return connection.prepareCall(procedureQuery.query());
    }

    protected String formatJoinCondition(JdbcClient client, String leftRelationAlias, String rightRelationAlias, JdbcJoinCondition condition) {
        return String.format("%s.%s %s %s.%s", leftRelationAlias, this.buildJoinColumn(client, condition.getLeftColumn()), condition.getOperator().getValue(), rightRelationAlias, this.buildJoinColumn(client, condition.getRightColumn()));
    }

    protected String buildJoinColumn(JdbcClient client, JdbcColumnHandle columnHandle) {
        return client.quoted(columnHandle.getColumnName());
    }

    protected String formatProjections(JdbcClient client, Map<JdbcColumnHandle, String> projections) {
        if (projections.isEmpty()) {
            return "1 x";
        }
        return projections.entrySet().stream().map(entry -> String.format("%s AS %s", client.quoted(((JdbcColumnHandle)entry.getKey()).getColumnName()), client.quoted((String)entry.getValue()))).collect(Collectors.joining(", "));
    }

    protected String formatProjectionAliases(JdbcClient client, Collection<String> aliases) {
        return aliases.stream().map(s -> String.format("%s", client.quoted((String)s))).collect(Collectors.joining(", "));
    }

    protected String formatAssignments(JdbcClient client, String relationAlias, Map<JdbcColumnHandle, String> assignments) {
        return assignments.entrySet().stream().map(entry -> String.format("%s.%s AS %s", relationAlias, client.quoted(((JdbcColumnHandle)entry.getKey()).getColumnName()), client.quoted((String)entry.getValue()))).collect(Collectors.joining(", "));
    }

    protected String formatJoinType(JoinType joinType) {
        return switch (joinType) {
            default -> throw new MatchException(null, null);
            case JoinType.INNER -> "INNER JOIN";
            case JoinType.LEFT_OUTER -> "LEFT JOIN";
            case JoinType.RIGHT_OUTER -> "RIGHT JOIN";
            case JoinType.FULL_OUTER -> "FULL JOIN";
        };
    }

    protected String getRelation(JdbcClient client, RemoteTableName remoteTableName) {
        return client.quoted(remoteTableName);
    }

    protected String getProjection(JdbcClient client, List<JdbcColumnHandle> columns, Map<String, ParameterizedExpression> columnExpressions, Consumer<QueryParameter> accumulator) {
        if (columns.isEmpty()) {
            return "1 x";
        }
        ArrayList<String> projections = new ArrayList<String>();
        for (JdbcColumnHandle jdbcColumnHandle : columns) {
            String columnAlias = client.quoted(jdbcColumnHandle.getColumnName());
            ParameterizedExpression expression = columnExpressions.get(jdbcColumnHandle.getColumnName());
            if (expression == null) {
                projections.add(columnAlias);
                continue;
            }
            projections.add(String.format("%s AS %s", expression.expression(), columnAlias));
            expression.parameters().forEach(accumulator);
        }
        return String.join((CharSequence)", ", projections);
    }

    protected String getFrom(JdbcClient client, JdbcRelationHandle baseRelation, Consumer<QueryParameter> accumulator) {
        if (baseRelation instanceof JdbcNamedRelationHandle) {
            JdbcNamedRelationHandle jdbcNamedRelationHandle = (JdbcNamedRelationHandle)baseRelation;
            return " FROM " + this.getRelation(client, jdbcNamedRelationHandle.getRemoteTableName());
        }
        if (baseRelation instanceof JdbcQueryRelationHandle) {
            JdbcQueryRelationHandle jdbcQueryRelationHandle = (JdbcQueryRelationHandle)baseRelation;
            PreparedQuery preparedQuery = jdbcQueryRelationHandle.getPreparedQuery();
            preparedQuery.parameters().forEach(accumulator);
            return " FROM (" + preparedQuery.query() + ") o";
        }
        throw new IllegalArgumentException("Unsupported relation: " + String.valueOf(baseRelation));
    }

    protected Domain pushDownDomain(JdbcClient client, ConnectorSession session, Connection connection, JdbcColumnHandle column, Domain domain) {
        return client.toColumnMapping(session, connection, column.getJdbcTypeHandle()).orElseThrow(() -> new IllegalStateException(String.format("Unsupported type %s with handle %s", column.getColumnType(), column.getJdbcTypeHandle()))).getPredicatePushdownController().apply(session, domain).getPushedDown();
    }

    protected void toConjuncts(JdbcClient client, ConnectorSession session, Connection connection, TupleDomain<ColumnHandle> tupleDomain, ImmutableList.Builder<String> result, Consumer<QueryParameter> accumulator) {
        if (tupleDomain.isNone()) {
            result.add((Object)ALWAYS_FALSE);
            return;
        }
        for (Map.Entry entry : ((Map)tupleDomain.getDomains().get()).entrySet()) {
            JdbcColumnHandle column = (JdbcColumnHandle)entry.getKey();
            Domain domain = this.pushDownDomain(client, session, connection, column, (Domain)entry.getValue());
            result.add((Object)this.toPredicate(client, session, connection, column, domain, accumulator));
        }
    }

    protected String toPredicate(JdbcClient client, ConnectorSession session, Connection connection, JdbcColumnHandle column, Domain domain, Consumer<QueryParameter> accumulator) {
        if (domain.getValues().isNone()) {
            return domain.isNullAllowed() ? client.quoted(column.getColumnName()) + " IS NULL" : ALWAYS_FALSE;
        }
        if (domain.getValues().isAll()) {
            return domain.isNullAllowed() ? ALWAYS_TRUE : client.quoted(column.getColumnName()) + " IS NOT NULL";
        }
        String predicate = this.toPredicate(client, session, connection, column, domain.getValues(), accumulator);
        if (!domain.isNullAllowed()) {
            return predicate;
        }
        return String.format("(%s OR %s IS NULL)", predicate, client.quoted(column.getColumnName()));
    }

    protected String toPredicate(JdbcClient client, ConnectorSession session, Connection connection, JdbcColumnHandle column, ValueSet valueSet, Consumer<QueryParameter> accumulator) {
        ValueSet complement;
        Preconditions.checkArgument((!valueSet.isNone() ? 1 : 0) != 0, (Object)"none values should be handled earlier");
        if (!valueSet.isDiscreteSet() && (complement = valueSet.complement()).isDiscreteSet()) {
            return String.format("NOT (%s)", this.toPredicate(client, session, connection, column, complement, accumulator));
        }
        JdbcTypeHandle jdbcType = column.getJdbcTypeHandle();
        Type type = column.getColumnType();
        WriteFunction writeFunction = DefaultQueryBuilder.getWriteFunction(client, session, connection, jdbcType, type);
        ArrayList<Object> disjuncts = new ArrayList<Object>();
        ArrayList<Object> singleValues = new ArrayList<Object>();
        for (Range range : valueSet.getRanges().getOrderedRanges()) {
            Preconditions.checkState((!range.isAll() ? 1 : 0) != 0);
            if (range.isSingleValue()) {
                singleValues.add(range.getSingleValue());
                continue;
            }
            ArrayList<String> rangeConjuncts = new ArrayList<String>();
            if (!range.isLowUnbounded()) {
                rangeConjuncts.add(this.toPredicate(client, session, column, jdbcType, type, writeFunction, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), accumulator));
            }
            if (!range.isHighUnbounded()) {
                rangeConjuncts.add(this.toPredicate(client, session, column, jdbcType, type, writeFunction, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), accumulator));
            }
            Preconditions.checkState((!rangeConjuncts.isEmpty() ? 1 : 0) != 0);
            if (rangeConjuncts.size() == 1) {
                disjuncts.add((String)Iterables.getOnlyElement(rangeConjuncts));
                continue;
            }
            disjuncts.add("(" + Joiner.on((String)" AND ").join(rangeConjuncts) + ")");
        }
        if (singleValues.size() == 1) {
            disjuncts.add(this.toPredicate(client, session, column, jdbcType, type, writeFunction, "=", Iterables.getOnlyElement(singleValues), accumulator));
        } else if (singleValues.size() > 1) {
            for (Object e : singleValues) {
                accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(e)));
            }
            String values = Joiner.on((String)",").join(Collections.nCopies(singleValues.size(), writeFunction.getBindExpression()));
            disjuncts.add(client.quoted(column.getColumnName()) + " IN (" + values + ")");
        }
        Preconditions.checkState((!disjuncts.isEmpty() ? 1 : 0) != 0);
        if (disjuncts.size() == 1) {
            return (String)Iterables.getOnlyElement(disjuncts);
        }
        return "(" + Joiner.on((String)" OR ").join(disjuncts) + ")";
    }

    protected String toPredicate(JdbcClient client, ConnectorSession session, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer<QueryParameter> accumulator) {
        accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value)));
        return String.format("%s %s %s", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression());
    }

    protected String getGroupBy(JdbcClient client, Optional<List<List<JdbcColumnHandle>>> groupingSets) {
        if (groupingSets.isEmpty()) {
            return "";
        }
        Verify.verify((!groupingSets.get().isEmpty() ? 1 : 0) != 0);
        if (groupingSets.get().size() == 1) {
            List groupingSet2 = (List)Iterables.getOnlyElement((Iterable)groupingSets.get());
            if (groupingSet2.isEmpty()) {
                return "";
            }
            return " GROUP BY " + groupingSet2.stream().map(JdbcColumnHandle::getColumnName).map(client::quoted).collect(Collectors.joining(", "));
        }
        return " GROUP BY GROUPING SETS " + groupingSets.get().stream().map(groupingSet -> groupingSet.stream().map(JdbcColumnHandle::getColumnName).map(client::quoted).collect(Collectors.joining(", ", "(", ")"))).collect(Collectors.joining(", ", "(", ")"));
    }

    private static WriteFunction getWriteFunction(JdbcClient client, ConnectorSession session, Connection connection, JdbcTypeHandle jdbcType, Type type) {
        WriteFunction writeFunction = client.toColumnMapping(session, connection, jdbcType).orElseThrow(() -> new VerifyException(String.format("Unsupported type %s with handle %s", type, jdbcType))).getWriteFunction();
        Verify.verify((writeFunction.getJavaType() == type.getJavaType() ? 1 : 0) != 0, (String)"Java type mismatch: %s, %s", (Object)writeFunction, (Object)type);
        return writeFunction;
    }

    private static WriteFunction getWriteFunction(JdbcClient client, ConnectorSession session, Type type) {
        return client.toWriteMapping(session, type).getWriteFunction();
    }
}

