package com.alibaba.schedulerx.worker.master.persistence;

import java.io.IOException;
import java.util.List;

import com.alibaba.schedulerx.common.domain.InstanceStatus;
import com.alibaba.schedulerx.common.domain.TaskStatus;
import com.alibaba.schedulerx.common.util.ConfigUtil;
import com.alibaba.schedulerx.protocol.Worker.ContainerReportTaskStatusRequest;
import com.alibaba.schedulerx.protocol.Worker.MasterStartContainerRequest;
import com.alibaba.schedulerx.worker.domain.TaskInfo;
import com.alibaba.schedulerx.worker.domain.TaskStatistics;
import com.alibaba.schedulerx.worker.domain.WorkerConstants;
import com.alibaba.schedulerx.worker.log.LogFactory;
import com.alibaba.schedulerx.worker.log.Logger;

import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
import org.apache.commons.collections.CollectionUtils;

/**
 * H2Persistence是单例模式，只在第一次初始化的时候执行initTable
 *
 * @author xiaomeng.hxm
 */
public abstract class H2Persistence implements TaskPersistence {
    private static final Logger LOGGER = LogFactory.getLogger(H2Persistence.class);
    protected H2ConnectionPool h2CP;
    protected TaskDao taskDao;
    private volatile boolean inited = false;

    //private String emptyWorkerAddr = "null";

    public H2Persistence() {}

    @Override
    public void initTable() throws Exception {
        if (!inited) {
            synchronized (this) {
                if (!inited) {
                    taskDao.dropTable();
                    taskDao.createTable();
                    inited = true;
                }
            }
        }
    }

    //@Override
    //public void updateTaskStatus(long jobId, long jobInstanceId, long taskId, TaskStatus status, String workerAddr,
    //        String workerId) throws Exception {
    //    taskDao.updateStatus(jobInstanceId, taskId, status.getValue(), workerAddr);
    //}

    @Override
    public int updateTaskStatus(long jobInstanceId, List<Long> taskIds, TaskStatus status, String workerId,
                                 String workerAddr) throws Exception {
        int res = -1;
        if (CollectionUtils.isEmpty(taskIds)) {
            return res;
        }
        try {
            res = taskDao.updateStatus(jobInstanceId, taskIds, status.getValue(), workerId, workerAddr);
        } catch (Throwable e) {
            LOGGER.error("jobInstanceId={}, updateTaskStatus error", jobInstanceId, e);
        }
        return res;
    }

    /**
     * !!!Attention!!! For Grid/Batch tasks, this method invoked only when finish statuses updated.
     * In order to reduce h2 size, this method will delete all finish tasks;
     * @param taskStatusInfos list of task status
     * @throws Exception
     */
    @Override
    public void updateTaskStatues(List<ContainerReportTaskStatusRequest> taskStatusInfos) throws Exception{
        if (CollectionUtils.isEmpty(taskStatusInfos)) {
            return;
        }
        /*
            update task statues always batch by same job instance
         */
        long jobInstanceId = taskStatusInfos.get(0).getJobInstanceId();
        //Map<Integer, List<Long>> status2TaskIds = Maps.newHashMap();
        //for (TaskStatusInfo taskStatusInfo : taskStatusInfos) {
        //    Integer status = taskStatusInfo.getStatus().getValue();
        //    Long taskId = taskStatusInfo.getTaskId();
        //    if (!status2TaskIds.containsKey(status)) {
        //        // status not exists , all below must be first time add in too
        //        status2TaskIds.put(status, Lists.newArrayList(taskId));
        //    } else {
        //        status2TaskIds.get(status).add(taskId);
        //    }
        //}
        List<Long> taskIds = Lists.newArrayList();
        for (ContainerReportTaskStatusRequest taskStatusInfo : taskStatusInfos) {
            TaskStatus taskStatus = TaskStatus.parseValue(taskStatusInfo.getStatus());
            if (taskStatus.isFinish()) {
                taskIds.add(taskStatusInfo.getTaskId());
            }
        }
        taskDao.batchDeleteTasks(jobInstanceId, taskIds);

        //for (Entry<Integer, List<Long>> entry : status2TaskIds.entrySet()) {
        //    List<Long> taskIds = entry.getValue();
        //    taskDao.batchDeleteTasks(jobInstanceId, taskIds);
        //}
    }

    @Override
    public void clearTasks(long jobInstanceId) throws Exception {
        taskDao.deleteByJobInstanceId(jobInstanceId);
    }

