/*
 * Decompiled with CFR 0.152.
 */
package com.dtsx.astra.sdk.cassio;

import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.cql.Statement;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.uuid.Uuids;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
import com.datastax.oss.driver.api.querybuilder.term.Term;
import com.dtsx.astra.sdk.cassio.AbstractCassandraTable;
import com.dtsx.astra.sdk.cassio.SimilarityMetric;
import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery;
import com.dtsx.astra.sdk.cassio.SimilaritySearchResult;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MetadataVectorCassandraTable
extends AbstractCassandraTable<Record> {
    private static final Logger log = LoggerFactory.getLogger(MetadataVectorCassandraTable.class);
    private final int vectorDimension;
    private final SimilarityMetric similarityMetric;

    public MetadataVectorCassandraTable(CqlSession session, String keyspaceName, String tableName, int vectorDimension) {
        this(session, keyspaceName, tableName, vectorDimension, SimilarityMetric.DOT_PRODUCT);
    }

    public MetadataVectorCassandraTable(CqlSession session, String keyspaceName, String tableName, int vectorDimension, SimilarityMetric metric) {
        super(session, keyspaceName, tableName);
        this.vectorDimension = vectorDimension;
        this.similarityMetric = metric;
        this.createSchema();
    }

    @Override
    public void createSchema() {
        this.cqlSession.execute("CREATE TABLE IF NOT EXISTS " + this.tableName + " (" + "row_id" + " text, " + "attributes_blob" + " text, " + "body_blob" + " text, " + "metadata_s" + " map<text, text>, " + "vector" + " vector<float, " + this.vectorDimension + ">, PRIMARY KEY (" + "row_id" + "))");
        log.info("+ Table '{}' has been created (if needed).", (Object)this.tableName);
        HashMap<String, String> optionMap = new HashMap<String, String>();
        optionMap.put("similarity_function", this.similarityMetric.getOption());
        this.cqlSession.execute((Statement)SchemaBuilder.createIndex((String)("idx_vector_" + this.tableName)).ifNotExists().custom("org.apache.cassandra.index.sai.StorageAttachedIndex").onTable(this.tableName).andColumn("vector").withSASIOptions(optionMap).build());
        log.info("+ Index '{}' has been created (if needed).", (Object)("idx_vector_" + this.tableName));
        this.cqlSession.execute((Statement)SchemaBuilder.createIndex((String)("eidx_metadata_s_" + this.tableName)).ifNotExists().custom("org.apache.cassandra.index.sai.StorageAttachedIndex").onTable(this.tableName).andColumnEntries("metadata_s").build());
        log.info("+ Index '{}' has been created (if needed).", (Object)("eidx_metadata_s_" + this.tableName));
    }

    @Override
    public void put(Record row) {
        this.cqlSession.execute((Statement)row.insertStatement(this.keyspaceName, this.tableName));
    }

    private SimilaritySearchResult<Record> mapResult(Row cqlRow) {
        if (cqlRow == null) {
            return null;
        }
        SimilaritySearchResult<Record> res = new SimilaritySearchResult<Record>();
        res.setEmbedded(this.mapRow(cqlRow));
        res.setSimilarity(cqlRow.getFloat("similarity"));
        return res;
    }

    @Override
    public Record mapRow(Row cqlRow) {
        if (cqlRow == null) {
            return null;
        }
        Record record = new Record();
        record.setRowId(cqlRow.getString("row_id"));
        record.setBody(cqlRow.getString("body_blob"));
        record.setVector(((CqlVector)Objects.requireNonNull(cqlRow.getObject("vector"))).stream().collect(Collectors.toList()));
        if (cqlRow.getColumnDefinitions().contains("attributes_blob")) {
            record.setAttributes(cqlRow.getString("attributes_blob"));
        }
        if (cqlRow.getColumnDefinitions().contains("metadata_s")) {
            record.setMetadata(cqlRow.getMap("metadata_s", String.class, String.class));
        }
        return record;
    }

    public List<SimilaritySearchResult<Record>> similaritySearch(SimilaritySearchQuery query) {
        StringBuilder cqlQuery = new StringBuilder("SELECT row_id,vector,body_blob,attributes_blob,metadata_s,");
        cqlQuery.append(query.getDistance().getFunction()).append("(vector, :vector) as ").append("similarity");
        cqlQuery.append(" FROM ").append(this.tableName);
        if (query.getMetaData() != null && !query.getMetaData().isEmpty()) {
            cqlQuery.append(" WHERE ");
            boolean first = true;
            for (Map.Entry<String, String> entry : query.getMetaData().entrySet()) {
                if (!first) {
                    cqlQuery.append(" AND ");
                }
                cqlQuery.append("metadata_s").append("['").append(entry.getKey()).append("'] = '").append(entry.getValue()).append("'");
                first = false;
            }
        }
        cqlQuery.append(" ORDER BY vector ANN OF :vector ");
        cqlQuery.append(" LIMIT :maxRecord");
        return this.cqlSession.execute((Statement)SimpleStatement.builder((String)cqlQuery.toString()).addNamedValue("vector", (Object)CqlVector.newInstance(query.getEmbeddings())).addNamedValue("maxRecord", (Object)(query.getRecordCount() > 0 ? query.getRecordCount() : 4)).build()).all().stream().map(this::mapResult).filter(r -> (double)r.getSimilarity() >= query.getThreshold()).collect(Collectors.toList());
    }

    public int getVectorDimension() {
        return this.vectorDimension;
    }

    public SimilarityMetric getSimilarityMetric() {
        return this.similarityMetric;
    }

    public static class Record
    implements Serializable {
        private String rowId;
        private String attributes;
        private String body;
        private Map<String, String> metadata = new HashMap<String, String>();
        private List<Float> vector;

        public Record() {
            this(Uuids.timeBased().toString(), null);
        }

        public Record(List<Float> vector) {
            this(Uuids.timeBased().toString(), vector);
        }

        public Record(String rowId, List<Float> vector) {
            this.rowId = rowId;
            this.vector = vector;
        }

        public SimpleStatement insertStatement(String keyspaceName, String tableName) {
            if (this.rowId == null) {
                throw new IllegalStateException("Row Id cannot be null");
            }
            if (this.vector == null) {
                throw new IllegalStateException("Vector cannot be null");
            }
            RegularInsert insert = QueryBuilder.insertInto((String)keyspaceName, (String)tableName).value("row_id", (Term)QueryBuilder.literal((Object)this.rowId)).value("vector", (Term)QueryBuilder.literal((Object)CqlVector.newInstance(this.vector)));
            if (this.attributes != null) {
                insert = insert.value("attributes_blob", (Term)QueryBuilder.literal((Object)this.attributes));
            }
            if (this.body != null) {
                insert = insert.value("body_blob", (Term)QueryBuilder.literal((Object)this.body));
            }
            if (this.metadata != null && !this.metadata.isEmpty()) {
                insert = insert.value("metadata_s", (Term)QueryBuilder.literal(this.metadata));
            }
            return insert.build();
        }

        public String getRowId() {
            return this.rowId;
        }

        public String getAttributes() {
            return this.attributes;
        }

        public String getBody() {
            return this.body;
        }

        public Map<String, String> getMetadata() {
            return this.metadata;
        }

        public List<Float> getVector() {
            return this.vector;
        }

        public void setRowId(String rowId) {
            this.rowId = rowId;
        }

        public void setAttributes(String attributes) {
            this.attributes = attributes;
        }

        public void setBody(String body) {
            this.body = body;
        }

        public void setMetadata(Map<String, String> metadata) {
            this.metadata = metadata;
        }

        public void setVector(List<Float> vector) {
            this.vector = vector;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Record)) {
                return false;
            }
            Record other = (Record)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$rowId = this.getRowId();
            String other$rowId = other.getRowId();
            if (this$rowId == null ? other$rowId != null : !this$rowId.equals(other$rowId)) {
                return false;
            }
            String this$attributes = this.getAttributes();
            String other$attributes = other.getAttributes();
            if (this$attributes == null ? other$attributes != null : !this$attributes.equals(other$attributes)) {
                return false;
            }
            String this$body = this.getBody();
            String other$body = other.getBody();
            if (this$body == null ? other$body != null : !this$body.equals(other$body)) {
                return false;
            }
            Map<String, String> this$metadata = this.getMetadata();
            Map<String, String> other$metadata = other.getMetadata();
            if (this$metadata == null ? other$metadata != null : !((Object)this$metadata).equals(other$metadata)) {
                return false;
            }
            List<Float> this$vector = this.getVector();
            List<Float> other$vector = other.getVector();
            return !(this$vector == null ? other$vector != null : !((Object)this$vector).equals(other$vector));
        }

        protected boolean canEqual(Object other) {
            return other instanceof Record;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $rowId = this.getRowId();
            result = result * 59 + ($rowId == null ? 43 : $rowId.hashCode());
            String $attributes = this.getAttributes();
            result = result * 59 + ($attributes == null ? 43 : $attributes.hashCode());
            String $body = this.getBody();
            result = result * 59 + ($body == null ? 43 : $body.hashCode());
            Map<String, String> $metadata = this.getMetadata();
            result = result * 59 + ($metadata == null ? 43 : ((Object)$metadata).hashCode());
            List<Float> $vector = this.getVector();
            result = result * 59 + ($vector == null ? 43 : ((Object)$vector).hashCode());
            return result;
        }

        public String toString() {
            return "MetadataVectorCassandraTable.Record(rowId=" + this.getRowId() + ", attributes=" + this.getAttributes() + ", body=" + this.getBody() + ", metadata=" + this.getMetadata() + ", vector=" + this.getVector() + ")";
        }
    }
}

