/*
 * Decompiled with CFR 0.152.
 */
package io.stargate.db.dse.impl.interceptors;

import com.datastax.bdp.db.nodes.BootstrapState;
import com.datastax.bdp.db.util.ProductVersion;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.reactivex.Single;
import io.stargate.db.EventListener;
import io.stargate.db.dse.impl.ClientStateWithPublicAddress;
import io.stargate.db.dse.impl.StargateSystemKeyspace;
import io.stargate.db.dse.impl.interceptors.QueryInterceptor;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.cql3.CQLStatement;
import org.apache.cassandra.cql3.QueryOptions;
import org.apache.cassandra.cql3.QueryProcessor;
import org.apache.cassandra.cql3.ResultSet;
import org.apache.cassandra.cql3.statements.SelectStatement;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.InetAddressType;
import org.apache.cassandra.db.marshal.Int32Type;
import org.apache.cassandra.db.marshal.SetType;
import org.apache.cassandra.db.marshal.UTF8Type;
import org.apache.cassandra.db.marshal.UUIDType;
import org.apache.cassandra.service.ClientState;
import org.apache.cassandra.service.QueryState;
import org.apache.cassandra.stargate.transport.ServerError;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.transport.messages.ResultMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ProxyProtocolQueryInterceptor
implements QueryInterceptor {
    public static final String PROXY_DNS_NAME = System.getProperty("stargate.proxy_protocol.dns_name");
    public static final int PROXY_PORT = Integer.getInteger("stargate.proxy_protocol.port", DatabaseDescriptor.getNativeTransportPort());
    public static final long RESOLVE_DELAY_SECS = Long.getLong("stargate.proxy_protocol.resolve_delay_secs", 10L);
    private static final Logger logger = LoggerFactory.getLogger(ProxyProtocolQueryInterceptor.class);
    private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
    private final Resolver resolver;
    private final String proxyDnsName;
    private final int proxyPort;
    private final long resolveDelaySecs;
    private final List<EventListener> listeners = new CopyOnWriteArrayList<EventListener>();
    private final Map<InetAddress, Set<String>> tokensCache = new ConcurrentHashMap<InetAddress, Set<String>>();
    private volatile Set<InetAddress> peers = Collections.emptySet();
    private final Optional<QueryInterceptor> wrapped;

    public ProxyProtocolQueryInterceptor(QueryInterceptor wrapped) {
        this(new DefaultResolver(), PROXY_DNS_NAME, PROXY_PORT, RESOLVE_DELAY_SECS, wrapped);
    }

    @VisibleForTesting
    public ProxyProtocolQueryInterceptor() {
        this(null);
    }

    @VisibleForTesting
    public ProxyProtocolQueryInterceptor(Resolver resolver, String proxyDnsName, int proxyPort, long resolveDelaySecs) {
        this(resolver, proxyDnsName, proxyPort, resolveDelaySecs, null);
    }

    private ProxyProtocolQueryInterceptor(Resolver resolver, String proxyDnsName, int proxyPort, long resolveDelaySecs, QueryInterceptor wrapped) {
        this.resolver = resolver;
        this.proxyDnsName = proxyDnsName;
        this.proxyPort = proxyPort;
        this.resolveDelaySecs = resolveDelaySecs;
        this.wrapped = Optional.ofNullable(wrapped);
    }

    @Override
    public void initialize() {
        String ttl = Security.getProperty("networkaddress.cache.ttl");
        if (Strings.isNullOrEmpty((String)ttl)) {
            logger.info("DNS cache TTL (property \"networkaddress.cache.ttl\") not explicitly set. Setting to 60 seconds.");
            Security.setProperty("networkaddress.cache.ttl", "60");
        }
        this.resolvePeers();
        this.wrapped.ifPresent(w -> w.initialize());
    }

    @Override
    public Single<ResultMessage> interceptQuery(CQLStatement statement, QueryState state, QueryOptions options, Map<String, ByteBuffer> customPayload, long queryStartNanoTime) {
        List<Object> rows;
        ClientState clientState = state.getClientState();
        if (!StargateSystemKeyspace.isSystemLocalOrPeers(statement) || !(clientState instanceof ClientStateWithPublicAddress)) {
            return this.wrapped.map(i -> i.interceptQuery(statement, state, options, customPayload, queryStartNanoTime)).orElse(null);
        }
        SelectStatement selectStatement = (SelectStatement)statement;
        InetSocketAddress publicAddress = ((ClientStateWithPublicAddress)clientState).publicAddress();
        if (publicAddress == null) {
            throw new ServerError("Unable to intercept proxy protocol system query without a valid public address");
        }
        String tableName = selectStatement.table();
        if (tableName.equals("peer_nodes")) {
            Set<InetAddress> currentPeers = this.peers;
            rows = currentPeers.isEmpty() ? Collections.emptyList() : Lists.newArrayListWithCapacity((int)(currentPeers.size() - 1));
            for (InetAddress peer : currentPeers) {
                if (peer.equals(publicAddress.getAddress())) continue;
                rows.add(this.buildRow(selectStatement.getResultMetadata(), peer));
            }
        } else {
            assert (tableName.equals("local_node"));
            rows = Collections.singletonList(this.buildRow(selectStatement.getResultMetadata(), publicAddress.getAddress()));
        }
        ResultSet resultSet = new ResultSet(selectStatement.getResultMetadata(), (List)rows);
        return Single.just((Object)new ResultMessage.Rows(resultSet));
    }

    @Override
    public void register(EventListener listener) {
        this.listeners.add(listener);
        this.wrapped.ifPresent(w -> w.register(listener));
    }

    private void resolvePeers() {
        if (!Strings.isNullOrEmpty((String)this.proxyDnsName)) {
            try {
                Set<InetAddress> resolved = this.resolver.resolve(this.proxyDnsName);
                if (!this.peers.equals(resolved)) {
                    Sets.SetView added = Sets.difference(resolved, this.peers);
                    Sets.SetView removed = Sets.difference(this.peers, resolved);
                    for (EventListener listener : this.listeners) {
                        for (InetAddress peer : added) {
                            listener.onJoinCluster(peer, this.proxyPort);
                            listener.onUp(peer, this.proxyPort);
                        }
                        for (InetAddress peer : removed) {
                            this.tokensCache.remove(peer);
                            listener.onLeaveCluster(peer, this.proxyPort);
                        }
                    }
                    this.peers = resolved;
                }
            }
            catch (UnknownHostException e) {
                throw new ServerError("Unable to resolve DNS for proxy protocol peers table", (Throwable)e);
            }
            scheduler.schedule(this::resolvePeers, this.resolveDelaySecs, TimeUnit.SECONDS);
        }
    }

    private ByteBuffer buildColumnValue(String name, InetAddress publicAddress) {
        switch (name) {
            case "key": {
                return UTF8Type.instance.decompose((Object)"local");
            }
            case "bootstrapped": {
                return UTF8Type.instance.decompose((Object)BootstrapState.COMPLETED.toString());
            }
            case "peer": 
            case "preferred_ip": 
            case "broadcast_address": 
            case "native_transport_address": 
            case "listen_address": 
            case "rpc_address": {
                return InetAddressType.instance.decompose((Object)publicAddress);
            }
            case "cluster_name": {
                return UTF8Type.instance.decompose((Object)DatabaseDescriptor.getClusterName());
            }
            case "cql_version": {
                return UTF8Type.instance.decompose((Object)QueryProcessor.CQL_VERSION.toString());
            }
            case "data_center": {
                return UTF8Type.instance.decompose((Object)DatabaseDescriptor.getLocalDataCenter());
            }
            case "host_id": {
                return UUIDType.instance.decompose((Object)UUID.nameUUIDFromBytes(publicAddress.getAddress()));
            }
            case "native_protocol_version": {
                return UTF8Type.instance.decompose((Object)String.valueOf(ProtocolVersion.CURRENT.asInt()));
            }
            case "partitioner": {
                return UTF8Type.instance.decompose((Object)DatabaseDescriptor.getPartitioner().getClass().getName());
            }
            case "rack": {
                return UTF8Type.instance.decompose((Object)DatabaseDescriptor.getLocalRack());
            }
            case "release_version": {
                return UTF8Type.instance.decompose((Object)ProductVersion.getReleaseVersion().toString());
            }
            case "schema_version": {
                return UUIDType.instance.decompose((Object)StargateSystemKeyspace.SCHEMA_VERSION);
            }
            case "tokens": {
                return SetType.getInstance((AbstractType)UTF8Type.instance, (boolean)false).decompose(this.getTokens(publicAddress));
            }
            case "native_transport_port": 
            case "native_transport_port_ssl": {
                return Int32Type.instance.decompose((Object)PROXY_PORT);
            }
            case "storage_port": {
                return Int32Type.instance.decompose((Object)DatabaseDescriptor.getStoragePort());
            }
            case "storage_port_ssl": {
                return Int32Type.instance.decompose((Object)DatabaseDescriptor.getSSLStoragePort());
            }
            case "jmx_port": {
                return DatabaseDescriptor.getJMXPort().map(p -> Int32Type.instance.decompose(p)).orElse(null);
            }
        }
        return null;
    }

    private Set<String> getTokens(InetAddress publicAddress) {
        if (this.peers.contains(publicAddress)) {
            return this.tokensCache.computeIfAbsent(publicAddress, pa -> StargateSystemKeyspace.generateRandomTokens(pa, DatabaseDescriptor.getNumTokens()));
        }
        return StargateSystemKeyspace.generateRandomTokens(publicAddress, DatabaseDescriptor.getNumTokens());
    }

    private List<ByteBuffer> buildRow(ResultSet.ResultMetadata metadata, InetAddress publicAddress) {
        ArrayList row = Lists.newArrayListWithCapacity((int)metadata.names.size());
        metadata.names.forEach(column -> row.add(this.buildColumnValue(column.name.toString(), publicAddress)));
        return row;
    }

    private static class DefaultResolver
    implements Resolver {
        private DefaultResolver() {
        }

        @Override
        public Set<InetAddress> resolve(String name) throws UnknownHostException {
            return Arrays.stream(InetAddress.getAllByName(name)).collect(Collectors.toSet());
        }
    }

    public static interface Resolver {
        public Set<InetAddress> resolve(String var1) throws UnknownHostException;
    }
}

