/*
 * Decompiled with CFR 0.152.
 */
package com.yugabyte.oss.driver.internal.core.loadbalancing;

import com.datastax.oss.driver.api.core.ConsistencyLevel;
import com.datastax.oss.driver.api.core.context.DriverContext;
import com.datastax.oss.driver.api.core.cql.BatchStatement;
import com.datastax.oss.driver.api.core.cql.BatchableStatement;
import com.datastax.oss.driver.api.core.cql.BoundStatement;
import com.datastax.oss.driver.api.core.cql.ColumnDefinitions;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.loadbalancing.LoadBalancingPolicy;
import com.datastax.oss.driver.api.core.loadbalancing.NodeDistance;
import com.datastax.oss.driver.api.core.metadata.Node;
import com.datastax.oss.driver.api.core.metadata.NodeState;
import com.datastax.oss.driver.api.core.session.Request;
import com.datastax.oss.driver.api.core.session.Session;
import com.datastax.oss.driver.api.core.tracker.RequestTracker;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.MapType;
import com.datastax.oss.driver.api.core.type.SetType;
import com.datastax.oss.driver.api.core.type.UserDefinedType;
import com.datastax.oss.driver.internal.core.util.collection.QueryPlan;
import com.yugabyte.oss.driver.api.core.DefaultPartitionMetadata;
import com.yugabyte.oss.driver.api.core.TableSplitMetadata;
import com.yugabyte.oss.driver.api.core.utils.Jenkins;
import com.yugabyte.oss.driver.internal.core.loadbalancing.YugabyteDefaultLoadBalancingPolicy;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.UUID;
import net.jcip.annotations.ThreadSafe;
import org.apache.commons.collections.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ThreadSafe
public class PartitionAwarePolicy
extends YugabyteDefaultLoadBalancingPolicy
implements RequestTracker {
    private static final Logger LOG = LoggerFactory.getLogger(PartitionAwarePolicy.class);

    public PartitionAwarePolicy(@NonNull DriverContext context, @NonNull String profileName) {
        super(context, profileName);
    }

    @Override
    public void init(Map<UUID, Node> nodes, LoadBalancingPolicy.DistanceReporter distanceReporter) {
        super.init(nodes, distanceReporter);
    }

    @Override
    public Queue<Node> newQueryPlan(Request request, Session session) {
        Iterator<Node> partitionAwareNodeIterator = null;
        if (request instanceof BoundStatement) {
            partitionAwareNodeIterator = this.getQueryPlan(session, (BoundStatement)request);
        } else if (request instanceof BatchStatement) {
            partitionAwareNodeIterator = this.getQueryPlan(session, (BatchStatement)request);
        }
        LinkedHashSet<Node> partitionAwareNodes = null;
        if (partitionAwareNodeIterator != null) {
            partitionAwareNodes = new LinkedHashSet<Node>();
            while (partitionAwareNodeIterator.hasNext()) {
                partitionAwareNodes.add(partitionAwareNodeIterator.next());
            }
            LOG.debug("newQueryPlan: Number of Nodes = " + partitionAwareNodes.size());
        }
        return !CollectionUtils.isEmpty(partitionAwareNodes) ? new QueryPlan(partitionAwareNodes.toArray()) : super.newQueryPlan(request, session);
    }

    private Iterator<Node> getQueryPlan(Session session, BoundStatement statement) {
        PreparedStatement pstmt = statement.getPreparedStatement();
        String query = pstmt.getQuery();
        ColumnDefinitions variables = pstmt.getVariableDefinitions();
        if (variables.size() == 0) {
            return null;
        }
        int key = PartitionAwarePolicy.getKey(statement);
        if (key < 0) {
            return null;
        }
        String queryKeySpace = variables.get(0).getKeyspace().asInternal();
        String queryTable = variables.get(0).getTable().asInternal();
        LOG.debug("getQueryPlan: keyspace = " + queryKeySpace + ", query = " + query);
        Optional<DefaultPartitionMetadata> partitionMetadata = session.getMetadata().getDefaultPartitionMetadata();
        if (!partitionMetadata.isPresent()) {
            return null;
        }
        TableSplitMetadata tableSplitMetadata = partitionMetadata.get().getTableSplitMetadata(queryKeySpace, queryTable);
        if (tableSplitMetadata == null) {
            return null;
        }
        Iterator<Node> nodesFromBasePolicy = super.newQueryPlan(statement, session).iterator();
        return new UpHostIterator(statement, new ArrayList<Node>(tableSplitMetadata.getHosts(key)), nodesFromBasePolicy);
    }

    private Iterator<Node> getQueryPlan(Session session, BatchStatement batch) {
        for (BatchableStatement nextStatement : batch) {
            Iterator<Node> plan;
            if (!(nextStatement instanceof BoundStatement) || (plan = this.getQueryPlan(session, (BoundStatement)nextStatement)) == null) continue;
            return plan;
        }
        return null;
    }

    private static int getKey(byte[] bytes) {
        long SEED = 97L;
        long h = Jenkins.hash64(bytes, 97L);
        long h1 = h >>> 48;
        long h2 = 3L * (h >>> 32);
        long h3 = 5L * (h >>> 16);
        long h4 = 7L * (h & 0xFFFFL);
        return (int)((h1 ^ h2 ^ h3 ^ h4) & 0xFFFFL);
    }

    public static int CqlToYBHashCode(long cql_hash) {
        int hash_code = (int)(cql_hash >> 48);
        return hash_code ^= 0x8000;
    }

    public static long YBToCqlHashCode(int hash) {
        long cql_hash = hash ^ 0x8000;
        return cql_hash <<= 48;
    }

    public static int getKey(BoundStatement stmt) {
        PreparedStatement pstmt = stmt.getPreparedStatement();
        List<Integer> hashIndexes = pstmt.getPartitionKeyIndices();
        if (hashIndexes == null || hashIndexes.isEmpty()) {
            return -1;
        }
        try {
            ByteArrayOutputStream bs = new ByteArrayOutputStream();
            WritableByteChannel channel = Channels.newChannel(bs);
            ColumnDefinitions variables = pstmt.getVariableDefinitions();
            for (int i = 0; i < hashIndexes.size(); ++i) {
                int index = hashIndexes.get(i);
                DataType type = variables.get(index).getType();
                ByteBuffer value = stmt.getBytesUnsafe(index).duplicate();
                PartitionAwarePolicy.AppendValueToChannel(type, value, channel);
            }
            channel.close();
            return PartitionAwarePolicy.getKey(bs.toByteArray());
        }
        catch (IOException e) {
            LOG.error("hash key encoding failed", (Throwable)e);
            return -1;
        }
    }

    private static void AppendValueToChannel(DataType type, ByteBuffer value, WritableByteChannel channel) throws IOException {
        int typeCode = type.getProtocolCode();
        block0 : switch (typeCode) {
            case 1: 
            case 2: 
            case 3: 
            case 4: 
            case 6: 
            case 9: 
            case 12: 
            case 13: 
            case 14: 
            case 15: 
            case 16: 
            case 17: 
            case 18: 
            case 19: 
            case 20: 
            case 128: {
                channel.write(value);
                break;
            }
            case 8: {
                float floatValue = value.getFloat(0);
                value.rewind();
                if (Float.isNaN(floatValue)) {
                    value = ByteBuffer.allocate(4);
                    value.putInt(2143289344);
                    value.flip();
                }
                channel.write(value);
                break;
            }
            case 7: {
                double doubleValue = value.getDouble(0);
                value.rewind();
                if (Double.isNaN(doubleValue)) {
                    value = ByteBuffer.allocate(8);
                    value.putLong(9221120237041090560L);
                    value.flip();
                }
                channel.write(value);
                break;
            }
            case 11: {
                ByteBuffer bb = ByteBuffer.allocate(8);
                bb.putLong(value.getLong() * 1000L);
                bb.flip();
                value = bb;
                channel.write(value);
                break;
            }
            case 32: 
            case 34: {
                SetType setType = (SetType)type;
                DataType dataTypeOfSetValue = setType.getElementType();
                int length = value.getInt();
                for (int j = 0; j < length; ++j) {
                    int size = value.getInt();
                    ByteBuffer buf = value.slice();
                    buf.limit(size);
                    PartitionAwarePolicy.AppendValueToChannel(dataTypeOfSetValue, buf, channel);
                    value.position(value.position() + size);
                }
                break;
            }
            case 33: {
                MapType mapType = (MapType)type;
                DataType dataTypeOfMapKey = mapType.getKeyType();
                DataType dataTypeOfMapValue = mapType.getValueType();
                int length = value.getInt();
                for (int j = 0; j < length; ++j) {
                    int size = value.getInt();
                    ByteBuffer buf = value.slice();
                    buf.limit(size);
                    PartitionAwarePolicy.AppendValueToChannel(dataTypeOfMapKey, buf, channel);
                    value.position(value.position() + size);
                    size = value.getInt();
                    buf = value.slice();
                    buf.limit(size);
                    PartitionAwarePolicy.AppendValueToChannel(dataTypeOfMapValue, buf, channel);
                    value.position(value.position() + size);
                }
                break;
            }
            case 48: {
                UserDefinedType udt = (UserDefinedType)type;
                for (DataType field : udt.getFieldTypes()) {
                    if (!value.hasRemaining()) break block0;
                    int size = value.getInt();
                    ByteBuffer buf = value.slice();
                    buf.limit(size);
                    PartitionAwarePolicy.AppendValueToChannel(field, buf, channel);
                    value.position(value.position() + size);
                }
                break;
            }
            case 0: 
            case 5: 
            case 49: {
                throw new UnsupportedOperationException("Datatype with Hex Code: " + typeCode + " not supported in a partition key column");
            }
        }
    }

    private static class UpHostIterator
    implements Iterator<Node> {
        private final BoundStatement statement;
        private final Iterator<Node> iterator;
        private final Iterator<Node> childIterator;
        private final List<Node> hosts;
        private Node nextHost;

        public UpHostIterator(BoundStatement statement, List<Node> hosts, Iterator<Node> nodesFromBasePolicy) {
            this.statement = statement;
            this.hosts = hosts;
            this.iterator = hosts.iterator();
            this.childIterator = nodesFromBasePolicy;
            if (this.getConsistencyLevel() == ConsistencyLevel.YB_CONSISTENT_PREFIX) {
                Collections.shuffle(hosts);
            }
        }

        private ConsistencyLevel getConsistencyLevel() {
            return this.statement.getConsistencyLevel() != null ? this.statement.getConsistencyLevel() : ConsistencyLevel.YB_STRONG;
        }

        @Override
        public boolean hasNext() {
            while (this.iterator.hasNext()) {
                this.nextHost = this.iterator.next();
                if (this.nextHost.getState() != NodeState.UP || this.nextHost.getDistance() != NodeDistance.LOCAL && !this.getConsistencyLevel().isYBStrong()) continue;
                return true;
            }
            if (this.childIterator != null) {
                while (this.childIterator.hasNext()) {
                    this.nextHost = this.childIterator.next();
                    if (this.hosts.contains(this.nextHost) && (this.nextHost.getDistance() == NodeDistance.LOCAL || this.statement.getConsistencyLevel() == ConsistencyLevel.YB_STRONG)) continue;
                    return true;
                }
            }
            return false;
        }

        @Override
        public Node next() {
            return this.nextHost;
        }
    }
}

