/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.connectors.seatunnel.common.source.arrow;

import java.io.ByteArrayOutputStream;
import java.math.BigDecimal;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.connectors.seatunnel.common.source.arrow.SeaTunnelDataTypeHolder;
import org.apache.seatunnel.connectors.seatunnel.common.source.arrow.reader.ArrowToSeatunnelRowReader;
import org.apache.seatunnel.shade.com.google.common.base.Stopwatch;
import org.apache.seatunnel.shade.io.netty.util.CharsetUtil;
import org.apache.seatunnel.shade.org.apache.arrow.memory.ArrowBuf;
import org.apache.seatunnel.shade.org.apache.arrow.memory.BufferAllocator;
import org.apache.seatunnel.shade.org.apache.arrow.memory.RootAllocator;
import org.apache.seatunnel.shade.org.apache.arrow.vector.BigIntVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.BitVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.DateDayVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.DateMilliVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.DecimalVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.Float4Vector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.Float8Vector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.IntVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.LargeVarCharVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.SmallIntVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.TimeMicroVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.TimeStampMicroVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.TimeStampMilliTZVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.TinyIntVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.ValueVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.VarBinaryVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.VarCharVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.seatunnel.shade.org.apache.arrow.vector.complex.ListVector;
import org.apache.seatunnel.shade.org.apache.arrow.vector.complex.impl.UnionListWriter;
import org.apache.seatunnel.shade.org.apache.arrow.vector.complex.impl.UnionMapWriter;
import org.apache.seatunnel.shade.org.apache.arrow.vector.holders.TimeMilliHolder;
import org.apache.seatunnel.shade.org.apache.arrow.vector.holders.VarCharHolder;
import org.apache.seatunnel.shade.org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.seatunnel.shade.org.apache.arrow.vector.types.TimeUnit;
import org.apache.seatunnel.shade.org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.seatunnel.shade.org.apache.arrow.vector.types.pojo.Field;
import org.apache.seatunnel.shade.org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ArrowToSeatunnelRowReaderTest {
    private static final Logger log = LoggerFactory.getLogger(ArrowToSeatunnelRowReaderTest.class);
    private static VectorSchemaRoot root;
    private static RootAllocator rootAllocator;
    private static final List<SeaTunnelDataTypeHolder> seaTunnelDataTypeHolder;
    private static final LocalDateTime localDateTime;
    private static final List<String> stringData;
    private static final List<Byte> byteData;
    private static final List<Short> shortData;
    private static final List<Integer> intData;
    private static final List<Long> longData;
    private static final float floatData = 1.23f;
    private static final double doubleData = 1.23456789;
    private static final BigDecimal decimalData;
    private static final List<List<Integer>> arrayData1;
    private static final List<List<LocalDateTime>> arrayData2;
    private static final List<Map<String, LocalDateTime>> mapData;

    @BeforeAll
    public static void beforeAll() throws Exception {
        rootAllocator = new RootAllocator(Long.MAX_VALUE);
        root = ArrowToSeatunnelRowReaderTest.buildVectorSchemaRoot(rootAllocator, 10, true);
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("boolean", 1));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("byte", 1));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("short", 1));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("int", 1));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("long", 1));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("float", 1));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("double", 1));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("string1", 1));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("decimal", 1));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("timestamp1", 1));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("string2", 0));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("string3", 0));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("timestamp2", 0));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("time", 0));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("date1", 0));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("date2", 0));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("array1", 0));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("array2", 0));
        seaTunnelDataTypeHolder.add(new SeaTunnelDataTypeHolder("map", 0));
    }

    private static VectorSchemaRoot buildVectorSchemaRoot(RootAllocator rootAllocator, int count, boolean allType) {
        ArrayList<Object> vectors = new ArrayList<Object>();
        ZoneId zoneId = ZoneId.systemDefault();
        vectors.add(new BitVector("boolean", (BufferAllocator)rootAllocator));
        vectors.add(new TinyIntVector("byte", (BufferAllocator)rootAllocator));
        vectors.add(new SmallIntVector("short", (BufferAllocator)rootAllocator));
        vectors.add(new IntVector("int", (BufferAllocator)rootAllocator));
        vectors.add(new BigIntVector("long", (BufferAllocator)rootAllocator));
        vectors.add(new Float4Vector("float", (BufferAllocator)rootAllocator));
        vectors.add(new Float8Vector("double", (BufferAllocator)rootAllocator));
        vectors.add(new VarCharVector("string1", (BufferAllocator)rootAllocator));
        vectors.add(new DecimalVector(Field.nullable((String)"decimal", (ArrowType)new ArrowType.Decimal(10, 2, 128)), (BufferAllocator)rootAllocator));
        vectors.add(new TimeStampMicroVector("timestamp1", (BufferAllocator)rootAllocator));
        if (allType) {
            vectors.add(new VarBinaryVector("string2", (BufferAllocator)rootAllocator));
            vectors.add(new LargeVarCharVector("string3", (BufferAllocator)rootAllocator));
            vectors.add(new TimeStampMilliTZVector(Field.nullable((String)"timestamp2", (ArrowType)new ArrowType.Timestamp(TimeUnit.MILLISECOND, ZoneId.systemDefault().getId())), (BufferAllocator)rootAllocator));
            vectors.add(new TimeMicroVector("time", (BufferAllocator)rootAllocator));
            vectors.add(new DateMilliVector("date1", (BufferAllocator)rootAllocator));
            vectors.add(new DateDayVector("date2", (BufferAllocator)rootAllocator));
            vectors.add(ListVector.empty((String)"array1", (BufferAllocator)rootAllocator));
            vectors.add(ListVector.empty((String)"array2", (BufferAllocator)rootAllocator));
        }
        vectors.forEach(ValueVector::allocateNew);
        long epochMilli = localDateTime.atZone(zoneId).toInstant().toEpochMilli();
        byte byteStart = 97;
        vectors.forEach(vector -> {
            for (int i = 0; i < count; ++i) {
                String stringValue = "test" + i;
                if (vector instanceof BitVector) {
                    ((BitVector)vector).setSafe(i, i % 2 == 0 ? 0 : 1);
                    continue;
                }
                if (vector instanceof TinyIntVector) {
                    int i1 = byteStart + i;
                    byteData.add((byte)i1);
                    ((TinyIntVector)vector).setSafe(i, i1);
                    continue;
                }
                if (vector instanceof SmallIntVector) {
                    shortData.add((short)i);
                    ((SmallIntVector)vector).setSafe(i, i);
                    continue;
                }
                if (vector instanceof IntVector) {
                    intData.add(i);
                    ((IntVector)vector).setSafe(i, i);
                    continue;
                }
                if (vector instanceof BigIntVector) {
                    longData.add(Long.valueOf(i));
                    ((BigIntVector)vector).setSafe(i, (long)i);
                    continue;
                }
                if (vector instanceof Float4Vector) {
                    ((Float4Vector)vector).setSafe(i, 1.23f);
                    continue;
                }
                if (vector instanceof Float8Vector) {
                    ((Float8Vector)vector).setSafe(i, 1.23456789);
                    continue;
                }
                if (vector instanceof DecimalVector) {
                    ((DecimalVector)vector).setSafe(i, decimalData);
                    continue;
                }
                if (vector instanceof VarCharVector) {
                    stringData.add(stringValue);
                    ((VarCharVector)vector).setSafe(i, stringValue.getBytes(StandardCharsets.UTF_8));
                    continue;
                }
                if (vector instanceof TimeStampMicroVector) {
                    ((TimeStampMicroVector)vector).setSafe(i, epochMilli * 1000L);
                    continue;
                }
                if (vector instanceof VarBinaryVector) {
                    ((VarBinaryVector)vector).setSafe(i, stringValue.getBytes(StandardCharsets.UTF_8));
                    continue;
                }
                if (vector instanceof LargeVarCharVector) {
                    ((LargeVarCharVector)vector).setSafe(i, stringValue.getBytes(StandardCharsets.UTF_8));
                    continue;
                }
                if (vector instanceof TimeStampMilliTZVector) {
                    ((TimeStampMilliTZVector)vector).setSafe(i, epochMilli);
                    continue;
                }
                if (vector instanceof TimeMicroVector) {
                    ((TimeMicroVector)vector).setSafe(i, epochMilli);
                    continue;
                }
                if (vector instanceof DateMilliVector) {
                    ((DateMilliVector)vector).setSafe(i, epochMilli);
                    continue;
                }
                if (!(vector instanceof DateDayVector)) continue;
                ((DateDayVector)vector).setSafe(i, (int)localDateTime.toLocalDate().toEpochDay());
            }
        });
        vectors.stream().filter(vector -> vector instanceof ListVector).forEach(vector -> {
            ListVector listVector = (ListVector)vector;
            String name = listVector.getField().getName();
            UnionListWriter writer = listVector.getWriter();
            for (int i = 0; i < count; ++i) {
                int j;
                writer.startList();
                writer.setPosition(i);
                if ("array1".equals(name)) {
                    ArrayList<Integer> intList = new ArrayList<Integer>();
                    for (j = 0; j < 5; ++j) {
                        int i1 = j + i;
                        writer.writeInt(i1);
                        intList.add(i1);
                    }
                    writer.setValueCount(5);
                    writer.endList();
                    arrayData1.add(intList);
                }
                if (!"array2".equals(name)) continue;
                ArrayList<LocalDateTime> dateTimeList = new ArrayList<LocalDateTime>();
                for (j = 0; j < 5; ++j) {
                    writer.writeTimeStampMilliTZ(epochMilli);
                    dateTimeList.add(localDateTime);
                }
                writer.setValueCount(5);
                writer.endList();
                arrayData2.add(dateTimeList);
            }
        });
        vectors.forEach(vector -> vector.setValueCount(count));
        List fields = vectors.stream().map(ValueVector::getField).collect(Collectors.toList());
        Schema schema = new Schema(fields);
        return new VectorSchemaRoot(schema, vectors, count);
    }

    private static void writeKeyAndValue(UnionMapWriter writer, Object value, int rowIndex, BufferAllocator allocator) {
        writer.setPosition(rowIndex);
        if (value instanceof String) {
            byte[] bytes = ((String)value).getBytes(CharsetUtil.UTF_8);
            ArrowBuf buffer = allocator.buffer((long)bytes.length);
            buffer.writeBytes(bytes);
            VarCharHolder holder = new VarCharHolder();
            holder.start = 0;
            holder.buffer = buffer;
            holder.end = bytes.length;
            writer.write(holder);
        } else if (value instanceof LocalDateTime) {
            LocalDateTime dateTime = (LocalDateTime)value;
            TimeMilliHolder holder = new TimeMilliHolder();
            holder.value = (int)dateTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli();
            writer.write(holder);
        }
    }

    @Test
    public void testSeatunnelRow() throws Exception {
        try (ByteArrayOutputStream out = new ByteArrayOutputStream();
             ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(out));){
            writer.writeBatch();
            out.flush();
            ArrayList<SeaTunnelRow> rows = new ArrayList<SeaTunnelRow>();
            try (ArrowToSeatunnelRowReader reader = new ArrowToSeatunnelRowReader(out.toByteArray(), this.getSeatunnelRowType(true)).readArrow();){
                while (reader.hasNext()) {
                    rows.add(reader.next());
                }
                Assertions.assertEquals((int)10, (int)rows.size());
            }
            List actualBooleanData = rows.stream().map(s -> s.getField(0)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(Arrays.asList(Boolean.FALSE, Boolean.TRUE), actualBooleanData);
            List actualByteData = rows.stream().map(s -> s.getField(1)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(byteData, actualByteData);
            List actualShortData = rows.stream().map(s -> s.getField(2)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(shortData, actualShortData);
            List actualIntData = rows.stream().map(s -> s.getField(3)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(intData, actualIntData);
            List actualLongData = rows.stream().map(s -> s.getField(4)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(longData, actualLongData);
            List actualFloatData = rows.stream().map(s -> s.getField(5)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(Collections.singletonList(Float.valueOf(1.23f)), actualFloatData);
            List actualDoubleData = rows.stream().map(s -> s.getField(6)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(Collections.singletonList(1.23456789), actualDoubleData);
            List actualStringData = rows.stream().map(s -> s.getField(7)).collect(Collectors.toList());
            Assertions.assertEquals(stringData, actualStringData);
            List actualDecimalData = rows.stream().map(s -> s.getField(8)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(Collections.singletonList(decimalData), actualDecimalData);
            List actualTimestamp1Data = rows.stream().map(s -> s.getField(9)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(Collections.singletonList(localDateTime), actualTimestamp1Data);
            List actualString2Data = rows.stream().map(s -> s.getField(10)).collect(Collectors.toList());
            Assertions.assertEquals(stringData, actualString2Data);
            List actualString3Data = rows.stream().map(s -> s.getField(11)).collect(Collectors.toList());
            Assertions.assertEquals(stringData, actualString3Data);
            List actualTimestamp2Data = rows.stream().map(s -> s.getField(12)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(Collections.singletonList(localDateTime), actualTimestamp2Data);
            List actualTimeDate = rows.stream().map(s -> s.getField(13)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(Collections.singletonList(localDateTime.toLocalTime()), actualTimeDate);
            List actualDate1Data = rows.stream().map(s -> s.getField(14)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(Collections.singletonList(localDateTime.toLocalDate()), actualDate1Data);
            List actualDate2Data = rows.stream().map(s -> s.getField(15)).distinct().collect(Collectors.toList());
            Assertions.assertEquals(Collections.singletonList(localDateTime.toLocalDate()), actualDate2Data);
            List actualArrayIntData = rows.stream().map(s -> s.getField(16)).collect(Collectors.toList());
            Assertions.assertIterableEquals(arrayData1, actualArrayIntData);
            List actualArrayTimestampData = rows.stream().map(s -> s.getField(17)).collect(Collectors.toList());
            Assertions.assertIterableEquals(arrayData2, actualArrayTimestampData);
        }
    }

    @Test
    public void testConvertArrowSpeed() throws Exception {
        Stopwatch stopwatch = Stopwatch.createStarted();
        int count = 1000000;
        try (RootAllocator rootAllocator = new RootAllocator(Integer.MAX_VALUE);
             VectorSchemaRoot vectorSchemaRoot = ArrowToSeatunnelRowReaderTest.buildVectorSchemaRoot(rootAllocator, count, false);
             ByteArrayOutputStream out = new ByteArrayOutputStream();
             ArrowStreamWriter writer = new ArrowStreamWriter(vectorSchemaRoot, null, Channels.newChannel(out));){
            stopwatch.stop();
            System.out.printf("build %s rows vectorSchemaRoot cost %s ms \n", count, stopwatch.elapsed(java.util.concurrent.TimeUnit.MILLISECONDS));
            writer.writeBatch();
            out.flush();
            ArrayList<SeaTunnelRow> rows = new ArrayList<SeaTunnelRow>();
            stopwatch.reset().start();
            SeaTunnelRowType seatunnelRowType = this.getSeatunnelRowType(false);
            try (ArrowToSeatunnelRowReader reader = new ArrowToSeatunnelRowReader(out.toByteArray(), seatunnelRowType).readArrow();){
                while (reader.hasNext()) {
                    rows.add(reader.next());
                }
                stopwatch.stop();
                System.out.printf("read %s rows cost %s ms ", rows.size(), stopwatch.elapsed(java.util.concurrent.TimeUnit.MILLISECONDS));
                Assertions.assertEquals((int)count, (int)rows.size());
            }
        }
    }

    private SeaTunnelRowType getSeatunnelRowType(boolean allType) {
        String[] fieldNames = (String[])seaTunnelDataTypeHolder.stream().filter(h -> allType ? h.getFlag() >= 0 : h.getFlag() == 1).map(SeaTunnelDataTypeHolder::getFiledName).toArray(String[]::new);
        SeaTunnelDataType[] seaTunnelDataTypes = (SeaTunnelDataType[])seaTunnelDataTypeHolder.stream().filter(h -> allType ? h.getFlag() >= 0 : h.getFlag() == 1).map(SeaTunnelDataTypeHolder::getSeatunnelDataType).toArray(SeaTunnelDataType[]::new);
        return new SeaTunnelRowType(fieldNames, seaTunnelDataTypes);
    }

    @AfterAll
    public static void afterAll() throws Exception {
        try {
            if (root != null) {
                root.close();
            }
            if (rootAllocator != null) {
                rootAllocator.close();
            }
        }
        catch (Exception e) {
            throw new RuntimeException("failed to close arrow stream reader.", e);
        }
    }

    static {
        seaTunnelDataTypeHolder = new ArrayList<SeaTunnelDataTypeHolder>();
        localDateTime = LocalDateTime.parse("2025-02-15 02:21:23", DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
        stringData = new ArrayList<String>();
        byteData = new ArrayList<Byte>();
        shortData = new ArrayList<Short>();
        intData = new ArrayList<Integer>();
        longData = new ArrayList<Long>();
        decimalData = new BigDecimal("1234567.89");
        arrayData1 = new ArrayList<List<Integer>>();
        arrayData2 = new ArrayList<List<LocalDateTime>>();
        mapData = new ArrayList<Map<String, LocalDateTime>>();
    }
}

