/*
 * Decompiled with CFR 0.152.
 */
package org.apache.arrow.adbc.driver.flightsql;

import com.github.benmanes.caffeine.cache.LoadingCache;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcInfoCode;
import org.apache.arrow.adbc.core.StandardSchemas;
import org.apache.arrow.adbc.driver.flightsql.BaseFlightReader;
import org.apache.arrow.adbc.driver.flightsql.FlightSqlClientWithCallOptions;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.DenseUnionVector;
import org.apache.arrow.vector.types.pojo.Schema;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.dataflow.qual.Pure;

final class GetInfoMetadataReader
extends BaseFlightReader {
    private static final byte STRING_VALUE_TYPE_ID = 0;
    private static final Map<Integer, Integer> ADBC_TO_FLIGHT_SQL_CODES = new HashMap<Integer, Integer>();
    private static final Map<Integer, AddInfo> SUPPORTED_CODES = new HashMap<Integer, AddInfo>();
    private static final byte[] DRIVER_NAME = "ADBC Flight SQL Driver".getBytes(StandardCharsets.UTF_8);
    private final BufferAllocator allocator;
    private final Collection<Integer> requestedCodes;
    private @Nullable UInt4Vector infoCodes = null;
    private @Nullable DenseUnionVector infoValues = null;
    private @Nullable VarCharVector stringValues = null;
    private boolean hasInMemoryDataBeenWritten = false;
    private final boolean hasInMemoryData;
    private final boolean hasSupportedCodes;
    private boolean hasRequestBeenIssued = false;

    static GetInfoMetadataReader CreateGetInfoMetadataReader(BufferAllocator allocator, FlightSqlClientWithCallOptions client, LoadingCache<Location, FlightSqlClientWithCallOptions> clientCache, int @Nullable [] infoCodes) {
        LinkedHashSet requestedCodes;
        if (infoCodes == null) {
            requestedCodes = new LinkedHashSet(SUPPORTED_CODES.keySet());
            requestedCodes.add(AdbcInfoCode.DRIVER_NAME.getValue());
            requestedCodes.add(AdbcInfoCode.DRIVER_VERSION.getValue());
        } else {
            requestedCodes = IntStream.of(infoCodes).sorted().boxed().collect(Collectors.toCollection(LinkedHashSet::new));
        }
        return new GetInfoMetadataReader(allocator, client, clientCache, requestedCodes);
    }

    GetInfoMetadataReader(BufferAllocator allocator, FlightSqlClientWithCallOptions client, LoadingCache<Location, FlightSqlClientWithCallOptions> clientCache, Collection<Integer> requestedCodes) {
        super(allocator, client, clientCache, () -> GetInfoMetadataReader.issueGetSqlInfoRequest(client, requestedCodes));
        this.requestedCodes = requestedCodes;
        this.allocator = allocator;
        this.hasInMemoryData = requestedCodes.contains(AdbcInfoCode.DRIVER_NAME.getValue()) || requestedCodes.contains(AdbcInfoCode.DRIVER_VERSION.getValue());
        this.hasSupportedCodes = requestedCodes.stream().anyMatch(SUPPORTED_CODES::containsKey);
    }

    @Pure
    void setStringValue(int index, byte[] value) {
        this.infoValues.setValueCount(index + 1);
        this.infoValues.setTypeId(index, (byte)0);
        this.stringValues.setSafe(index, value);
        this.infoValues.getOffsetBuffer().setInt((long)index * 4L, this.stringValues.getLastSet());
    }

    @Override
    public boolean loadNextBatch() throws IOException {
        if (this.hasInMemoryData && !this.hasInMemoryDataBeenWritten) {
            this.hasInMemoryDataBeenWritten = true;
            int dstIndex = 0;
            try (VectorSchemaRoot root = VectorSchemaRoot.create((Schema)this.readSchema(), (BufferAllocator)this.allocator);){
                root.allocateNew();
                this.infoCodes = (UInt4Vector)root.getVector(0);
                this.infoValues = (DenseUnionVector)root.getVector(1);
                this.stringValues = this.infoValues.getVarCharVector((byte)0);
                if (this.requestedCodes.contains(AdbcInfoCode.DRIVER_NAME.getValue())) {
                    this.infoCodes.setSafe(dstIndex, AdbcInfoCode.DRIVER_NAME.getValue());
                    this.setStringValue(dstIndex++, DRIVER_NAME);
                }
                if (this.requestedCodes.contains(AdbcInfoCode.DRIVER_VERSION.getValue())) {
                    this.infoCodes.setSafe(dstIndex, AdbcInfoCode.DRIVER_VERSION.getValue());
                    this.setStringValue(dstIndex++, "0.0.1".getBytes(StandardCharsets.UTF_8));
                }
                root.setRowCount(dstIndex);
                this.loadRoot(root);
                boolean bl = true;
                return bl;
            }
        }
        if (this.hasSupportedCodes) {
            if (!this.hasRequestBeenIssued) {
                this.hasRequestBeenIssued = true;
                try {
                    this.populateEndpointData();
                }
                catch (AdbcException e) {
                    throw new IOException(e);
                }
            }
            return super.loadNextBatch();
        }
        return false;
    }

    @Override
    protected Schema readSchema() {
        return StandardSchemas.GET_INFO_SCHEMA;
    }

    @Override
    protected void processRootFromStream(VectorSchemaRoot root) {
        try (VectorSchemaRoot tmpRoot = VectorSchemaRoot.create((Schema)this.readSchema(), (BufferAllocator)this.allocator);){
            root.allocateNew();
            this.infoCodes = (UInt4Vector)tmpRoot.getVector(0);
            this.infoValues = (DenseUnionVector)tmpRoot.getVector(1);
            this.stringValues = this.infoValues.getVarCharVector((byte)0);
            int dstIndex = 0;
            UInt4Vector sqlCode = (UInt4Vector)root.getVector(0);
            DenseUnionVector sqlInfo = (DenseUnionVector)root.getVector(1);
            for (int srcIndex = 0; srcIndex < root.getRowCount(); ++srcIndex) {
                AddInfo addInfo = SUPPORTED_CODES.get(sqlCode.get(srcIndex));
                if (addInfo == null) continue;
                addInfo.accept(this, sqlInfo, srcIndex, dstIndex++);
            }
            tmpRoot.setRowCount(dstIndex);
            this.loadRoot(tmpRoot);
        }
    }

    private static List<FlightEndpoint> issueGetSqlInfoRequest(FlightSqlClientWithCallOptions client, Collection<Integer> requestedCodes) {
        ArrayList<Integer> translatedCodes = new ArrayList<Integer>();
        for (int code : requestedCodes) {
            Integer translatedCode = ADBC_TO_FLIGHT_SQL_CODES.get(code);
            if (translatedCode == null) continue;
            translatedCodes.add(translatedCode);
        }
        return client.getSqlInfo(translatedCodes, new CallOption[0]).getEndpoints();
    }

    static {
        ADBC_TO_FLIGHT_SQL_CODES.put(AdbcInfoCode.VENDOR_NAME.getValue(), FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME.getNumber());
        ADBC_TO_FLIGHT_SQL_CODES.put(AdbcInfoCode.VENDOR_VERSION.getValue(), FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION.getNumber());
        SUPPORTED_CODES.put(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME.getNumber(), (b, sqlInfo, srcIndex, dstIndex) -> {
            if (b.infoCodes == null) {
                throw new IllegalStateException();
            }
            b.infoCodes.setSafe(dstIndex, AdbcInfoCode.VENDOR_NAME.getValue());
            b.setStringValue(dstIndex, sqlInfo.getVarCharVector((byte)0).get(srcIndex));
        });
        SUPPORTED_CODES.put(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION.getNumber(), (b, sqlInfo, srcIndex, dstIndex) -> {
            if (b.infoCodes == null) {
                throw new IllegalStateException();
            }
            b.infoCodes.setSafe(dstIndex, AdbcInfoCode.VENDOR_VERSION.getValue());
            b.setStringValue(dstIndex, sqlInfo.getVarCharVector((byte)0).get(srcIndex));
        });
    }

    @FunctionalInterface
    static interface AddInfo {
        public void accept(GetInfoMetadataReader var1, DenseUnionVector var2, int var3, int var4);
    }
}

