/*
 * Decompiled with CFR 0.152.
 */
package com.starrocks.connector.spark.serialization;

import com.google.common.base.Preconditions;
import com.starrocks.connector.spark.exception.StarrocksException;
import com.starrocks.connector.spark.rest.models.Field;
import com.starrocks.connector.spark.rest.models.Schema;
import com.starrocks.connector.spark.util.DataTypeUtils;
import com.starrocks.thrift.TScanBatchResult;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.sql.Date;
import java.sql.Timestamp;
import java.text.SimpleDateFormat;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.Types;
import org.apache.spark.sql.types.Decimal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.math.BigDecimal;

public class RowBatch {
    private static Logger logger = LoggerFactory.getLogger(RowBatch.class);
    private final SimpleDateFormat dateFormatter = new SimpleDateFormat("yyyy-MM-dd");
    private final DateTimeFormatter dateTimeFormatter;
    private int offsetInRowBatch = 0;
    private int rowCountInOneBatch = 0;
    private int readRowCount = 0;
    private List<Row> rowBatch = new ArrayList<Row>();
    private final ArrowStreamReader arrowStreamReader;
    private final VectorSchemaRoot root;
    private List<FieldVector> fieldVectors;
    private RootAllocator rootAllocator;
    private final Schema schema;
    private final Map<String, Field> fieldMap;

    public RowBatch(TScanBatchResult nextResult, Schema schema) throws StarrocksException {
        this(nextResult, schema, ZoneId.systemDefault());
    }

    public RowBatch(TScanBatchResult nextResult, Schema schema, ZoneId timeZone) throws StarrocksException {
        this.schema = schema;
        this.dateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss[.SSSSSS]").withZone(timeZone);
        this.fieldMap = schema.getProperties().stream().collect(Collectors.toMap(Field::getName, Function.identity()));
        this.rootAllocator = new RootAllocator(Integer.MAX_VALUE);
        this.arrowStreamReader = new ArrowStreamReader((InputStream)new ByteArrayInputStream(nextResult.getRows()), (BufferAllocator)this.rootAllocator);
        this.offsetInRowBatch = 0;
        try {
            this.root = this.arrowStreamReader.getVectorSchemaRoot();
            while (this.arrowStreamReader.loadNextBatch()) {
                this.fieldVectors = this.root.getFieldVectors();
                if (this.fieldVectors.size() != schema.size()) {
                    logger.error("Schema size '{}' is not equal to arrow field size '{}'.", (Object)schema.size(), (Object)this.fieldVectors.size());
                    throw new StarrocksException("Load StarRocks data failed, schema size of fetch data is wrong.");
                }
                if (this.fieldVectors.size() == 0 || this.root.getRowCount() == 0) {
                    logger.debug("One batch in arrow has no data.");
                    continue;
                }
                this.rowCountInOneBatch = this.root.getRowCount();
                for (int i = 0; i < this.rowCountInOneBatch; ++i) {
                    this.rowBatch.add(new Row(this.fieldVectors.size()));
                }
                this.convertArrowToRowBatch();
                this.readRowCount += this.root.getRowCount();
            }
        }
        catch (Exception e) {
            logger.error("Read StarRocks Data failed because: ", (Throwable)e);
            throw new StarrocksException(e);
        }
        finally {
            this.close();
        }
    }

    public boolean hasNext() {
        return this.offsetInRowBatch < this.readRowCount;
    }

    private void addValueToRow(int rowIndex, Object obj) {
        if (rowIndex > this.rowCountInOneBatch) {
            String errMsg = "Get row offset: " + rowIndex + " larger than row size: " + this.rowCountInOneBatch;
            logger.error(errMsg);
            throw new NoSuchElementException(errMsg);
        }
        this.rowBatch.get(this.readRowCount + rowIndex).put(obj);
    }