    @Override
    public void createTask(long jobId, long jobInstanceId, long taskId, String taskName, ByteString taskBody)
            throws Exception {
        taskDao.insert(jobId, jobInstanceId, taskId, taskName, taskBody);
    }

    @Override
    public void createTasks(List<MasterStartContainerRequest> containers, String workerId, String workerAddr) throws Exception {
        boolean createSucess = false;
        for (int i = 0; i < 3; i++) {
            try {
                taskDao.batchInsert(containers, workerId, workerAddr);
                createSucess = true;
                break;
            } catch (Exception e) {
                LOGGER.warn("batch insert tasks error, try after 1000ms", e);
                Thread.sleep(1000);
            }
        }
        if (!createSucess) {
            throw new IOException("batch insert tasks error, workerId=" + workerId + ", workerAddr=" + workerAddr);  
        }
        
    }

    @Override
    public List<TaskInfo> pull(long jobInstanceId, int pageSize) throws Exception {
        List<TaskInfo> taskInfoList = Lists.newArrayList();
        List<TaskSnapshot> taskSnapshots = taskDao.queryTaskList(jobInstanceId, TaskStatus.INIT.getValue(), pageSize);
        if (!taskSnapshots.isEmpty()) {
            List<Long> taskIdList = Lists.newArrayList();
            for (TaskSnapshot taskSnapshot : taskSnapshots) {
                taskIdList.add(taskSnapshot.getTask_id());
                taskInfoList.add(convert2TaskInfo(taskSnapshot));
            }
            for (int i = 0; i < 3; i++) {
                try {
                    taskDao.batchUpdateStatus(jobInstanceId, taskIdList, TaskStatus.PULLED.getValue());
                    break;
                } catch (Exception e) {
                    LOGGER.warn("batchUpdateStatus error, try after 1000ms", e);
                    Thread.sleep(1000);
                }
            }
            
        }
        return taskInfoList;
    }

    @Override
    public InstanceStatus checkInstanceStatus(long jobInstanceId) throws Exception {
        InstanceStatus instanceStatus;
        boolean exist = taskDao.exist(jobInstanceId);
        instanceStatus = exist ? InstanceStatus.RUNNING : InstanceStatus.SUCCESS;
        return instanceStatus;
    }

    private TaskInfo convert2TaskInfo(TaskSnapshot taskSnapshot) {
        return TaskInfo.newBuilder()
                .setTaskId(taskSnapshot.getTask_id())
                .setTaskName(taskSnapshot.getTask_name())
                .setTaskBody(taskSnapshot.getTask_body())
                .setJobId(taskSnapshot.getJob_id())
                .setJobInstanceId(taskSnapshot.getJob_instance_id())
                .build();
    }

    private TaskSnapshot convert2TaskSnapshot(TaskInfo taskInfo) {
        TaskSnapshot taskSnapshot = new TaskSnapshot();
        taskSnapshot.setJob_id(taskInfo.getJobId());
        taskSnapshot.setJob_instance_id(taskInfo.getJobInstanceId());
        taskSnapshot.setTask_id(taskInfo.getTaskId());
        taskSnapshot.setTask_name(taskInfo.getTaskName());
        taskSnapshot.setTask_body(taskInfo.getTaskBody());
        return taskSnapshot;
    }


    @Override
    public int batchUpdateTaskStatus(long jobInstanceId, TaskStatus status, String workerId, String workerAddr) {
        int res = -1;
        for (int i = 0; i < 3; i++) {
            try {
                if (ConfigUtil.getWorkerConfig().getBoolean(WorkerConstants.MAP_MASTER_FAILOVER_ENABLE, true)) {
                    res = taskDao.batchUpdateStatus(jobInstanceId, status.getValue(), workerId, workerAddr);
                } else {
                    res = taskDao.batchDeleteTasks(jobInstanceId, workerId, workerAddr);
                }
                break;
            } catch (Throwable e) {
                LOGGER.error("batchUpdateTaskStatus error, try after 1000ms", e);
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e1) {
                    LOGGER.error("", e1);
                }
            }
        }
        return res;
    }

    /**
     * 获取h2 task汇总统计
     *
     * @return
     * @throws Exception
     */
    public TaskStatistics getTaskStatistics() throws Exception {
        return taskDao.getTaskStatistics();
    }

    /**
     * 获取h2中存留的结束但是未删除的实例
     *
     * @return
     * @throws Exception
     */
    public List<Long> getDistinctInstanceIds() throws Exception {
        return taskDao.getDistinctInstanceIds();
    }

    public boolean isInited() {
        return inited;
    }
}
