/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.translation.spark.sink.writer;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.seatunnel.api.sink.SeaTunnelSink;
import org.apache.seatunnel.api.sink.SinkAggregatedCommitter;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.translation.spark.sink.writer.SparkDataWriterFactory;
import org.apache.seatunnel.translation.spark.sink.writer.SparkWriterCommitMessage;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter;
import org.apache.spark.sql.sources.v2.writer.DataWriterFactory;
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage;

public class SparkDataSourceWriter<StateT, CommitInfoT, AggregatedCommitInfoT>
implements DataSourceWriter {
    protected final SeaTunnelSink<SeaTunnelRow, StateT, CommitInfoT, AggregatedCommitInfoT> sink;
    @Nullable
    protected final SinkAggregatedCommitter<CommitInfoT, AggregatedCommitInfoT> sinkAggregatedCommitter;

    public SparkDataSourceWriter(SeaTunnelSink<SeaTunnelRow, StateT, CommitInfoT, AggregatedCommitInfoT> sink) throws IOException {
        this.sink = sink;
        this.sinkAggregatedCommitter = sink.createAggregatedCommitter().orElse(null);
    }

    public DataWriterFactory<InternalRow> createWriterFactory() {
        return new SparkDataWriterFactory<CommitInfoT, StateT>(this.sink);
    }

    public void commit(WriterCommitMessage[] messages) {
        if (this.sinkAggregatedCommitter != null) {
            try {
                this.sinkAggregatedCommitter.commit(this.combineCommitMessage(messages));
            }
            catch (IOException e) {
                throw new RuntimeException("SinkAggregatedCommitter commit failed in driver", e);
            }
        }
    }

    public void abort(WriterCommitMessage[] messages) {
        if (this.sinkAggregatedCommitter != null) {
            try {
                this.sinkAggregatedCommitter.abort(this.combineCommitMessage(messages));
            }
            catch (Exception e) {
                throw new RuntimeException("SinkAggregatedCommitter abort failed in driver", e);
            }
        }
    }

    @Nonnull
    private List<AggregatedCommitInfoT> combineCommitMessage(WriterCommitMessage[] messages) {
        if (this.sinkAggregatedCommitter == null || messages.length == 0) {
            return Collections.emptyList();
        }
        List commitInfos = Arrays.stream(messages).map(m -> ((SparkWriterCommitMessage)m).getMessage()).filter(Objects::nonNull).collect(Collectors.toList());
        return Collections.singletonList(this.sinkAggregatedCommitter.combine(commitInfos));
    }
}

