/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.net.InetAddresses;
import com.google.inject.Inject;
import io.trino.execution.scheduler.NetworkLocation;
import io.trino.execution.scheduler.NetworkTopology;
import io.trino.execution.scheduler.SubnetTopologyConfig;
import io.trino.spi.HostAddress;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;

public class SubnetBasedTopology
implements NetworkTopology {
    private final List<byte[]> subnetMasks;
    private final AddressProtocol protocol;

    @Inject
    public SubnetBasedTopology(SubnetTopologyConfig config) {
        this(config.getCidrPrefixLengths(), config.getAddressProtocol());
    }

    public SubnetBasedTopology(List<Integer> cidrPrefixLengths, AddressProtocol protocol) {
        Objects.requireNonNull(cidrPrefixLengths, "cidrPrefixLengths is null");
        Objects.requireNonNull(protocol, "protocol is null");
        SubnetBasedTopology.validateHierarchy(cidrPrefixLengths, protocol);
        this.protocol = protocol;
        this.subnetMasks = (List)cidrPrefixLengths.stream().map(protocol::computeSubnetMask).collect(ImmutableList.toImmutableList());
    }

    @Override
    public NetworkLocation locate(HostAddress address) {
        try {
            InetAddress inetAddress = this.protocol.getInetAddress(address.getAllInetAddresses());
            if (inetAddress == null) {
                return NetworkLocation.ROOT_LOCATION;
            }
            byte[] addressBytes = inetAddress.getAddress();
            ImmutableList.Builder segments = ImmutableList.builder();
            for (byte[] subnetMask : this.subnetMasks) {
                byte[] bytes = this.applyMask(addressBytes, subnetMask);
                segments.add((Object)InetAddresses.toAddrString((InetAddress)InetAddress.getByAddress(bytes)));
            }
            segments.add((Object)InetAddresses.toAddrString((InetAddress)inetAddress));
            return new NetworkLocation((Collection<String>)segments.build());
        }
        catch (UnknownHostException e) {
            return NetworkLocation.ROOT_LOCATION;
        }
    }

    private byte[] applyMask(byte[] addressBytes, byte[] subnetMask) {
        int length = subnetMask.length;
        byte[] subnet = new byte[length];
        for (int i = 0; i < length; ++i) {
            subnet[i] = (byte)(addressBytes[i] & subnetMask[i]);
        }
        return subnet;
    }

    private static void validateHierarchy(List<Integer> lengths, AddressProtocol protocol) {
        if (!Ordering.natural().isStrictlyOrdered(lengths)) {
            throw new IllegalArgumentException("Subnet hierarchy should be listed in the order of increasing prefix lengths");
        }
        if (!(lengths.isEmpty() || lengths.get(0) > 0 && (Integer)Iterables.getLast(lengths) < protocol.getTotalBitCount())) {
            throw new IllegalArgumentException("Subnet mask prefix lengths are invalid");
        }
    }

    public static enum AddressProtocol {
        IPv4(Inet4Address.class, 32),
        IPv6(Inet6Address.class, 128);

        private final Class<?> addressClass;
        private final int totalBitCount;

        private AddressProtocol(Class<?> addressClass, int totalBitCount) {
            this.addressClass = addressClass;
            this.totalBitCount = totalBitCount;
        }

        int getTotalBitCount() {
            return this.totalBitCount;
        }

        byte[] computeSubnetMask(int n) {
            Preconditions.checkArgument((n > 0 && n < this.getTotalBitCount() ? 1 : 0) != 0, (Object)"Invalid length for subnet mask");
            byte[] mask = new byte[this.getTotalBitCount() / 8];
            for (int i = 0; i < mask.length; ++i) {
                if (n < 8) {
                    mask[i] = (byte)(-(1 << 8 - n));
                    break;
                }
                mask[i] = -1;
                n -= 8;
            }
            return mask;
        }

        InetAddress getInetAddress(List<InetAddress> inetAddresses) {
            return inetAddresses.stream().filter(this.addressClass::isInstance).findFirst().orElse(null);
        }
    }
}

