package com.alibaba.schedulerx.worker.batch;

import com.alibaba.schedulerx.protocol.Worker.MasterStartContainerRequest;
import com.alibaba.schedulerx.worker.log.LogFactory;
import com.alibaba.schedulerx.worker.log.Logger;
import com.alibaba.schedulerx.worker.master.StreamTaskMaster;

import java.util.Comparator;
import java.util.List;
import java.util.concurrent.Semaphore;

/**
 * StreamTaskPushReqHandler
 * @author yaohui
 * @create 2023/5/18 11:19 AM
 **/
public class StreamTaskPushReqHandler<T> extends TaskDispatchReqHandler<T> {

    private static final Logger LOGGER = LogFactory.getLogger(StreamTaskPushReqHandler.class);

    private final Semaphore semaphore;

    private Long currentBatchNo;

    public StreamTaskPushReqHandler(long jobInstanceId, int globalConcurrency, int batchSize, ReqQueue<T> queue) {
        super(jobInstanceId, 1, 1, batchSize, queue,
                "Schedulerx-Batch-Tasks-Dispatch-Thread-", "Schedulerx-Batch-Tasks-Retrieve-Thread-");
        semaphore = new Semaphore(globalConcurrency);
    }

    public void release(int permits) {
        semaphore.release(permits);
    }

    public void release() {
        release(1);
    }

    @Override
    public void process(long jobInstanceId, List<T> reqs, String workerAddr) {
        batchProcessSvc.submit(new BatchTasksDispatchRunnable(jobInstanceId, (List<MasterStartContainerRequest>) reqs));
    }

    @Override
    protected int getBatchSize() {
        int batchSize = super.getBatchSize();
        return Math.min(batchSize, semaphore.availablePermits());
    }

    public synchronized boolean allTasksPushed(Long batchNo) {
        return currentBatchNo != null && batchNo < currentBatchNo;
    }

    private class BatchTasksDispatchRunnable implements Runnable {
        private long jobInstanceId;
        private List<MasterStartContainerRequest> reqs;
        BatchTasksDispatchRunnable(long jobInstanceId, List<MasterStartContainerRequest> reqs) {
            this.jobInstanceId = jobInstanceId;
            this.reqs = reqs;
        }

        @Override
        public void run() {
            try {
                long startTime = System.currentTimeMillis();
                //TODO 支持持续保持最大并发数
                semaphore.acquire(reqs.size());
                ((StreamTaskMaster)taskMasterPool.get(jobInstanceId)).batchDispatchTasks(reqs);
                LOGGER.info("jobInstance={}, batch dispatch cost:{} ms, dispatchSize:{}, size:{}",
                        jobInstanceId, System.currentTimeMillis() - startTime, dispatchSize, reqs.size());
            } catch (Throwable e) {
                LOGGER.error(e);
            } finally {
                for (MasterStartContainerRequest req:reqs){
                    if (!req.getFailover()) {
                        currentBatchNo = req.getSerialNum();
                    }
                }
                activeRunnableNum.decrementAndGet();
            }
        }
    }

}
