/*
 * Decompiled with CFR 0.152.
 */
package ai.knowly.langtorch.store.vectordb;

import ai.knowly.langtorch.processor.EmbeddingProcessor;
import ai.knowly.langtorch.schema.embeddings.EmbeddingInput;
import ai.knowly.langtorch.schema.embeddings.EmbeddingOutput;
import ai.knowly.langtorch.schema.io.DomainDocument;
import ai.knowly.langtorch.schema.io.Metadata;
import ai.knowly.langtorch.store.vectordb.integration.VectorStore;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.PGVectorService;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.SqlCommandProvider;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.PGVectorQueryParameters;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.PGVectorStoreSpec;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.PGVectorValues;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.distance.DistanceStrategy;
import ai.knowly.langtorch.store.vectordb.integration.schema.SimilaritySearchQuery;
import com.google.common.flogger.FluentLogger;
import com.google.common.primitives.Floats;
import com.google.inject.Inject;
import com.pgvector.PGvector;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import lombok.NonNull;

public class PGVectorStore
implements VectorStore {
    private static final int EMBEDDINGS_COLUMN_COUNT = 2;
    private static final int EMBEDDINGS_INDEX_ID = 0;
    private static final int EMBEDDINGS_INDEX_VECTOR = 1;
    private static final int METADATA_COLUMN_COUNT = 4;
    private static final int METADATA_INDEX_ID = 0;
    private static final int METADATA_INDEX_KEY = 1;
    private static final int METADATA_INDEX_VALUE = 2;
    private static final int METADATA_INDEX_VECTOR_ID = 3;
    private static final FluentLogger logger = FluentLogger.forEnclosingClass();
    @NonNull
    private final EmbeddingProcessor embeddingsProcessor;
    private final PGVectorStoreSpec pgVectorStoreSpec;
    private final SqlCommandProvider sqlCommandProvider;
    @NonNull
    private final PGVectorService pgVectorService;
    private final DistanceStrategy distanceStrategy;

    @Inject
    public PGVectorStore(@NonNull EmbeddingProcessor embeddingsProcessor, PGVectorStoreSpec pgVectorStoreSpec, @NonNull PGVectorService pgVectorService, DistanceStrategy distanceStrategy) throws SQLException {
        if (embeddingsProcessor == null) {
            throw new NullPointerException("embeddingsProcessor is marked non-null but is null");
        }
        if (pgVectorService == null) {
            throw new NullPointerException("pgVectorService is marked non-null but is null");
        }
        this.distanceStrategy = distanceStrategy;
        this.pgVectorService = pgVectorService;
        this.embeddingsProcessor = embeddingsProcessor;
        this.pgVectorStoreSpec = pgVectorStoreSpec;
        this.sqlCommandProvider = new SqlCommandProvider(pgVectorStoreSpec.getDatabaseName(), pgVectorStoreSpec.isOverwriteExistingTables());
        this.createNecessaryTables();
    }

    private void createNecessaryTables() throws SQLException {
        this.createEmbeddingsTable();
        this.createMetadataTable();
    }

    @Override
    public boolean addDocuments(List<DomainDocument> documents) {
        int metadataResult;
        int result;
        if (documents.isEmpty()) {
            return true;
        }
        PGVectorQueryParameters pgVectorQueryParameters = this.getVectorQueryParameters(documents);
        List<PGVectorValues> vectorValues = pgVectorQueryParameters.getVectorValues();
        try {
            PreparedStatement insertEmbeddingsStmt = this.pgVectorService.prepareStatement(this.sqlCommandProvider.getInsertEmbeddingsQuery(pgVectorQueryParameters.getVectorParameters()));
            PreparedStatement insertMetadataStmt = this.pgVectorService.prepareStatement(this.sqlCommandProvider.getInsertMetadataQuery(pgVectorQueryParameters.getMetadataParameters()));
            this.setQueryParameters(vectorValues, insertEmbeddingsStmt, insertMetadataStmt);
            result = insertEmbeddingsStmt.executeUpdate();
            metadataResult = insertMetadataStmt.executeUpdate();
        }
        catch (SQLException e) {
            ((FluentLogger.Api)((FluentLogger.Api)logger.atSevere()).withCause((Throwable)e)).log("Error with SQL Exception");
            return false;
        }
        return result == vectorValues.size() && metadataResult == pgVectorQueryParameters.getMetadataSize();
    }

    @Override
    public List<DomainDocument> similaritySearch(SimilaritySearchQuery similaritySearchQuery) {
        ArrayList<DomainDocument> documentsWithScores;
        float[] queryVectorValuesAsFloats = this.getFloatVectorValues(similaritySearchQuery.getQuery());
        double[] queryVectorValuesAsDoubles = this.getDoubleVectorValues(queryVectorValuesAsFloats);
        LinkedHashMap<String, DomainDocument> documentsWithScoresMap = new LinkedHashMap<String, DomainDocument>();
        try {
            PreparedStatement neighborStmt = this.pgVectorService.prepareStatement(this.sqlCommandProvider.getSelectEmbeddingsQuery(this.distanceStrategy.getSyntax(), similaritySearchQuery.getTopK()));
            neighborStmt.setObject(1, new PGvector(queryVectorValuesAsFloats));
            ResultSet result = neighborStmt.executeQuery();
            while (result.next()) {
                String vectorId = (String)result.getObject(1);
                PGvector pGvector = (PGvector)result.getObject(2);
                String key = (String)result.getObject(3);
                String value = (String)result.getObject(4);
                double[] currentVector = this.getDoubleVectorValues(pGvector.toArray());
                double score = this.distanceStrategy.calculateDistance(queryVectorValuesAsDoubles, currentVector);
                documentsWithScoresMap.computeIfAbsent(vectorId, s -> {
                    Metadata defaultMetadata = Metadata.builder().build();
                    return DomainDocument.builder().setId(vectorId).setPageContent("").setSimilarityScore(Optional.of(score)).setMetadata(defaultMetadata).build();
                });
                DomainDocument documentWithScore = (DomainDocument)documentsWithScoresMap.get(vectorId);
                this.saveValueToMetadataIfPresent(documentWithScore, key, value);
                documentsWithScoresMap.put(vectorId, this.getDocumentWithScoreWithPageContent(documentWithScore, key, value));
            }
            documentsWithScores = new ArrayList<DomainDocument>(documentsWithScoresMap.values());
        }
        catch (SQLException e) {
            ((FluentLogger.Api)((FluentLogger.Api)logger.atSevere()).withCause((Throwable)e)).log("Error with SQL Exception");
            return new ArrayList<DomainDocument>(documentsWithScoresMap.values());
        }
        return documentsWithScores;
    }

    private void createEmbeddingsTable() throws SQLException {
        this.pgVectorService.executeUpdate(this.sqlCommandProvider.getCreateEmbeddingsTableQuery(this.pgVectorStoreSpec.getVectorDimensions()));
    }

    private void createMetadataTable() throws SQLException {
        this.pgVectorService.executeUpdate(this.sqlCommandProvider.getCreateMetadataTableQuery());
    }

    private PGVectorQueryParameters getVectorQueryParameters(List<DomainDocument> documents) {
        ArrayList<PGVectorValues> vectorValues = new ArrayList<PGVectorValues>();
        StringBuilder vectorParameters = new StringBuilder();
        StringBuilder metadataParameters = new StringBuilder();
        int metadataSize = 0;
        for (DomainDocument document : documents) {
            List<Double> vector = this.createVector(document);
            String id = document.getId().orElse(UUID.randomUUID().toString());
            vectorValues.add(this.buildPGVectorValues(id, vector, document.getMetadata()));
            vectorParameters.append(this.getVectorParameters());
            metadataSize += this.processMetadata(metadataParameters, document.getMetadata());
        }
        this.trimStringBuilder(vectorParameters);
        this.trimStringBuilder(metadataParameters);
        return this.buildPGVectorQueryParameters(vectorValues, vectorParameters.toString(), metadataParameters.toString(), metadataSize);
    }

    private PGVectorValues buildPGVectorValues(String id, List<Double> vector, Optional<Metadata> metadata) {
        return PGVectorValues.builder().setId(id).setValues(this.getFloatVectorValues(vector)).setMetadata(metadata.orElse(Metadata.builder().build())).build();
    }

    private String getVectorParameters() {
        return "(?, ?), ";
    }

    private int processMetadata(StringBuilder metadataParameters, Optional<Metadata> metadata) {
        int metadataSize = 0;
        if (!metadata.isPresent()) {
            return metadataSize;
        }
        metadataSize += metadata.get().getValue().size();
        for (int i = 0; i < metadata.get().getValue().entrySet().size(); ++i) {
            metadataParameters.append("(?, ?, ?, ?), ");
        }
        return metadataSize;
    }

    private void trimStringBuilder(StringBuilder stringBuilder) {
        int index = stringBuilder.lastIndexOf(", ");
        if (index > 0) {
            stringBuilder.delete(index, stringBuilder.length());
        }
    }

    private PGVectorQueryParameters buildPGVectorQueryParameters(List<PGVectorValues> vectorValues, String vectorParameters, String metadataParameters, int metadataSize) {
        return PGVectorQueryParameters.builder().setVectorValues(vectorValues).setVectorParameters(vectorParameters).setMetadataParameters(metadataParameters).setMetadataSize(metadataSize).build();
    }

    private List<Double> createVector(DomainDocument document) {
        EmbeddingOutput embeddingOutput = (EmbeddingOutput)this.embeddingsProcessor.run(EmbeddingInput.builder().setModel(this.pgVectorStoreSpec.getModel()).setInput(Collections.singletonList(document.getPageContent())).build());
        return embeddingOutput.getValue().get(0).getVector();
    }

    private int setMetadataQueryParameters(PGVectorValues values, int parameterIndex, PreparedStatement insertStmt) throws SQLException {
        for (Map.Entry<String, String> entry : values.getMetadata().getValue().entrySet()) {
            for (int j = 0; j < 4; ++j) {
                switch (j) {
                    case 0: {
                        String id = values.getId() + entry.getKey();
                        insertStmt.setString(parameterIndex, id);
                        break;
                    }
                    case 1: {
                        insertStmt.setString(parameterIndex, entry.getKey());
                        break;
                    }
                    case 2: {
                        insertStmt.setString(parameterIndex, entry.getValue());
                        break;
                    }
                    case 3: {
                        insertStmt.setString(parameterIndex, values.getId());
                        break;
                    }
                    default: {
                        ((FluentLogger.Api)logger.atSevere()).log("INVALID COLUM INDEX");
                    }
                }
                ++parameterIndex;
            }
        }
        return parameterIndex;
    }

    private int setVectorQueryParameters(PGVectorValues values, int parameterIndex, PreparedStatement insertStmt) throws SQLException {
        for (int i = 0; i < 2; ++i) {
            if (i == 0) {
                insertStmt.setString(parameterIndex, values.getId());
            } else if (i == 1) {
                insertStmt.setObject(parameterIndex, new PGvector(values.getValues()));
            }
            ++parameterIndex;
        }
        return parameterIndex;
    }

    private void setQueryParameters(List<PGVectorValues> vectorValues, PreparedStatement insertEmbeddingsStmt, PreparedStatement insertMetadataStmt) throws SQLException {
        int embeddingParameterIndex = 1;
        int metadataParameterIndex = 1;
        for (PGVectorValues values : vectorValues) {
            embeddingParameterIndex = this.setVectorQueryParameters(values, embeddingParameterIndex, insertEmbeddingsStmt);
            metadataParameterIndex = this.setMetadataQueryParameters(values, metadataParameterIndex, insertMetadataStmt);
        }
    }

    private void saveValueToMetadataIfPresent(DomainDocument document, String key, String value) {
        Optional<Metadata> metadata = document.getMetadata();
        if (!metadata.isPresent() || key == null) {
            return;
        }
        metadata.get().getValue().put(key, value);
    }

    private DomainDocument getDocumentWithScoreWithPageContent(DomainDocument documentWithScore, String key, String value) {
        if (key == null) {
            return documentWithScore;
        }
        Optional<String> textKey = this.pgVectorStoreSpec.getTextKey();
        if (!textKey.isPresent()) {
            return documentWithScore;
        }
        boolean isTextKey = key.equals(textKey.get());
        if (!isTextKey) {
            return documentWithScore;
        }
        return documentWithScore.toBuilder().setPageContent(value).build();
    }

    private float[] getFloatVectorValues(List<Double> vectorValues) {
        return Floats.toArray(vectorValues);
    }

    private double[] getDoubleVectorValues(float[] vectorValues) {
        double[] doubles = new double[vectorValues.length];
        for (int i = 0; i < vectorValues.length; ++i) {
            doubles[i] = vectorValues[i];
        }
        return doubles;
    }
}

