package com.ksyun.ks3.service.common;

import com.ksyun.ks3.AutoAbortInputStream;
import com.ksyun.ks3.dto.GetObjectResult;
import com.ksyun.ks3.exception.Ks3ClientException;
import com.ksyun.ks3.service.Ks3Client;
import com.ksyun.ks3.service.Ks3ClientConfig;
import com.ksyun.ks3.service.request.GetObjectRequest;
import com.ksyun.ks3.utils.CommonUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.io.*;
import java.util.Queue;
import java.util.concurrent.*;

/**
 * 多线程下载
 */
public class MultiThreadDownloader {
    private static final Log log = LogFactory.getLog(MultiThreadDownloader.class);
    private PipedOutputStream pipedOut;
    private BufferedInputStream bufferedInputStream;
    private final Ks3Client client;
    private final GetObjectRequest request;
    private final GetObjectResult object;
    private static final int PIPE_SIZE = 1024 * 1024;

    public MultiThreadDownloader(Ks3Client client, GetObjectRequest request, GetObjectResult object) {
        this.client = client;
        this.request = request;
        this.object = object;
        this.init();
    }

    private void validateRange() {
        long[] range = request.getRange();
        if (range == null) {
            return;
        }

        if (range.length != 2) {
            log.warn("Invalid range value: {}, ignore it and request for entire object");
            return;
        }

        if (range[0] < 0) {
            throw new IllegalArgumentException("The start of range must not be negative");
        }

        if (range[1] < 0) {
            throw new IllegalArgumentException("The end of range must not be negative");
        }
    }

    private void init() throws Ks3ClientException {
        log.debug("use multiThread...");
        validateRange();

        try {
            pipedOut = new PipedOutputStream();
            PipedInputStream pipedIn = new PipedInputStream(pipedOut, PIPE_SIZE);
            bufferedInputStream = new BufferedInputStream(pipedIn);
        } catch (Exception e) {
            throw new Ks3ClientException(e);
        }

        new Thread(new MultipartDownloadWorker()).start();
    }

    public InputStream getInputStream() {
        return this.bufferedInputStream;
    }

    /**
     * 分块下载 worker
     * 负责分块、把下载结果写入流中
     */
    class MultipartDownloadWorker implements Runnable {

        @Override
        public void run() {
            final Ks3ClientConfig.MultiThreadDownloadConf conf = client.getKs3config().getMultiThreadDownloadConf();
            int threadNum = conf.getThreadNum();
            long blockSize = conf.getBlockSize();
            ExecutorService downThread = Executors.newFixedThreadPool(threadNum);
            boolean isRange = request.getRange() != null && request.getRange().length == 2;

            try {
                long length = object.getObject().getObjectMetadata().getContentLength();
                if (isRange) {
                    length = object.getObject().getObjectMetadata().getInstanceLength();
                }
                long totalParts = (long) Math.ceil((double) length / blockSize);
                int partIndex = 0;
                Queue<Future<byte[]>> queue = new LinkedBlockingQueue<>();
                while (partIndex < totalParts) {
                    // 首次按线程数+1加入线程池，之后读完一块开始下载一块
                    int addNum = partIndex == 0 ? threadNum + 1 : 1;
                    for (int i = 0; i < addNum && partIndex < totalParts; i++, partIndex++) {
                        long start = partIndex * blockSize;
                        long end = (partIndex + 1 == totalParts ? length : (partIndex + 1) * blockSize) - 1;
                        if (isRange) {
                            start = request.getRange()[0] + partIndex * blockSize;
                            end = partIndex + 1 == totalParts ? request.getRange()[1] : request.getRange()[0] + (partIndex + 1) * blockSize - 1;
                        }
                        log.debug("download block:" + partIndex + ", start=" + start + ", end=" + end);
                        FilePartDownloadWorker worker = new FilePartDownloadWorker(start, end);
                        queue.add(downThread.submit(worker));
                    }

                    if (!queue.isEmpty()) {
                        byte[] b = queue.poll().get();
                        pipedOut.write(b);
                    }
                }
                while (!queue.isEmpty()) {
                    byte[] bytes = queue.poll().get();
                    pipedOut.write(bytes);
                }
            } catch (Exception e) {
                throw new Ks3ClientException(e);
            } finally {
                CommonUtils.closeQuietly(pipedOut);
                downThread.shutdownNow();
            }
        }
    }

    /**
     * 块下载 worker
     * 负责下载文件块
     */
    class FilePartDownloadWorker implements Callable<byte[]> {
        private final long start;
        private final long end;

        public FilePartDownloadWorker(long start, long end) {
            this.start = start;
            this.end = end;
        }

        @Override
        public byte[] call() throws Exception {
            Exception ex = null;
            byte[] bytes = new byte[(int) (end - start + 1)];
            int maxTryTimes = client.getKs3config().getMultiThreadDownloadConf().getMaxTryTimes();
            for (int i = 0; i < maxTryTimes; i++) {
                if (i > 0) {
                    CommonUtils.sleep(100L * i);
                }
                AutoAbortInputStream inputStream = null;
                try {
                    GetObjectRequest req = new GetObjectRequest(request.getBucket(), request.getKey());
                    req.setRange(start, end);
                    req.setMultiThread(false);
                    GetObjectResult result = client.getObject(req);
                    inputStream = result.getObject().getObjectContent();
                    read(inputStream, bytes);
                    ex = null;
                    break;
                } catch (Exception e) {
                    String message = e.getMessage() + ", part info: start = " + start + ", end = " + end;
                    ex = new Exception(message, e);
                } finally {
                    CommonUtils.closeQuietly(inputStream);
                }
            }

            if (ex != null) {
                throw ex;
            }
            return bytes;
        }

        public static final int EOF = -1;

        /**
         * 将数据流读入到byte数组中
         *
         * @param input   流
         * @param buffer  byte 数组
         * @return        读到的 byte 数
         * @throws IOException IO 异常
         */

        public int read(final InputStream input, final byte[] buffer) throws IOException {
            return read(input, buffer, 0, buffer.length);
        }

        public int read(final InputStream input, final byte[] buffer, final int offset, final int length)
                throws IOException {
            if (length < 0) {
                throw new IllegalArgumentException("Length must not be negative: " + length);
            }
            int remaining = length;
            while (remaining > 0) {
                final int location = length - remaining;
                final int count = input.read(buffer, offset + location, remaining);
                if (EOF == count) { // EOF
                    break;
                }
                remaining -= count;
            }
            return length - remaining;
        }
    }
}
