package com.openfin.desktop.nix;

/**
 * Create named pipe for Runtime port discovery on Linux and Mac.
 *
 * This class replies on openfin-cli to launch Runtime. com.openfin.installer.location should be set to
 * location of a script to launch Runtime with openfin-cli
 *
 * Example bsh script to launch Runtime with openfin-cli
 *
 * #!/bin/bash
 *
 * cd "$(dirname ${BASH_SOURCE[0]})"
 *
 * if [[ $1 == file:///* ]] ;
 * then
 * configFile=`echo $1 | cut -c8-`
 * else
 * configFile=$1
 * fi
 *
 * node cli.js -l -c $configFile
 *
 *
 * @author wche
 * @since 11/15/17
 *
 */

import com.openfin.desktop.*;
import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.Structure;
import com.sun.jna.ptr.IntByReference;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.lang.System;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.List;

public class NamedPipePortHandler implements PortDiscoveryHandler {
    private final static Logger logger = LoggerFactory.getLogger(NamedPipePortHandler.class.getName());

    private static final long RUNTIME_HELLO_MESSAGE = (long) Math.pow( 2, 16 ) - 1;  // UINT16_MAX
    private static final int RUNTIME_STRING_MESSAGE = 0;
    private static final int UINT32_SIZE = 4;  // 4 bytes
    private static final int MESSAGE_HEADER_SIZE = 5 * UINT32_SIZE;  // 5 * uint32

    // UNIX stuff
    private static final int AF_UNIX = 1;
    private static final int SOCK_STREAM = 1;
    private static final int PROTOCOL =0;
    private static final byte[] ZERO_BYTE = new byte[] {0};
    private static CLibrary cLibrary;

    private String pipeName, effectivePipeName;
    private PipeMessageThread pipeMessageThread;

    static {
        if (!DesktopUtils.isWindows()) {
            cLibrary = (CLibrary) Native.loadLibrary("c", CLibrary.class);
        }
    }

    public NamedPipePortHandler(String pipeName) {
        this.pipeName = pipeName;
        this.effectivePipeName = String.format("%s%s%s",
                java.lang.System.getProperty("java.io.tmpdir"),
                File.separator, this.pipeName);
        logger.debug(String.format("Created with %s", pipeName));
    }

    @Override
    public String getEffectivePipeName() {
        return this.effectivePipeName;
    }

    @Override
    public void registerEventListener(EventListener listener, int timeout) {
        this.pipeMessageThread = new PipeMessageThread(this.effectivePipeName, listener, timeout);
        pipeMessageThread.start();
    }

    @Override
    public void removeEventListener(EventListener listener) {
        if (this.pipeMessageThread != null) {
            this.pipeMessageThread.removeEventListener(listener);
        }
    }

    protected class TimeoutThread extends Thread {
        private PipeMessageThread messageThread;
        private int timeout;
        private volatile boolean interrupted = false;

        private TimeoutThread(PipeMessageThread messageThread, int timeout) {
            this.messageThread = messageThread;
            this.timeout = timeout;
            this.setName(com.openfin.desktop.win32.NamedPipePortHandler.class.getName() + ".TimeoutThread");
        }

        @Override
        public void run() {
            logger.debug("Starting timeout thread");
            try {
                Thread.sleep(timeout * 1000);
            } catch (InterruptedException ignored) {
            }
            messageThread.timeout();
            logger.debug("exiting");
        }
    }

    protected class PipeMessageThread extends Thread {
        private EventListener eventListener;
        private TimeoutThread timeoutThread;
        private String pipeName;
        private int pipefd, clientfd;

        private PipeMessageThread(String pipeName, EventListener listener, int timeout) {
            super();
            this.pipefd = -1;
            this.pipeName = pipeName;
            this.setDaemon(true);
            this.setName(com.openfin.desktop.win32.NamedPipePortHandler.class.getName() + ".PipeMessageThread");
            this.eventListener = listener;
            this.timeoutThread = new TimeoutThread(this, timeout);
            this.timeoutThread.start();
        }

