/*
 * Decompiled with CFR 0.152.
 */
package com.tigergraph.spark.write;

import com.tigergraph.spark.TigerGraphConnection;
import com.tigergraph.spark.client.Write;
import com.tigergraph.spark.util.Options;
import com.tigergraph.spark.util.Utils;
import com.tigergraph.spark.write.TigerGraphWriterCommitMessage;
import java.io.IOException;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampNTZType;
import org.apache.spark.sql.types.TimestampType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TigerGraphDataWriter
implements DataWriter<InternalRow> {
    private static final Logger logger = LoggerFactory.getLogger(TigerGraphDataWriter.class);
    private final StructType schema;
    private final int partitionId;
    private final long taskId;
    private final long epochId;
    private final Write write;
    private final String version;
    private final String jobId;
    private final String graph;
    private final String sep;
    private final String eol;
    private final int maxBatchSizeInBytes;
    private final Map<String, Object> queryMap;
    private final List<BiFunction<InternalRow, Integer, String>> converters;
    private final StringBuilder sb = new StringBuilder();
    private int sbOffset = 0;
    private long totalLines = 0L;

    TigerGraphDataWriter(StructType schema, TigerGraphConnection conn, int partitionId, long taskId, long epochId) {
        this.schema = schema;
        this.partitionId = partitionId;
        this.taskId = taskId;
        this.epochId = epochId;
        this.write = conn.getWrite();
        this.version = conn.getVersion();
        this.jobId = conn.getLoadingJobId();
        Options opts = conn.getOpts();
        this.graph = opts.getString("graph");
        this.maxBatchSizeInBytes = opts.getInt("loading.batch.size.bytes");
        this.sep = opts.getString("loading.separator");
        this.eol = opts.getString("loading.eol");
        this.queryMap = new HashMap<String, Object>();
        this.queryMap.put("tag", opts.getString("loading.job"));
        this.queryMap.put("filename", opts.getString("loading.filename"));
        this.queryMap.put("sep", opts.getString("loading.separator"));
        this.queryMap.put("eol", opts.getString("loading.eol"));
        this.queryMap.put("timeout", opts.getInt("loading.timeout.ms"));
        if (Utils.versionCmp(this.version, "3.9.4") >= 0) {
            this.queryMap.put("jobid", this.jobId);
            if (opts.containsOption("loading.max.num.error")) {
                this.queryMap.put("max_num_error", opts.getInt("loading.max.num.error"));
            }
            if (opts.containsOption("loading.max.percent.error")) {
                this.queryMap.put("max_percent_error", opts.getInt("loading.max.percent.error"));
            }
        }
        this.converters = TigerGraphDataWriter.getConverters(schema);
        logger.info("Created data writer for partition {}, task {}, epochId {}", new Object[]{partitionId, taskId, epochId});
    }

    TigerGraphDataWriter(StructType schema, TigerGraphConnection conn, int partitionId, long taskId) {
        this(schema, conn, partitionId, taskId, -1L);
    }

    public void close() throws IOException {
    }

    public void write(InternalRow record) throws IOException {
        String line = IntStream.range(0, record.numFields()).mapToObj(i -> record.isNullAt(i) ? "" : this.converters.get(i).apply(record, i)).collect(Collectors.joining(this.sep));
        if (this.sb.length() + line.length() + this.eol.length() > this.maxBatchSizeInBytes) {
            this.postToDDL();
        }
        this.sb.append(line).append(this.eol);
        ++this.sbOffset;
    }

    private void postToDDL() {
        Write.LoadingResponse resp = this.write.ddl(this.graph, this.sb.toString(), this.queryMap);
        logger.info("Upsert {} rows to TigerGraph graph {}", (Object)this.sbOffset, (Object)this.graph);
        resp.panicOnFail();
        Utils.removeUserData(resp.results);
        if (resp.hasInvalidRecord()) {
            logger.error("Found rejected rows, it won't abort the loading: ");
            logger.error(resp.results.toPrettyString());
        } else {
            logger.debug(resp.results.toPrettyString());
        }
        this.totalLines += (long)this.sbOffset;
        this.sbOffset = 0;
        this.sb.setLength(0);
    }

    public TigerGraphWriterCommitMessage commit() throws IOException {
        if (this.sb.length() > 0) {
            this.postToDDL();
        }
        logger.info("Finished writing {} rows to TigerGraph. Partition {}, task {}, epoch {}.", new Object[]{this.totalLines, this.partitionId, this.taskId, this.epochId});
        return new TigerGraphWriterCommitMessage(this.totalLines, this.partitionId, this.taskId);
    }

    public void abort() throws IOException {
        logger.error("Write aborted with {} records loaded. Partition {}, task {}, epoch {}", new Object[]{this.totalLines, this.partitionId, this.taskId, this.epochId});
    }

    protected static List<BiFunction<InternalRow, Integer, String>> getConverters(StructType schema) {
        return Stream.of(schema.fields()).map(f -> TigerGraphDataWriter.getConverter(f.dataType())).collect(Collectors.toList());
    }

    private static BiFunction<InternalRow, Integer, String> getConverter(DataType dt) {
        if (dt instanceof IntegerType) {
            return (row, idx) -> String.valueOf(row.getInt(idx.intValue()));
        }
        if (dt instanceof LongType) {
            return (row, idx) -> String.valueOf(row.getLong(idx.intValue()));
        }
        if (dt instanceof DoubleType) {
            return (row, idx) -> String.valueOf(row.getDouble(idx.intValue()));
        }
        if (dt instanceof FloatType) {
            return (row, idx) -> String.valueOf(row.getFloat(idx.intValue()));
        }
        if (dt instanceof ShortType) {
            return (row, idx) -> String.valueOf(row.getShort(idx.intValue()));
        }
        if (dt instanceof ByteType) {
            return (row, idx) -> String.valueOf(row.getByte(idx.intValue()));
        }
        if (dt instanceof BooleanType) {
            return (row, idx) -> String.valueOf(row.getBoolean(idx.intValue()));
        }
        if (dt instanceof StringType) {
            return (row, idx) -> row.getString(idx.intValue());
        }
        if (dt instanceof TimestampType || dt instanceof TimestampNTZType) {
            return (row, idx) -> new Timestamp(row.getLong(idx.intValue()) / 1000L).toString();
        }
        if (dt instanceof DateType) {
            return (row, idx) -> new Date((long)(row.getInt(idx.intValue()) * 24 * 60 * 60) * 1000L).toString();
        }
        if (dt instanceof DecimalType) {
            return (row, idx) -> row.getDecimal(idx.intValue(), ((DecimalType)dt).precision(), ((DecimalType)dt).scale()).toString();
        }
        throw new UnsupportedOperationException("Unsupported Spark type: " + dt.typeName() + ", please convert it to a string that matches the LOAD statement defined in the loading job: https://docs.tigergraph.com/gsql-ref/current/ddl-and-loading/creating-a-loading-job#_more_complex_attribute_expressions");
    }
}

