/*
 * Decompiled with CFR 0.152.
 */
package org.dflib.jdbc.connector.saver;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.BitSet;
import java.util.function.Supplier;
import org.dflib.DataFrame;
import org.dflib.GroupBy;
import org.dflib.Hasher;
import org.dflib.Index;
import org.dflib.IntSeries;
import org.dflib.RowToValueMapper;
import org.dflib.Series;
import org.dflib.jdbc.SaveOp;
import org.dflib.jdbc.connector.JdbcConnector;
import org.dflib.jdbc.connector.StatementBuilder;
import org.dflib.jdbc.connector.TableLoader;
import org.dflib.jdbc.connector.metadata.TableFQName;
import org.dflib.jdbc.connector.saver.TableSaveStrategy;
import org.dflib.jdbc.connector.saver.UpsertInfoTracker;
import org.dflib.join.JoinIndicator;
import org.dflib.row.RowProxy;
import org.dflib.series.SingleValueSeries;

public class SaveViaUpsert
extends TableSaveStrategy {
    private static final String INDICATOR_COLUMN = "dflib_ind_%$#86AcD3";
    private static final String DIFF_COLUMN = "dflib_dif_%4$#96Ac3";
    protected String[] keyColumns;

    public SaveViaUpsert(JdbcConnector connector, TableFQName tableName, String[] keyColumns, int batchSize) {
        super(connector, tableName, batchSize);
        this.keyColumns = keyColumns;
    }

    @Override
    protected Supplier<Series<SaveOp>> doInsertOrUpdate(JdbcConnector connector, DataFrame df) {
        DataFrame keyDf = this.keyValues(df);
        DataFrame previouslySaved = new TableLoader(connector, this.tableName).cols(df.getColumnsIndex().toArray()).eq(keyDf).load();
        if (previouslySaved.height() == 0) {
            this.doInsert(connector, df);
            return () -> new SingleValueSeries((Object)SaveOp.insert, df.height());
        }
        DataFrame insertAndUpdate = df.leftJoin(previouslySaved).on(this.keyHasher()).indicatorColumn(INDICATOR_COLUMN).select();
        Series index = insertAndUpdate.getColumn(INDICATOR_COLUMN);
        IntSeries insertIndex = index.index(i -> i == JoinIndicator.left_only);
        IntSeries updateIndex = index.index(i -> i == JoinIndicator.both);
        int heightDelta = insertAndUpdate.height() - df.height();
        if (heightDelta > 0) {
            String message = String.format("Duplicate rows in the database table %s using key columns %s. Specify key columns that produce unique DB rows.", this.tableName, Arrays.toString(this.keyColumns));
            throw new IllegalStateException(message);
        }
        if (heightDelta < 0) {
            throw new IllegalStateException();
        }
        UpsertInfoTracker infoTracker = new UpsertInfoTracker(df.width(), df.height());
        infoTracker.insertAndUpdate((Series<JoinIndicator>)index);
        if (insertIndex.size() > 0) {
            this.doInsert(connector, df.rows(insertIndex).select());
        }
        if (updateIndex.size() > 0) {
            Index mainColumns = df.getColumnsIndex();
            Index joinedIndex = insertAndUpdate.getColumnsIndex().selectRange(mainColumns.size(), mainColumns.size() * 2);
            DataFrame previouslySavedOrdered = insertAndUpdate.cols(joinedIndex).select().cols().as(mainColumns.toArray());
            this.doUpdate(connector, df.rows(updateIndex).select(), previouslySavedOrdered.rows(updateIndex).select(), infoTracker);
        }
        return infoTracker::getInfo;
    }

    protected DataFrame keyValues(DataFrame df) {
        return df.cols(this.keyColumns).select();
    }

    protected void doUpdate(JdbcConnector connector, DataFrame toSave, DataFrame previouslySaved, UpsertInfoTracker infoTracker) {
        int w = toSave.width();
        if (w == this.keyColumns.length) {
            this.log("All DataFrame columns are key columns. Skipping update.", new Object[0]);
            return;
        }
        DataFrame eqMatrix = toSave.eq(previouslySaved).colsAppend(new String[]{DIFF_COLUMN}).merge(new RowToValueMapper[]{this::booleansAsBitSet});
        DataFrame toSaveClassified = toSave.colsAppend(new String[]{DIFF_COLUMN}).merge(new Series[]{eqMatrix.getColumn(DIFF_COLUMN)});
        infoTracker.updatesCardinality((Series<BitSet>)toSaveClassified.getColumn(DIFF_COLUMN));
        GroupBy byUpdatePattern = toSaveClassified.group(new String[]{DIFF_COLUMN});
        for (Object o : byUpdatePattern.getGroupKeys()) {
            BitSet bits = (BitSet)o;
            int cardinality = bits.cardinality();
            if (cardinality == w) continue;
            DataFrame toUpdate = byUpdatePattern.getGroup((Object)bits);
            String[] updateColumns = new String[w - cardinality];
            int j = 0;
            for (int i = 0; i < w; ++i) {
                if (bits.get(i)) continue;
                updateColumns[j++] = toUpdate.getColumnsIndex().get(i);
            }
            Index valueIndex = Index.of((String[])updateColumns).selectExcept(this.keyColumns);
            Index valueAndKeyIndex = valueIndex.expand(this.keyColumns);
            StatementBuilder builder = connector.createStatementBuilder(this.createUpdateStatement(this.keyColumns, valueIndex.toArray())).paramDescriptors(this.fixedParams(valueAndKeyIndex)).bindBatch(toUpdate.cols(valueAndKeyIndex).select());
            try {
                Connection c = connector.getConnection();
                try {
                    builder.update(c);
                }
                finally {
                    if (c == null) continue;
                    c.close();
                }
            }
            catch (SQLException e) {
                throw new RuntimeException("Error closing DB connection", e);
            }
        }
    }

    protected BitSet booleansAsBitSet(RowProxy booleanRow) {
        int w = booleanRow.getIndex().size();
        BitSet s = new BitSet(w);
        for (int i = 0; i < w; ++i) {
            if (!((Boolean)booleanRow.get(i)).booleanValue()) continue;
            s.set(i);
        }
        return s;
    }

    protected Hasher keyHasher() {
        Hasher h = Hasher.of((String)this.keyColumns[0]);
        for (int i = 1; i < this.keyColumns.length; ++i) {
            h = h.and(this.keyColumns[i]);
        }
        return h;
    }

    protected String createUpdateStatement(String[] conditionColumns, String[] valueColumns) {
        int i;
        StringBuilder sql = new StringBuilder("update ").append(this.connector.quoteTableName(this.tableName)).append(" set ").append(this.connector.quoteIdentifier(valueColumns[0])).append(" = ?");
        for (i = 1; i < valueColumns.length; ++i) {
            sql.append(", ").append(this.connector.quoteIdentifier(valueColumns[i])).append(" = ?");
        }
        sql.append(" where ").append(this.connector.quoteIdentifier(conditionColumns[0])).append(" = ?");
        for (i = 1; i < conditionColumns.length; ++i) {
            sql.append(" and ").append(this.connector.quoteIdentifier(conditionColumns[i])).append(" = ?");
        }
        return sql.toString();
    }
}