        @Override
        public void run() {
            try {
                this.pipefd = createPipe();
                if (this.pipefd != -1) {
                    this.clientfd = acceptClient();
                    MessageHeader header = new MessageHeader();
                    header.extraInteger = null;
                    readMessageHeader(header);
                    if (header.messageType == RUNTIME_HELLO_MESSAGE) {
                        readRuntimeHello(header);
                        writeRuntimeHello(header);
                        readMessageHeader(header);
                        readRuntimeInfo(header);
                    } else {
                        logger.error(String.format("Invalid Runtime Hello message type %d", header.messageType));
                    }
                }
            } catch (Exception ex) {
                logger.error("Error processing port discovery", ex);
            }
            finally {
                closePipe();
            }
        }

        private int createPipe() throws Exception {
            int pipe, errno;
            try {
                logger.debug(String.format("Creating pipe %s", this.pipeName));
                pipe = cLibrary.socket(AF_UNIX, SOCK_STREAM, PROTOCOL);
                if (pipe > 0) {
                    SockAddr sockAddr = new SockAddr();
                    sockAddr.setSunPath(this.pipeName);
                    errno = cLibrary.bind(pipe, sockAddr, sockAddr.size());
                    if (errno >= 0) {
                        errno = cLibrary.listen(pipe, 1);
                        if (errno < 0) {
                            logger.error(String.format("socket.listen error %d %s", errno,
                                            this.pipeName));
                        }
                    } else {
                        logger.error(String.format("socket.bind error %d %s", errno,
                                this.pipeName));
                    }
                } else {
                    errno = Native.getLastError();
                    logger.error(String.format("socket.create error %d %s", errno,
                            this.pipeName));
                }
                if (errno < 0) {
                    throw new IOException(String.format("Error creating pipe %s errno %d",
                                this.pipeName, errno));
                }
            } catch (Exception ex) {
                logger.error(String.format("Error creating name pipe %s", this.pipeName), ex);
                throw ex;
            }
            return pipe;
        }

        private int acceptClient() throws Exception {
            SockAddr clientAddr = new SockAddr();
            IntByReference addrLen = new IntByReference(0);
            int clientfd = cLibrary.accept(this.pipefd, clientAddr, addrLen);
            if (clientfd >= 0) {
                logger.debug(String.format("Client connected %d", clientfd));
            } else {
                logger.error(String.format("Error socket.accept %d", Native.getLastError()));
            }
            return clientfd;
        }

        private synchronized void closePipe() {
            if (this.pipefd >= 0) {
                try {
                    logger.debug(String.format("Closing named pipe %s", this.pipeName));
                    cLibrary.close(this.pipefd);
                    cLibrary.unlink(this.pipeName);
                } catch (Exception ex) {
                    logger.debug(String.format("Error closing pipe %s", ex.getMessage()));
                }
            }
            this.pipefd = -1;
        }

        private void readMessageHeader(MessageHeader header) throws IOException {
            header.payloadSize     = readInt();
            header.routingId       = readInt();
            header.messageType     = readInt();
            header.flags           = readInt();
            header.attachmentCount = readInt();
            logger.debug(String.format("Runtime Header %d %d %d %d %d", header.payloadSize,
                    header.routingId, header.messageType, header.flags, header.attachmentCount));
        }

