/*
 * Decompiled with CFR 0.152.
 */
package com.lancedb.lance.spark.write;

import com.lancedb.lance.FragmentMetadata;
import com.lancedb.lance.fragment.FragmentMergeResult;
import com.lancedb.lance.spark.LanceConfig;
import com.lancedb.lance.spark.LanceDataset;
import com.lancedb.lance.spark.arrow.LanceArrowWriter;
import com.lancedb.lance.spark.arrow.LanceArrowWriter$;
import com.lancedb.lance.spark.internal.LanceDatasetAdapter;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.arrow.c.ArrowArrayStream;
import org.apache.arrow.c.Data;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters;
import org.apache.spark.sql.connector.write.BatchWrite;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.connector.write.DataWriterFactory;
import org.apache.spark.sql.connector.write.PhysicalWriteInfo;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.LanceArrowUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AddColumnsBackfillBatchWrite
implements BatchWrite {
    private static final Logger logger = LoggerFactory.getLogger(AddColumnsBackfillBatchWrite.class);
    private final StructType schema;
    private final LanceConfig config;
    private final List<String> newColumns;

    public AddColumnsBackfillBatchWrite(StructType schema, LanceConfig config, List<String> newColumns) {
        this.schema = schema;
        this.config = config;
        this.newColumns = newColumns;
    }

    public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) {
        return new AddColumnsWriterFactory(this.schema, this.config, this.newColumns);
    }

    public boolean useCommitCoordinator() {
        return false;
    }

    public void commit(WriterCommitMessage[] messages) {
        List<FragmentMetadata> fragments = Arrays.stream(messages).map(m -> (TaskCommit)m).map(TaskCommit::getFragments).flatMap(Collection::stream).collect(Collectors.toList());
        if (fragments.isEmpty()) {
            logger.info("No merged fragments to commit.");
            return;
        }
        StructType sparkSchema = Arrays.stream(messages).map(m -> (TaskCommit)m).map(TaskCommit::getSchema).filter(Objects::nonNull).findFirst().orElse(null);
        if (sparkSchema == null) {
            throw new RuntimeException("No merged schema found in commit messages.");
        }
        Set mergedFragmentIds = fragments.stream().map(FragmentMetadata::getId).collect(Collectors.toSet());
        LanceDatasetAdapter.getFragments(this.config).stream().filter(f -> !mergedFragmentIds.contains(f.getId())).forEach(fragments::add);
        Schema schema = LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false);
        LanceDatasetAdapter.mergeFragments(this.config, fragments, schema);
    }

    public void abort(WriterCommitMessage[] messages) {
        throw new UnsupportedOperationException();
    }

    public String toString() {
        return String.format("AddColumnsWriterFactory(datasetUri=%s)", this.config.getDatasetUri());
    }

    public static class AddColumnsWriterFactory
    implements DataWriterFactory {
        private final LanceConfig config;
        private final StructType schema;
        private final List<String> newColumns;

        protected AddColumnsWriterFactory(StructType schema, LanceConfig config, List<String> newColumns) {
            this.schema = schema;
            this.config = config;
            this.newColumns = newColumns;
        }

        public DataWriter<InternalRow> createWriter(int partitionId, long taskId) {
            return new AddColumnsWriter(this.config, this.schema, this.newColumns);
        }
    }

    public static class TaskCommit
    implements WriterCommitMessage {
        private final List<FragmentMetadata> fragments;
        private final StructType schema;

        TaskCommit(List<FragmentMetadata> fragments, StructType schema) {
            this.fragments = fragments;
            this.schema = schema;
        }

        List<FragmentMetadata> getFragments() {
            return this.fragments;
        }

        StructType getSchema() {
            return this.schema;
        }
    }

    public static class AddColumnsWriter
    implements DataWriter<InternalRow> {
        private final LanceConfig config;
        private final StructType schema;
        private final int fragmentIdField;
        private final List<FragmentMetadata> fragments;
        private Schema mergedSchema;
        private StructType writerSchema;
        private int fragmentId = -1;
        private VectorSchemaRoot data;
        private LanceArrowWriter writer = null;

        public AddColumnsWriter(LanceConfig config, StructType schema, List<String> newColumns) {
            this.config = config;
            this.schema = schema;
            this.fragmentIdField = schema.fieldIndex(LanceDataset.FRAGMENT_ID_COLUMN.name());
            this.fragments = new ArrayList<FragmentMetadata>();
            this.writerSchema = new StructType();
            Arrays.stream(schema.fields()).filter(f -> newColumns.contains(f.name()) || f.name().equals(LanceDataset.ROW_ADDRESS_COLUMN.name())).forEach(f -> {
                this.writerSchema = this.writerSchema.add(f);
            });
            this.createWriter();
        }

        public void write(InternalRow record) throws IOException {
            int fragId = record.getInt(this.fragmentIdField);
            if (this.fragmentId == -1) {
                this.fragmentId = fragId;
            }
            if (fragId != this.fragmentId && this.data != null) {
                this.mergeFragment();
                this.fragmentId = fragId;
                this.createWriter();
            }
            for (int i = 0; i < this.writerSchema.fields().length; ++i) {
                this.writer.field(i).write((SpecializedGetters)record, this.schema.fieldIndex(this.writerSchema.fields()[i].name()));
            }
        }

        private void createWriter() {
            this.data = VectorSchemaRoot.create((Schema)LanceArrowUtils.toArrowSchema(this.writerSchema, "UTC", false, false), (BufferAllocator)LanceDatasetAdapter.allocator);
            this.writer = LanceArrowWriter$.MODULE$.create(this.data, this.writerSchema);
        }

        private void mergeFragment() {
            this.writer.finish();
            ByteArrayOutputStream out = new ByteArrayOutputStream();
            try (ArrowStreamWriter writer = new ArrowStreamWriter(this.data, null, (OutputStream)out);){
                writer.start();
                writer.writeBatch();
                writer.end();
            }
            catch (IOException e) {
                throw new RuntimeException("Cannot write schema root", e);
            }
            byte[] arrowData = out.toByteArray();
            ByteArrayInputStream in = new ByteArrayInputStream(arrowData);
            try (ArrowStreamReader reader = new ArrowStreamReader((InputStream)in, LanceDatasetAdapter.allocator);
                 ArrowArrayStream stream = ArrowArrayStream.allocateNew((BufferAllocator)LanceDatasetAdapter.allocator);){
                Data.exportArrayStream((BufferAllocator)LanceDatasetAdapter.allocator, (ArrowReader)reader, (ArrowArrayStream)stream);
                FragmentMergeResult result = LanceDatasetAdapter.mergeFragmentColumn(this.config, this.fragmentId, stream, LanceDataset.ROW_ADDRESS_COLUMN.name(), LanceDataset.ROW_ADDRESS_COLUMN.name());
                this.fragments.add(result.getFragmentMetadata());
                this.mergedSchema = result.getSchema().asArrowSchema();
            }
            catch (Exception e) {
                throw new RuntimeException("Cannot read arrow stream.", e);
            }
            this.data.close();
        }

        public WriterCommitMessage commit() {
            if (this.fragmentId >= 0 && this.data != null) {
                this.mergeFragment();
            }
            return new TaskCommit(this.fragments, this.mergedSchema == null ? null : LanceArrowUtils.fromArrowSchema(this.mergedSchema));
        }

        public void abort() {
        }

        public void close() throws IOException {
        }
    }
}