    public void convertArrowToRowBatch() throws Exception {
        try {
            block42: for (int col = 0; col < this.fieldVectors.size(); ++col) {
                String currentType;
                FieldVector curFieldVector = this.fieldVectors.get(col);
                Types.MinorType mt = curFieldVector.getMinorType();
                String vectorName = curFieldVector.getName();
                Field field = this.schema.get(col);
                Preconditions.checkNotNull((Object)field, (String)"Can't find schema for arrow vector [%s] at index [%s]", (Object[])new Object[]{vectorName, col});
                if (!vectorName.isEmpty()) {
                    Preconditions.checkState((boolean)vectorName.equals(field.getName()), (String)"The column at [%s] has inconsistent column names between schema [%s] and arrow vector [%s]", (Object[])new Object[]{col, field.getName(), vectorName});
                }
                switch (currentType = field.getType().orElseGet(() -> DataTypeUtils.map(mt))) {
                    case "NULL_TYPE": {
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            this.addValueToRow(rowIndex, null);
                        }
                        continue block42;
                    }
                    case "BOOLEAN": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.BIT), (Object)this.typeMismatchMessage(currentType, mt));
                        BitVector bitVector = (BitVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            Boolean fieldValue = bitVector.isNull(rowIndex) ? null : Boolean.valueOf(bitVector.get(rowIndex) != 0);
                            this.addValueToRow(rowIndex, fieldValue);
                        }
                        continue block42;
                    }
                    case "TINYINT": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.TINYINT), (Object)this.typeMismatchMessage(currentType, mt));
                        TinyIntVector tinyIntVector = (TinyIntVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            Byte fieldValue = tinyIntVector.isNull(rowIndex) ? null : Byte.valueOf(tinyIntVector.get(rowIndex));
                            this.addValueToRow(rowIndex, fieldValue);
                        }
                        continue block42;
                    }
                    case "SMALLINT": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.SMALLINT), (Object)this.typeMismatchMessage(currentType, mt));
                        SmallIntVector smallIntVector = (SmallIntVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            Short fieldValue = smallIntVector.isNull(rowIndex) ? null : Short.valueOf(smallIntVector.get(rowIndex));
                            this.addValueToRow(rowIndex, fieldValue);
                        }
                        continue block42;
                    }
                    case "INT": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.INT), (Object)this.typeMismatchMessage(currentType, mt));
                        IntVector intVector = (IntVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            Integer fieldValue = intVector.isNull(rowIndex) ? null : Integer.valueOf(intVector.get(rowIndex));
                            this.addValueToRow(rowIndex, fieldValue);
                        }
                        continue block42;
                    }
                    case "BIGINT": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.BIGINT), (Object)this.typeMismatchMessage(currentType, mt));
                        BigIntVector bigIntVector = (BigIntVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            Long fieldValue = bigIntVector.isNull(rowIndex) ? null : Long.valueOf(bigIntVector.get(rowIndex));
                            this.addValueToRow(rowIndex, fieldValue);
                        }
                        continue block42;
                    }
                    case "FLOAT": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.FLOAT4), (Object)this.typeMismatchMessage(currentType, mt));
                        Float4Vector float4Vector = (Float4Vector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            Float fieldValue = float4Vector.isNull(rowIndex) ? null : Float.valueOf(float4Vector.get(rowIndex));
                            this.addValueToRow(rowIndex, fieldValue);
                        }
                        continue block42;
                    }
                    case "TIME": 
                    case "DOUBLE": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.FLOAT8), (Object)this.typeMismatchMessage(currentType, mt));
                        Float8Vector float8Vector = (Float8Vector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            Double fieldValue = float8Vector.isNull(rowIndex) ? null : Double.valueOf(float8Vector.get(rowIndex));
                            this.addValueToRow(rowIndex, fieldValue);
                        }
                        continue block42;
                    }
                    case "BINARY": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.VARBINARY), (Object)this.typeMismatchMessage(currentType, mt));
                        VarBinaryVector varBinaryVector = (VarBinaryVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            byte[] fieldValue = varBinaryVector.isNull(rowIndex) ? null : varBinaryVector.get(rowIndex);
                            this.addValueToRow(rowIndex, fieldValue);
                        }
                        continue block42;
                    }
                    case "DECIMAL": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.VARCHAR), (Object)this.typeMismatchMessage(currentType, mt));
                        VarCharVector varCharVectorForDecimal = (VarCharVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            if (varCharVectorForDecimal.isNull(rowIndex)) {
                                this.addValueToRow(rowIndex, null);
                                continue;
                            }
                            String decimalValue = new String(varCharVectorForDecimal.get(rowIndex));
                            Decimal decimal = new Decimal();
                            try {
                                decimal.set(new BigDecimal(new java.math.BigDecimal(decimalValue)));
                            }
                            catch (NumberFormatException e) {
                                String errMsg = "Decimal response result '" + decimalValue + "' is illegal.";
                                logger.error(errMsg, (Throwable)e);
                                throw new StarrocksException(errMsg);
                            }
                            this.addValueToRow(rowIndex, decimal);
                        }
                        continue block42;
                    }
                    case "DECIMALV2": 
                    case "DECIMAL32": 
                    case "DECIMAL64": 
                    case "DECIMAL128": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.DECIMAL), (Object)this.typeMismatchMessage(currentType, mt));
                        DecimalVector decimalVector = (DecimalVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            if (decimalVector.isNull(rowIndex)) {
                                this.addValueToRow(rowIndex, null);
                                continue;
                            }
                            Decimal decimal = Decimal.apply((java.math.BigDecimal)decimalVector.getObject(rowIndex));
                            this.addValueToRow(rowIndex, decimal);
                        }
                        continue block42;
                    }
                    case "DATE": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.VARCHAR), (Object)this.typeMismatchMessage(currentType, mt));
                        VarCharVector varCharVectorForDate = (VarCharVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            if (varCharVectorForDate.isNull(rowIndex)) {
                                this.addValueToRow(rowIndex, null);
                                continue;
                            }
                            String value = new String(varCharVectorForDate.get(rowIndex));
                            this.addValueToRow(rowIndex, new Date(this.dateFormatter.parse(value).getTime()));
                        }
                        continue block42;
                    }
                    case "DATETIME": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.VARCHAR), (Object)this.typeMismatchMessage(currentType, mt));
                        VarCharVector varCharVectorForDateTime = (VarCharVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            if (varCharVectorForDateTime.isNull(rowIndex)) {
                                this.addValueToRow(rowIndex, null);
                                continue;
                            }
                            String value = new String(varCharVectorForDateTime.get(rowIndex));
                            this.addValueToRow(rowIndex, Timestamp.from(ZonedDateTime.parse(value, this.dateTimeFormatter).toInstant()));
                        }
                        continue block42;
                    }
                    case "LARGEINT": 
                    case "CHAR": 
                    case "VARCHAR": {
                        Preconditions.checkArgument((boolean)mt.equals((Object)Types.MinorType.VARCHAR), (Object)this.typeMismatchMessage(currentType, mt));
                        VarCharVector varCharVector = (VarCharVector)curFieldVector;
                        for (int rowIndex = 0; rowIndex < this.rowCountInOneBatch; ++rowIndex) {
                            if (varCharVector.isNull(rowIndex)) {
                                this.addValueToRow(rowIndex, null);
                                continue;
                            }
                            String value = new String(varCharVector.get(rowIndex));
                            this.addValueToRow(rowIndex, value);
                        }
                        continue block42;
                    }
                    default: {
                        String errMsg = "Unsupported type " + this.schema.get(col).getType();
                        logger.error(errMsg);
                        throw new StarrocksException(errMsg);
                    }
                }
            }
        }
        catch (Exception e) {
            this.close();
            throw e;
        }
    }

    public List<Object> next() throws StarrocksException {
        if (!this.hasNext()) {
            String errMsg = "Get row offset:" + this.offsetInRowBatch + " larger than row size: " + this.readRowCount;
            logger.error(errMsg);
            throw new NoSuchElementException(errMsg);
        }
        return this.rowBatch.get(this.offsetInRowBatch++).getCols();
    }

    private String typeMismatchMessage(String sparkType, Types.MinorType arrowType) {
        String messageTemplate = "Spark type is %1$s, but arrow type is %2$s.";
        return String.format("Spark type is %1$s, but arrow type is %2$s.", sparkType, arrowType.name());
    }

    public int getReadRowCount() {
        return this.readRowCount;
    }

    public void close() {
        try {
            if (this.arrowStreamReader != null) {
                this.arrowStreamReader.close();
            }
            if (this.rootAllocator != null) {
                this.rootAllocator.close();
            }
        }
        catch (IOException iOException) {
            // empty catch block
        }
    }

    public static class Row {
        private List<Object> cols;

        Row(int colCount) {
            this.cols = new ArrayList<Object>(colCount);
        }

        List<Object> getCols() {
            return this.cols;
        }

        public void put(Object o) {
            this.cols.add(o);
        }
    }
}