        private void writeRuntimeHello(MessageHeader header) throws IOException{
            int writeLength = MESSAGE_HEADER_SIZE + UINT32_SIZE;
            if (header.extraInteger != null) {
                writeLength += UINT32_SIZE;
            }
            byte[] writeBuffer = new byte[writeLength];
            ByteBuffer bb = ByteBuffer.wrap(writeBuffer);
            bb.order(ByteOrder.LITTLE_ENDIAN);
            bb.putInt(header.payloadSize);
            bb.putInt(header.routingId);
            bb.putInt(header.messageType);
            bb.putInt(header.flags);
            bb.putInt(header.attachmentCount);
            if (header.extraInteger != null) {
                bb.putInt(header.extraInteger);
            }
            bb.putInt(cLibrary.getpid());

            int length = cLibrary.write(this.clientfd, writeBuffer, writeBuffer.length);
            if (writeBuffer.length == length) {
                logger.debug(String.format("Wrote Runtime hello message %d ", cLibrary.getpid()));
            } else {
                throw new IOException(String.format("Error WriteFile length mismatch %d %d",
                        writeBuffer.length, length));
            }
        }

        private void readRuntimeHello(MessageHeader header) throws IOException {
            logger.debug("Reading Runtime Hello Payload");
            int helloPayload = readInt();
            if (helloPayload == 0) {
                logger.debug(String.format("Runtime Hello Payload extra int %d", helloPayload)); // Mac has an extra int for some reason
                header.extraInteger = helloPayload;
                helloPayload = readInt();
            }
            logger.debug(String.format("Runtime Hello Payload pid %d", helloPayload)); // supposed to be pid of Runtime
        }

        private String readRuntimeString() throws Exception {
            int strLength = readInt();
            if (strLength == 0) {
                logger.debug(String.format("Discovery Message length exxtra int %d", strLength));
                strLength = readInt();
            }
            logger.debug(String.format("Discovery Message length %d", strLength));
            byte[] data = new byte[strLength];
            ByteBuffer bb = ByteBuffer.wrap(data);
            bb.order(ByteOrder.LITTLE_ENDIAN);
            int length = cLibrary.read(this.clientfd, data, data.length);
            if (length != strLength) {
                throw new IOException(String.format("Runtime string length mismatch %d %d", length, strLength));
            }
            String value = new String(data);
            logger.debug(String.format("Runtime String %s", value));
            return value;
        }

        private void readRuntimeInfo(MessageHeader header) throws Exception{
            if (header.messageType == RUNTIME_STRING_MESSAGE) {
                String runtimeMsg = readRuntimeString();
                JSONObject jsonObject = new JSONObject(runtimeMsg);
                JSONObject payload = JsonUtils.getJsonValue(jsonObject,"payload", null);
                if (payload != null) {
                    ActionEvent actionEvent = new ActionEvent(this.pipeName, payload, this);
                    if (this.eventListener != null) {
                        this.eventListener.eventReceived(actionEvent);
                    }
                } else {
                    logger.error("Missing payload of Runtime info");
                }
            } else {
                logger.error(String.format("Invalid RUNTIME_STRING_MESSAGE %d", header.messageType));
            }
        }
        private int readInt() throws IOException {
            byte[] readBuffer = new byte[UINT32_SIZE];
            ByteBuffer bb = ByteBuffer.wrap(readBuffer);
            bb.order(ByteOrder.LITTLE_ENDIAN);
            int length = cLibrary.read(this.clientfd, readBuffer, readBuffer.length);
            if (length == readBuffer.length) {
                return bb.getInt();
            } else {
                throw new IOException(String.format("readInt failed with %d", Native.getLastError()));
            }
        }

        private void timeout() {
            if (this.pipefd >= 0) {
                this.closePipe();
                JSONObject jsonObject = new JSONObject();
                ActionEvent actionEvent = new ActionEvent("TIMEOUT", jsonObject, this);
                if (this.eventListener != null) {
                    this.eventListener.eventReceived(actionEvent);
                }
            }
        }
        private void removeEventListener(EventListener listener) {
            if (this.eventListener == listener) {
                this.eventListener = null;
            }
        }

    }

    private static class MessageHeader {
        int payloadSize;
        int routingId;
        int messageType;
        int flags;
        int attachmentCount;
        Integer extraInteger; // Mac has an extra int for some reason
    }

