/*
 * Decompiled with CFR 0.152.
 */
package io.seata.rm.datasource.exec;

import com.google.common.collect.Lists;
import io.seata.common.exception.NotSupportYetException;
import io.seata.common.exception.ShouldNeverHappenException;
import io.seata.common.util.CollectionUtils;
import io.seata.rm.datasource.ColumnUtils;
import io.seata.rm.datasource.PreparedStatementProxy;
import io.seata.rm.datasource.StatementProxy;
import io.seata.rm.datasource.exec.AbstractDMLBaseExecutor;
import io.seata.rm.datasource.exec.InsertExecutor;
import io.seata.rm.datasource.exec.StatementCallback;
import io.seata.rm.datasource.sql.struct.ColumnMeta;
import io.seata.rm.datasource.sql.struct.TableRecords;
import io.seata.sqlparser.SQLInsertRecognizer;
import io.seata.sqlparser.SQLRecognizer;
import io.seata.sqlparser.struct.Null;
import io.seata.sqlparser.struct.Sequenceable;
import io.seata.sqlparser.struct.SqlDefaultExpr;
import io.seata.sqlparser.struct.SqlMethodExpr;
import io.seata.sqlparser.struct.SqlSequenceExpr;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseInsertExecutor<T, S extends Statement>
extends AbstractDMLBaseExecutor<T, S>
implements InsertExecutor<T> {
    private static final Logger LOGGER = LoggerFactory.getLogger(BaseInsertExecutor.class);
    protected static final String PLACEHOLDER = "?";

    public BaseInsertExecutor(StatementProxy<S> statementProxy, StatementCallback<T, S> statementCallback, SQLRecognizer sqlRecognizer) {
        super(statementProxy, statementCallback, sqlRecognizer);
    }

    @Override
    protected TableRecords beforeImage() throws SQLException {
        return TableRecords.empty(this.getTableMeta());
    }

    @Override
    protected TableRecords afterImage(TableRecords beforeImage) throws SQLException {
        Map<String, List<Object>> pkValues = this.getPkValues();
        TableRecords afterImage = this.buildTableRecords(pkValues);
        if (afterImage == null) {
            throw new SQLException("Failed to build after-image for insert");
        }
        return afterImage;
    }

    protected boolean containsPK() {
        SQLInsertRecognizer recognizer = (SQLInsertRecognizer)this.sqlRecognizer;
        List<String> insertColumns = recognizer.getInsertColumns();
        if (CollectionUtils.isEmpty(insertColumns)) {
            return false;
        }
        return this.containsPK(insertColumns);
    }

    protected boolean containsColumns() {
        return !((SQLInsertRecognizer)this.sqlRecognizer).insertColumnsIsEmpty();
    }

    protected Map<String, Integer> getPkIndex() {
        HashMap<String, Integer> pkIndexMap = new HashMap<String, Integer>();
        SQLInsertRecognizer recognizer = (SQLInsertRecognizer)this.sqlRecognizer;
        List<String> insertColumns = recognizer.getInsertColumns();
        if (CollectionUtils.isNotEmpty(insertColumns)) {
            int insertColumnsSize = insertColumns.size();
            for (int paramIdx = 0; paramIdx < insertColumnsSize; ++paramIdx) {
                String sqlColumnName = insertColumns.get(paramIdx);
                if (!this.containPK(sqlColumnName)) continue;
                pkIndexMap.put(this.getStandardPkColumnName(sqlColumnName), paramIdx);
            }
            return pkIndexMap;
        }
        int pkIndex = -1;
        Map<String, ColumnMeta> allColumns = this.getTableMeta().getAllColumns();
        for (Map.Entry<String, ColumnMeta> entry : allColumns.entrySet()) {
            ++pkIndex;
            if (!this.containPK(entry.getValue().getColumnName())) continue;
            pkIndexMap.put(ColumnUtils.delEscape(entry.getValue().getColumnName(), this.getDbType()), pkIndex);
        }
        return pkIndexMap;
    }

    protected Map<String, List<Object>> parsePkValuesFromStatement() {
        SQLInsertRecognizer recognizer = (SQLInsertRecognizer)this.sqlRecognizer;
        Map<String, Integer> pkIndexMap = this.getPkIndex();
        if (pkIndexMap.isEmpty()) {
            throw new ShouldNeverHappenException("pkIndex is not found");
        }
        HashMap<String, List<Object>> pkValuesMap = new HashMap<String, List<Object>>();
        boolean ps = true;
        if (this.statementProxy instanceof PreparedStatementProxy) {
            PreparedStatementProxy preparedStatementProxy = (PreparedStatementProxy)this.statementProxy;
            List<List<Object>> insertRows = recognizer.getInsertRows(pkIndexMap.values());
            if (insertRows != null && !insertRows.isEmpty()) {
                Map<Integer, ArrayList<Object>> parameters = preparedStatementProxy.getParameters();
                int rowSize = insertRows.size();
                int totalPlaceholderNum = -1;
                for (List<Object> row : insertRows) {
                    if (row.isEmpty()) continue;
                    int currentRowPlaceholderNum = -1;
                    for (Object r : row) {
                        if (!PLACEHOLDER.equals(r)) continue;
                        ++totalPlaceholderNum;
                        ++currentRowPlaceholderNum;
                    }
                    for (Map.Entry<String, Integer> entry : pkIndexMap.entrySet()) {
                        int pkIndex2;
                        Object pkValue;
                        String pkKey2 = entry.getKey();
                        ArrayList<Object> pkValues = (ArrayList<Object>)pkValuesMap.get(pkKey2);
                        if (Objects.isNull(pkValues)) {
                            pkValues = new ArrayList<Object>(rowSize);
                        }
                        if (PLACEHOLDER.equals(pkValue = row.get(pkIndex2 = entry.getValue().intValue()))) {
                            int currentRowNotPlaceholderNumBeforePkIndex = 0;
                            int len = row.size();
                            for (int n = 0; n < len; ++n) {
                                Object r = row.get(n);
                                if (n >= pkIndex2 || PLACEHOLDER.equals(r)) continue;
                                ++currentRowNotPlaceholderNumBeforePkIndex;
                            }
                            int idx = totalPlaceholderNum - currentRowPlaceholderNum + pkIndex2 - currentRowNotPlaceholderNumBeforePkIndex;
                            ArrayList<Object> parameter = parameters.get(idx + 1);
                            pkValues.addAll(parameter);
                        } else {
                            pkValues.add(pkValue);
                        }
                        if (pkValuesMap.containsKey(ColumnUtils.delEscape(pkKey2, this.getDbType()))) continue;
                        pkValuesMap.put(ColumnUtils.delEscape(pkKey2, this.getDbType()), pkValues);
                    }
                }
            }
        } else {
            ps = false;
            List<List<Object>> insertRows = recognizer.getInsertRows(pkIndexMap.values());
            for (List<Object> row : insertRows) {
                pkIndexMap.forEach((pkKey, pkIndex) -> {
                    List pkValues = (List)pkValuesMap.get(pkKey);
                    if (Objects.isNull(pkValues)) {
                        pkValuesMap.put(ColumnUtils.delEscape(pkKey, this.getDbType()), Lists.newArrayList((Object[])new Object[]{row.get((int)pkIndex)}));
                    } else {
                        pkValues.add(row.get((int)pkIndex));
                    }
                });
            }
        }
        if (pkValuesMap.isEmpty()) {
            throw new ShouldNeverHappenException();
        }
        boolean b = this.checkPkValues(pkValuesMap, ps);
        if (!b) {
            throw new NotSupportYetException(String.format("not support sql [%s]", this.sqlRecognizer.getOriginalSQL()));
        }
        return pkValuesMap;
    }

    public List<Object> getGeneratedKeys() throws SQLException {
        ResultSet genKeys = this.statementProxy.getGeneratedKeys();
        ArrayList<Object> pkValues = new ArrayList<Object>();
        while (genKeys.next()) {
            Object v = genKeys.getObject(1);
            pkValues.add(v);
        }
        if (pkValues.isEmpty()) {
            throw new NotSupportYetException(String.format("not support sql [%s]", this.sqlRecognizer.getOriginalSQL()));
        }
        try {
            genKeys.beforeFirst();
        }
        catch (SQLException e) {
            LOGGER.warn("Fail to reset ResultSet cursor. can not get primary key value");
        }
        return pkValues;
    }

    protected List<Object> getPkValuesBySequence(SqlSequenceExpr expr) throws SQLException {
        List<Object> pkValues = null;
        try {
            pkValues = this.getGeneratedKeys();
        }
        catch (NotSupportYetException | SQLException exception) {
            // empty catch block
        }
        if (!CollectionUtils.isEmpty(pkValues)) {
            return pkValues;
        }
        Sequenceable sequenceable = (Sequenceable)((Object)this);
        String sql = sequenceable.getSequenceSql(expr);
        LOGGER.warn("Fail to get auto-generated keys, use '{}' instead. Be cautious, statement could be polluted. Recommend you set the statement to return generated keys.", (Object)sql);
        ResultSet genKeys = this.statementProxy.getConnection().createStatement().executeQuery(sql);
        pkValues = new ArrayList<Object>();
        while (genKeys.next()) {
            Object v = genKeys.getObject(1);
            pkValues.add(v);
        }
        return pkValues;
    }

    protected boolean checkPkValuesForMultiPk(Map<String, List<Object>> pkValues) {
        Set<String> pkNames = pkValues.keySet();
        if (pkNames.isEmpty()) {
            throw new ShouldNeverHappenException();
        }
        int rowSize = pkValues.get(pkNames.iterator().next()).size();
        for (int i = 0; i < rowSize; ++i) {
            int n = 0;
            int m = 0;
            for (String name : pkNames) {
                Object pkValue = pkValues.get(name).get(i);
                if (pkValue instanceof Null) {
                    ++n;
                }
                if (!(pkValue instanceof SqlMethodExpr)) continue;
                ++m;
            }
            if (n > 1) {
                return false;
            }
            if (m <= 0) continue;
            return false;
        }
        return true;
    }

    protected boolean checkPkValues(Map<String, List<Object>> pkValues, boolean ps) {
        Set<String> pkNames = pkValues.keySet();
        if (pkNames.size() == 1) {
            return this.checkPkValuesForSinglePk(pkValues.get(pkNames.iterator().next()), ps);
        }
        return this.checkPkValuesForMultiPk(pkValues);
    }

    protected boolean checkPkValuesForSinglePk(List<Object> pkValues, boolean ps) {
        int n = 0;
        int v = 0;
        int m = 0;
        int s = 0;
        int d = 0;
        for (Object pkValue : pkValues) {
            if (pkValue instanceof Null) {
                ++n;
                continue;
            }
            if (pkValue instanceof SqlMethodExpr) {
                ++m;
                continue;
            }
            if (pkValue instanceof SqlSequenceExpr) {
                ++s;
                continue;
            }
            if (pkValue instanceof SqlDefaultExpr) {
                ++d;
                continue;
            }
            ++v;
        }
        if (!ps) {
            if (m > 0) {
                return false;
            }
            if (n == 1 && v == 0 && m == 0 && s == 0 && d == 0) {
                return true;
            }
            if (n == 0 && v > 0 && m == 0 && s == 0 && d == 0) {
                return true;
            }
            if (n == 0 && v == 0 && m == 0 && s == 1 && d == 0) {
                return true;
            }
            return n == 0 && v == 0 && m == 0 && s == 0 && d == 1;
        }
        if (n > 0 && v == 0 && m == 0 && s == 0 && d == 0) {
            return true;
        }
        if (n == 0 && v > 0 && m == 0 && s == 0 && d == 0) {
            return true;
        }
        if (n == 0 && v == 0 && m > 0 && s == 0 && d == 0) {
            return true;
        }
        if (n == 0 && v == 0 && m == 0 && s > 0 && d == 0) {
            return true;
        }
        return n == 0 && v == 0 && m == 0 && s == 0 && d > 0;
    }
}

