package com.alibaba.schedulerx.worker.master.persistence;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.List;

import com.alibaba.schedulerx.common.domain.TaskStatus;
import com.alibaba.schedulerx.protocol.Worker.MasterStartContainerRequest;
import com.alibaba.schedulerx.worker.domain.TaskStatistics;

import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
import org.apache.commons.lang.StringUtils;
import org.joda.time.DateTime;

/**
 *
 * @author xiaomeng.hxm
 */
public class TaskDao {

    private H2ConnectionPool h2CP;

    public TaskDao(H2ConnectionPool h2CP) {
        this.h2CP = h2CP;
    }

    public void dropTable() throws Exception {
        String sql = "DROP TABLE IF EXISTS task";
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement(sql);
            ps.executeUpdate();
        } finally {
            releaseConnection(conn, ps, null);
        }
    }

    public void createTable() throws Exception {
        String sql = "CREATE TABLE IF NOT EXISTS task (" +
                "job_id bigint(20) unsigned NOT NULL," +
                "job_instance_id bigint(20) unsigned NOT NULL," +
                "task_id bigint(20) unsigned NOT NULL," +
                "batch_no bigint(20) unsigned DEFAULT NULL," +
                "task_name varchar(100) NOT NULL DEFAULT ''," +
                "status int(11) NOT NULL," +
                "progress float NOT NULL DEFAULT '0'," +
                "gmt_create datetime NOT NULL," +
                "gmt_modified datetime NOT NULL," +
                "worker_addr varchar(30) NOT NULL DEFAULT ''," +
                "worker_id varchar(30) NOT NULL DEFAULT ''," +
                "task_body blob DEFAULT NULL," +
                "UNIQUE KEY uk_instance_and_task (job_instance_id,task_id)," +
                "KEY idx_job_instance_id (job_instance_id)," +
                "KEY idx_status (status)" +
                ")";
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement(sql);
            ps.executeUpdate();
        } finally {
            releaseConnection(conn, ps, null);
        }
    }

    public int insert(long jobId, long jobInstanceId, long taskId, String taskName, ByteString taskBody) throws SQLException {
        int result = 0;
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement("insert into task(job_id,job_instance_id,task_id,batch_no,task_name,status,"
                    + "gmt_create,gmt_modified,task_body) " +
                    "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)");
            int i=1;
            ps.setLong(i++, jobId);
            ps.setLong(i++, jobInstanceId);
            ps.setLong(i++, taskId);
            ps.setLong(i++, 0);
            ps.setString(i++, taskName);
            ps.setInt(i++, TaskStatus.PULLED.getValue());
            ps.setTimestamp(i++, new Timestamp(DateTime.now().getMillis()));
            ps.setTimestamp(i++, new Timestamp(DateTime.now().getMillis()));
            ps.setBytes(i++, taskBody.toByteArray());
            result = ps.executeUpdate();
            return result;
        } finally {
            releaseConnection(conn, ps, null);
        }
    }
    
    public int batchInsert(List<MasterStartContainerRequest> containers, String workerId, String workerAddr) throws SQLException {
        int result = 0;
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            conn.setAutoCommit(false); //将自动提交关闭
            ps = conn.prepareStatement("insert into task(job_id,job_instance_id,task_id,batch_no,task_name,status,"
                + "gmt_create,gmt_modified,task_body,worker_id,worker_addr) " +
                "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
            
            for (MasterStartContainerRequest snapshot : containers) {
                int i=1;
                ps.setLong(i++, snapshot.getJobId());
                ps.setLong(i++, snapshot.getJobInstanceId());
                ps.setLong(i++, snapshot.getTaskId());
                ps.setLong(i++, snapshot.getSerialNum());
                ps.setString(i++, snapshot.getTaskName());
                ps.setInt(i++, TaskStatus.RUNNING.getValue());
                ps.setTimestamp(i++, new Timestamp(DateTime.now().getMillis()));
                ps.setTimestamp(i++, new Timestamp(DateTime.now().getMillis()));
                ps.setBytes(i++, snapshot.getTask().toByteArray());
                ps.setString(i++, workerId);
                ps.setString(i++, workerAddr);
                ps.addBatch();
            }
            int[] counts = ps.executeBatch();
            conn.commit(); //执行完后，手动提交事务
            conn.setAutoCommit(true);//再把自动提交打开，避免影响其他需要自动提交的操作
            for (int count : counts) {
                result += count;
            }
        } finally {
            releaseConnection(conn, ps, null);
        }
        return result;
    }

    public TaskStatistics getTaskStatistics() throws SQLException {
        TaskStatistics result = new TaskStatistics();
        Connection conn = null;
        PreparedStatement ps = null;
        ResultSet rs = null;
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement("select count(distinct job_instance_id) from task");
            rs = ps.executeQuery();
            if (rs.next()) {
                result.setDistinctInstanceCount(rs.getLong(1));
            }
            // task总量
            ps = conn.prepareStatement("select count(*) from task");
            rs = ps.executeQuery();
            if (rs.next()) {
                result.setTaskCount(rs.getLong(1));
            }
        } finally {
            releaseConnection(conn, ps, rs);
        }
        return result;
    }

    public int updateStatus(long jobInstanceId, long taskId, int status, String workerAddr) throws SQLException {
        int result = 0;
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement("update task set status=?,worker_addr=?,gmt_modified=? where job_instance_id=? and task_id=?");
            ps.setInt(1, status);
            ps.setString(2, workerAddr);
            ps.setTimestamp(3, new Timestamp(DateTime.now().getMillis()));
            ps.setLong(4, jobInstanceId);
            ps.setLong(5, taskId);
            result = ps.executeUpdate();
            return result;
        } finally {
            releaseConnection(conn, ps, null);
        }
    }
    
    public int updateStatus(long jobInstanceId, List<Long> taskIds, int status,
                            String workerId, String workerAddr) throws SQLException {
        int result = 0;
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            String sql = "update task set status=?, worker_id=?, worker_addr=? WHERE job_instance_id=? and task_id =?";
            if (status == TaskStatus.INIT.getValue()) {
                // only running status can convert to init status;
                sql = sql + " and status = 3";
            }
            ps = conn.prepareStatement(sql);
            for (Long taskId : taskIds) {
                ps.setInt(1, status);
                ps.setString(2, workerId);
                ps.setString(3, workerAddr);
                ps.setLong(4, jobInstanceId);
                ps.setLong(5, taskId);
                ps.addBatch();
            }
            int[] counts = ps.executeBatch();
            for (int count : counts) {
                result += count;
            }
            return result;
        } finally {
            releaseConnection(conn, ps, null);
        }
    }

    public int batchUpdateStatus(long jobInstanceId, List<Long> taskIdList, int status) throws SQLException {
        int result = 0;
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            conn.setAutoCommit(false);//将自动提交关闭
            String sql = "update task set status=" + status + " where job_instance_id=" + jobInstanceId + " and task_id in ("
                    + StringUtils.join(taskIdList, ",") + ")";
              ps = conn.prepareStatement(sql);
//            ps = conn.prepareStatement("update task set status=?,gmt_modified=? where job_instance_id=? and task_id in (?)");
//            ps.setInt(1, status);
//            ps.setTimestamp(2, new Timestamp(DateTime.now().getMillis()));
//            ps.setLong(3, jobInstanceId);
//            ps.setArray(4, conn.createArrayOf("bigint", taskIdList.toArray()));
//            ps.setString(4, StringUtils.join(taskIdList, ","));
            result = ps.executeUpdate();
            conn.commit();//执行完后，手动提交事务
            conn.setAutoCommit(true);//再把自动提交打开，避免影响其他需要自动提交的操作
            return result;
        } finally {
            releaseConnection(conn, ps, null);
        }
    }

    public List<Long> getDistinctInstanceIds() throws SQLException{
        Connection conn = null;
        PreparedStatement ps = null;
        ResultSet rs = null;
        List<Long> result = new ArrayList<>();
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement("select distinct job_instance_id from task");
            rs = ps.executeQuery();
            while (rs.next()) {
                result.add(rs.getLong(1));
            }
        } finally {
            releaseConnection(conn, ps, rs);
        }
        return result;
    }

    //public int batchUpdateStatus(long jobInstanceId, int status, String workerAddr, List<Long> taskIdList) throws SQLException {
    //    int result = 0;
    //    Connection conn = null;
    //    PreparedStatement ps = null;
    //    try {
    //        conn = h2CP.getConnection();
    //        String sql;
    //        if (workerAddr == null) {
    //            sql = "update task set status=" + status + " where job_instance_id=" + jobInstanceId + " and task_id in ("
    //                + StringUtils.join(taskIdList, ",") + ")";
    //        } else {
    //            sql = "update task set status=" + status + ", worker=" + workerAddr + " where job_instance_id=" + jobInstanceId + " and task_id in ("
    //                + StringUtils.join(taskIdList, ",") + ")";
    //        }
    //        ps = conn.prepareStatement(sql);
    //        result = ps.executeUpdate();
    //        return result;
    //    } finally {
    //        releaseConnection(conn, ps, null);
    //    }
    //}
    
    

    public int batchUpdateStatus(long jobInstanceId, int status, String workerId, String workerAddr) throws SQLException {
        int result;
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            conn.setAutoCommit(false);//将自动提交关闭
            String sql;
            if (workerId != null) {
                sql = "update task set status=?,gmt_modified=? where job_instance_id=? and worker_id=? and worker_addr=?";
            } else {
                sql = "update task set status=?,gmt_modified=? where job_instance_id=?";
            }
            if (status == TaskStatus.INIT.getValue()) {
                //only running status can convert to init status;
                sql = sql + " and status = 3";
            }
            ps = conn.prepareStatement(sql);
            ps.setInt(1, status);
            ps.setTimestamp(2, new Timestamp(DateTime.now().getMillis()));
            ps.setLong(3, jobInstanceId);
            if (workerId != null) {
                ps.setString(4, workerId);
                ps.setString(5, workerAddr);
            }
            result = ps.executeUpdate();
            conn.commit();//执行完后，手动提交事务
            conn.setAutoCommit(true);//再把自动提交打开，避免影响其他需要自动提交的操作
            return result;
        } finally {
            releaseConnection(conn, ps, null);
        }
    }

    public int updateWorker(long jobInstanceId, long taskId, String workerId, String workerAddr) throws SQLException {
        int result = 0;
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement("update task set worker_id=?,worker_addr=?,gmt_modified=? where job_instance_id=? and task_id=?");
            ps.setString(1, workerId);
            ps.setString(2, workerAddr);
            ps.setTimestamp(3, new Timestamp(DateTime.now().getMillis()));
            ps.setLong(4, jobInstanceId);
            ps.setLong(5, taskId);
            result = ps.executeUpdate();
            return result;
        } finally {
            releaseConnection(conn, ps, null);
        }
    }

    public List<TaskSnapshot> queryTaskList(long jobInstanceId, int status, int pageSize) throws SQLException {
        List<TaskSnapshot> taskSnapshots = Lists.newArrayList();
        Connection conn = null;
        PreparedStatement ps = null;
        ResultSet rs = null;
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement("select * from task where job_instance_id=? and status=? limit ?");
            ps.setLong(1, jobInstanceId);
            ps.setInt(2, status);
            ps.setInt(3, pageSize);
            rs = ps.executeQuery();
            while (rs.next()) {
                TaskSnapshot taskSnapshot = convert2TaskSnapshot(rs);
                taskSnapshots.add(taskSnapshot);
            }
            return taskSnapshots;
        } finally {
            releaseConnection(conn, ps, rs);
        }
    }

    public List<Integer> queryStatus(long jobInstanceId) throws SQLException {
        List<Integer> statusList = Lists.newArrayList();
        Connection conn = null;
        PreparedStatement ps = null;
        ResultSet rs = null;
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement("select distinct(status) from task where job_instance_id=?");
            ps.setLong(1, jobInstanceId);
            rs = ps.executeQuery();
            while (rs.next()) {
                statusList.add(rs.getInt(1));
            }
            return statusList;
        } finally {
            releaseConnection(conn, ps, rs);
        }
    }
    
    public long queryTaskCount(long jobInstanceId) throws SQLException {
        long count = 0;
        Connection conn = null;
        PreparedStatement ps = null;
        ResultSet rs = null;
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement("select count(*) from task where job_instance_id=?");
            ps.setLong(1, jobInstanceId);
            rs = ps.executeQuery();
            if (rs.next()) {
                count = rs.getLong(1);
            }
            return count;
        } finally {
            releaseConnection(conn, ps, rs);
        }
    }
    
    /**
     * whether exist at least 1 record with 'job_instance_id' = jobInstance.
     * @param jobInstanceId
     * @return true if exist, otherwise false
     * @throws SQLException
     */
    public boolean exist(Long jobInstanceId, Long batchNo) throws SQLException {
        Connection conn = null;
        PreparedStatement ps = null;
        ResultSet rs = null;
        try {
            String param = "job_instance_id=?";
            if (batchNo != null) {
                param += " and batch_no=?";
            }
            conn = h2CP.getConnection();
            ps = conn.prepareStatement("select EXISTS (select 1 from task where "+param+")");
            int i=1;
            ps.setLong(i++, jobInstanceId);
            if (batchNo != null) {
                ps.setLong(i++, batchNo);
            }
            rs = ps.executeQuery();
            if (rs.next()) {
                return rs.getBoolean(1);
            } else {
                return false;
            }
        } finally {
            releaseConnection(conn, ps, rs);
        }
    }
    
    public int deleteByJobInstanceId(long jobInstanceId) throws SQLException {
        int result = 0;
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            ps = conn.prepareStatement("delete from task where job_instance_id=?");
            ps.setLong(1, jobInstanceId);
            result = ps.executeUpdate();
            return result;
        } finally {
            releaseConnection(conn, ps, null);
        }
    }
    
    public int batchDeleteTasks(long jobInstanceId, List<Long> taskIds) throws SQLException{
        int result = 0;
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            conn.setAutoCommit(false); //将自动提交关闭
            String sql = "delete from task where job_instance_id=? and task_id=?";
            ps = conn.prepareStatement(sql);
            for (Long taskId : taskIds) {
                ps.setLong(1, jobInstanceId);
                ps.setLong(2, taskId);
                ps.addBatch();
            }
            int[] counts = ps.executeBatch();
            conn.commit();//执行完后，手动提交事务
            conn.setAutoCommit(true);//再把自动提交打开，避免影响其他需要自动提交的操作
            for (int count : counts) {
                result += count;
            }
            return result;
        } finally {
            releaseConnection(conn, ps, null);
        }
    }
    
    public int batchDeleteTasks(long jobInstanceId, String workerId, String workerAddr) throws SQLException {
        int result = 0;
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = h2CP.getConnection();
            conn.setAutoCommit(false);//将自动提交关闭
            ps = conn.prepareStatement("delete from task where job_instance_id=? and worker_id=? and worker_addr=?");
            ps.setLong(1, jobInstanceId);
            ps.setString(2, workerId);
            ps.setString(3, workerAddr);
            result = ps.executeUpdate();
            conn.commit();//执行完后，手动提交事务
            conn.setAutoCommit(true);//再把自动提交打开，避免影响其他需要自动提交的操作
            return result;
        } finally {
            releaseConnection(conn, ps, null);
        }
    }

    private void releaseConnection(Connection conn, PreparedStatement ps, ResultSet rs) throws SQLException {
        if (rs != null) {
            rs.close();
        }
        if (ps != null) {
            ps.close();
        }
        if (conn != null) {
            conn.close();
        }
    }

    private TaskSnapshot convert2TaskSnapshot(ResultSet rs) throws SQLException {
        TaskSnapshot taskSnapshot = new TaskSnapshot();
        taskSnapshot.setJob_id(rs.getLong("job_id"));
        taskSnapshot.setJob_instance_id(rs.getLong("job_instance_id"));
        taskSnapshot.setTask_id(rs.getLong("task_id"));
        taskSnapshot.setBatch_no(rs.getLong("batch_no"));
        taskSnapshot.setTask_name(rs.getString("task_name"));
        taskSnapshot.setStatus(rs.getInt("status"));
        taskSnapshot.setProgress(rs.getFloat("progress"));
        taskSnapshot.setGmt_create(rs.getDate("gmt_create"));
        taskSnapshot.setGmt_modified(rs.getDate("gmt_modified"));
        //taskSnapshot.setWorker(rs.getString("worker"));
        taskSnapshot.setTask_body(rs.getBytes("task_body"));
        return taskSnapshot;
    }
}