    private interface CLibrary extends Library {
        public int socket(int domain, int type, int protocol);
        public int bind(int socket, SockAddr sockAddr, int addrLen);
        public int listen(int socket, int queue);
        public int accept(int socket, SockAddr sockAddr, IntByReference addrLen);
        public int read(int socket, byte[] buffer, int length);
        public int write(int socket, byte[] buffer, int length);
        public int close(int socket);
        public int getpid();
        public int connect(int sockfd, SockAddr sockaddr, int addrlen);
        public int unlink(String name);
    }
    public static class SockAddr extends Structure {
        public short sun_family = 1;
        public byte[] sun_path = new byte[108];

        public void setSunPath(String sunPath) {
            System.arraycopy(sunPath.getBytes(), 0, this.sun_path, 0, sunPath.length());
            System.arraycopy(ZERO_BYTE, 0, this.sun_path, sunPath.length(), 1);
        }

        @Override
        protected List getFieldOrder() {
            String[] fields = {"sun_family", "sun_path"};
            return Arrays.asList(fields);
        }
    }

    public static void main(String[] argv) throws Exception {
        if ("server".equals(argv[0])) {
            serverMode();
        } else {
            clientMode();
        }
    }

    private static void clientMode() throws Exception {
        String pipeName = "/tmp/testpipe";
        logger.debug(String.format("server %s", pipeName));
        int sockfd = cLibrary.socket(AF_UNIX,SOCK_STREAM,PROTOCOL);
        SockAddr sockAddr = new SockAddr();
        sockAddr.setSunPath(pipeName);
        cLibrary.connect(sockfd, sockAddr, sockAddr.size());

        // 4 -2 65535 1879048194 0
        byte[] writeBuffer = new byte[MESSAGE_HEADER_SIZE + UINT32_SIZE];
        ByteBuffer bb = ByteBuffer.wrap(writeBuffer);
        bb.order(ByteOrder.LITTLE_ENDIAN);
        bb.putInt(4);
        bb.putInt(-2);
        bb.putInt(65535);
        bb.putInt(1879048194);
        bb.putInt(0);
        bb.putInt(cLibrary.getpid());

        for (int i = 0; i < writeBuffer.length; i += UINT32_SIZE) {
            logger.debug(String.format("Verify putInt %d", bb.getInt(i)));
        }

        int length = cLibrary.write(sockfd, writeBuffer, writeBuffer.length);
        if (writeBuffer.length == length) {
            logger.debug(String.format("Wrote Runtime hello message %d ", cLibrary.getpid()));
        }
    }

    private static void serverMode() throws Exception {
        String pipeName = "/tmp/testpipe";
        logger.debug(String.format("client %s", pipeName));
        int sockfd = cLibrary.socket(AF_UNIX,SOCK_STREAM,PROTOCOL);
        SockAddr sockAddr = new SockAddr();
        sockAddr.setSunPath("/tmp/testpipe");
        cLibrary.bind(sockfd, sockAddr, sockAddr.size());
        cLibrary.listen(sockfd, 1);
        SockAddr clientAddr = new SockAddr();
        IntByReference addrLen = new IntByReference(0);
        int clientfd = cLibrary.accept(sockfd, clientAddr, addrLen);
        if (clientfd >= 0) {
            logger.debug(String.format("Client connected %d", clientfd));
        }
        for (int i = 0; i < 1000; i++) {
            logger.debug(String.format("ReadInt %d", readInt(clientfd)));
        }

    }

    private static int readInt(int socketfd) throws IOException {
        byte[] readBuffer = new byte[UINT32_SIZE];
        ByteBuffer bb = ByteBuffer.wrap(readBuffer);
        bb.order(ByteOrder.LITTLE_ENDIAN);
        int length = cLibrary.read(socketfd, readBuffer, readBuffer.length);
        if (length == readBuffer.length) {
            return bb.getInt();
        } else {
            throw new IOException(String.format("readInt failed with %d", Native.getLastError()));
        }
    }


}
