import math
import pathlib
import socket
import struct
import time

"""
Main interface to the LabVIEW backend.
"""

class UdpMulticaster:
    """
    UDP Multicaster:
    This is used to 'multicast' to the network to find the Entangleware Control
    software. The Entangleware Control Software, when running, will listen for
    this multicast event. The time-to-live (MCAST_TTL) setting specifies how far
    within the intranet the multicast message will be allowed to travel. '1' -->
    only the local subnet; '255' --> the entire intranet (and possibly leak to
    the internet)
    """
    def __init__(self, mcast_group='239.255.45.57', mcast_port=50100, mcast_ttl=1):
        self.MCAST_GRP = mcast_group
        self.MCAST_PORT = mcast_port
        self.MCAST_TTL = mcast_ttl

        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
        self.sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, self.MCAST_TTL)

    def close(self):
        self.sock.close()

    def sendmsg(self, tcpport):
        portpack = bytearray(struct.pack(">H", tcpport))
        self.sock.sendto(portpack, (self.MCAST_GRP, self.MCAST_PORT))


class TcpServer:
    """
    TCP Server: This is used by the python code to create a TCP server to which
    the Entangleware Control software will be a client. The Entangleware
    software will receive the location of this server when the python code sends
    a multicast message via UPD. After the Entangleware Control software
    attempts to connect, the connection is used to create the actual TCP
    endpoint used for communication to/from the Entangleware Control software.
    """
    def __init__(self, serveraddress=('', 0), timeout=1.0, backlog=1):
        # private data ####
        self._timeout = timeout
        self._initaddress = serveraddress
        self._backlog = backlog
        self._listensock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

        # setup options, bind, listen
        self._listensock.settimeout(self._timeout)
        self._listensock.bind(self._initaddress)
        self._listensock.listen(self._backlog)

        # public data ####
        self.serverport = self._listensock.getsockname()[1]

    def accept(self):
        return self._listensock.accept()

    def close(self):
        self._listensock.close()


class TcpEndPoint:
    """
    TCP Endpoint:
    This is the endpoint to which the TCP server passes the connection. The
    endpoint will handle most communication to/from the Entangleware Control
    software.
    """
    def __init__(self, connection):
        if isinstance(connection, socket.socket):
            self._sock = connection
            self._sock.settimeout(30)
        else:
            raise Exception('connection is not a socket.socket type')

    def close(self):
        self._sock.close()

    def sendmsg(self, msg, msgid, msgtype):
        bytemsg = bytearray(msg)
        header = bytearray(struct.pack(">QLL", msgid, msgtype, len(bytemsg)))
        completemsg = header + bytemsg
        msglength = len(completemsg)
        chunksize = 512
        startaddr = 0
        while msglength:
            lastelem = startaddr + chunksize
            bytessent = self._sock.send(completemsg[startaddr:lastelem])
            startaddr += bytessent  # bytessent should equal chunksize but if not, start at proper startaddr
            msglength -= bytessent

    def getmsg(self):
        header = bytearray(self._sock.recv(16))
        msgid, msgtype, msglength = struct.unpack(">QLL", header)
        msg = bytearray(msglength) # allocate buffer
        msgview = memoryview(msg)  # create buffer reference
        while msglength:
            nbytes = self._sock.recv_into(msgview, msglength)
            msglength -= nbytes
        return msg, msgid, msgtype


class ConnectionManager:
    """
    Connection Manager: This is the basic class that handles information about
    the connection to/from the Entangleware Control client.
    """
    def __init__(self):
        self.isConnected = False
        self.tcp_server = None
        self.udp_local = None
        self.tcp_endpoint = None

    def close(self):
        self.isConnected = False

        if self.tcp_endpoint:
            self.tcp_endpoint.close()

        if self.tcp_server:
            self.tcp_server.close()

        if self.udp_local:
            self.udp_local.close()

        self.tcp_server = None
        self.udp_local = None
        self.tcp_endpoint = None


class DDS:
    """
    This is an example of how to use python and Entangleware Control to control
    a AD9959 DDS. It may not work.
    """
    def __init__(self, board, connector, mosi_pin, sclk_pin, cs_pin, reset_pin, ioupdate_pin):
        self.board = board
        self.connector = connector
        self.mosipin = mosi_pin
        self.sclkpin = sclk_pin
        self.cspin = cs_pin
        self.resetpin = reset_pin
        self.ioupdatepin = ioupdate_pin
        self.spi_min_time = 2e-7
        self._dds_refclock = 60e6
        self._dds_refclkmultiplier = 6
        self._dds_sysclock = self._dds_refclkmultiplier * self._dds_refclock

    def _spi(self, spitime,  bytes_to_write):
        # IOUpdate Low
        set_digital_state(spitime+self.spi_min_time, self.board, self.connector, 1 << self.ioupdatepin, 1 << self.ioupdatepin,
                          0 << self.ioupdatepin)

        chanselect = ((1 << self.ioupdatepin) | (1 << self.cspin))
        outenable = chanselect
        state = chanselect

        # CS high and ioupdate high
        set_digital_state(spitime, self.board, self.connector, chanselect, outenable, state)

        thistime = spitime
        chanselect = ((1 << self.mosipin) | (1 << self.sclkpin) | (1 << self.cspin))
        outenable = chanselect
        thistime -= self.spi_min_time
        state = ((0 << self.mosipin) | (0 << self.sclkpin) | (0 << self.cspin))

        # set last sclk falling edge
        set_digital_state(thistime, self.board, self.connector, chanselect, outenable, state)

        for individualbytes in reversed(bytes_to_write):
            for individualbits in range(8):
                thistime -= self.spi_min_time
                state = ((((individualbytes >> individualbits) & 1) << self.mosipin) | (1 << self.sclkpin) | (
                    0 << self.cspin))
                set_digital_state(thistime, self.board, self.connector, chanselect, outenable, state)
                thistime -= self.spi_min_time
                state = ((((individualbytes >> individualbits) & 1) << self.mosipin) | (0 << self.sclkpin) | (
                    0 << self.cspin))
                set_digital_state(thistime, self.board, self.connector, chanselect, outenable, state)

    def set_freq(self, ddstime, channel_mask, freq):
        payload0 = bytearray([0, channel_mask << 4, 4])
        payload4 = struct.pack('>L', round((1 << 32) * freq / self._dds_sysclock))
        payload_combined = payload0 + payload4
        self._spi(ddstime, payload_combined)
        return

    def set_amplitude(self, ddstime, channel_mask, amplitude):
        payload0 = bytearray([0, channel_mask << 4])
        payload6 = struct.pack('>BBH', 6, 0, ((1 << 12) | (int(amplitude) & 0x03FF)))
        payload_combined = payload0 + payload6
        #print(payload_combined.hex())
        self._spi(ddstime, payload_combined)
        return

    def set_phase(self, ddstime, channel_mask, phase):
        pass

    def initialize(self, ddstime):
        # FR1 register
        register = 1
        payload1 = (self._dds_refclkmultiplier << 2)
        payload2 = 0
        payload3 = 0
        #print(payload1)
        datatosend = bytearray([register, payload1, payload2, payload3])
        #print(datatosend.hex())
        self._spi(ddstime, datatosend)
        return

    def reset(self, ddstime):
        thistime = ddstime
        chanselect = ((1 << self.resetpin) | (1 << self.cspin))
        outenable = chanselect
        state = ((0 << self.resetpin) | (1 << self.cspin))

        # change resetpin to low
        set_digital_state(thistime, self.connector, chanselect, outenable, state)

        thistime -= 800*self.spi_min_time
        state = ((1 << self.resetpin) | (1 << self.cspin))

        set_digital_state(thistime, self.connector, chanselect, outenable, state)


class Sequence:
    """
    Sequence:
    This class can be used to store a local version of the sequence before
    sending to the Entangleware Control software. The sequence is initialized to
    be an empty byte array (with a memoryview object to insert data 'in-place').
    As more data is added, the byte array will automatically grow to accommodate
    any size (up to memory constraints)
    """
    def __init__(self):
        self.building = False
        self.local = True
        self.seqendindex = 0
        self.lengthpayload = 24
        self.lengthsequence = 2**20
        self.sizeofbytearray = self.lengthpayload * self.lengthsequence

        # Initial Settings
        self.seq = bytearray(self.sizeofbytearray)
        self.seqview = memoryview(self.seq)
        self.seqchainfirstcall = True
        self.seqchainlastruntime = 0

    def addElement(self, element):
        # element is a byte array whose length is a multiple of self.lengthpayload
        length_element = len(element)
        if (length_element % self.lengthpayload) != 0:
            raise ValueError('Length of \'element\' is not correct')

        length_buffer = len(self.seq)
        start_index = int(self.seqendindex * self.lengthpayload)
        length_empty_seq = int(length_buffer - start_index)

        if length_element > length_empty_seq:
            self.seqview.release()
            self.seq = self.seq + bytearray(self.sizeofbytearray * math.ceil(length_element / self.sizeofbytearray))
            self.seqview = memoryview(self.seq)
            print('new seqview')

        self.seqview[start_index:(start_index + length_element)] = element
        self.seqendindex += length_element / self.lengthpayload

    def clear(self):
        self.building = False
        self.seqview.release()
        self.seq = bytearray(self.lengthpayload * self.lengthsequence)
        self.seqview = memoryview(self.seq)
        self.seqendindex = 0


def connect(timeout_sec=None, address: str="127.0.0.1", port: int=50100, ttl: int=1):
    # Create the Multicaster and the TCP Server. Multicast (i.e. send) the port
    # number of the TCP server to the Entanglware Control software so that it
    # can attempt to connect to python and receive data. If the Entangleware
    # Control software is running, it has already subscribed to the Multicast
    # location.
    connmgr.udp_local = UdpMulticaster(address, port, ttl)
    connmgr.tcp_server = TcpServer()
    connmgr.udp_local.sendmsg(int(connmgr.tcp_server.serverport))
    try:
        (tcp_local, (lvip, lvport)) = connmgr.tcp_server.accept()
        if timeout_sec:
            tcp_local.settimeout(timeout_sec)
        connmgr.tcp_endpoint = TcpEndPoint(tcp_local)
        connmgr.isConnected = True
    except:
        raise Exception("[entangleware] Connection failed, shutting down...")
        #print('Connection failed, shutting down...')


def disconnect():
    if connmgr.isConnected:
        connmgr.close()


def build_sequence():
    msgseq.building = True
    return


def clear_sequence():
    msgseq.clear()
    return


def rerun_last_sequence(printflag: bool=True):
    pathtoTemp = str(pathlib.Path().absolute().joinpath("LastCompiledRun.dat"))
    with open(pathtoTemp, "rb") as in_file:
        tcpmessage = in_file.read()
    lv_instruct = 0x16
    connmgr.tcp_endpoint.sendmsg(tcpmessage, 0, lv_instruct)
    runreturn = connmgr.tcp_endpoint.getmsg()
    runtime = struct.unpack('>d', runreturn[0])
    if printflag:
        print(f"[entangleware] Expected runtime: ", runtime[0])
    connmgr.tcp_endpoint._sock.settimeout(runtime[0]+20.5)
    donemsg = connmgr.tcp_endpoint.getmsg()
    connmgr.tcp_endpoint._sock.settimeout(10)
    return donemsg


def run_sequence(printflag: bool=True):
    number_cycles = 1  # Don't Change (feature not yet implemented)
    tosend = bytearray(struct.pack('>l', number_cycles))
    tcpmessage = b""
    msgseq.seqendindex = int(msgseq.seqendindex)
    lv_instruct = 0x16
    if printflag:
        print(f'[entangleware] Sent {msgseq.seqendindex} events to Entangleware Control')
    #print('[entangleware] Number of elements sent to Entangleware Control: ', msgseq.seqendindex)
    tcpmessage = tosend + msgseq.seqview.tobytes()[:msgseq.seqendindex*msgseq.lengthpayload]
    connmgr.tcp_endpoint.sendmsg(tcpmessage, 0, lv_instruct)

    runreturn = connmgr.tcp_endpoint.getmsg()
    runtime = struct.unpack('>d', runreturn[0])
    if printflag:
        print('[entangleware] Expected runtime: ', runtime[0])
    pathtoTemp = str(pathlib.Path().absolute().joinpath("LastCompiledRun.dat"))
    with open(pathtoTemp, "wb") as out_file:
        out_file.write(tcpmessage)
    connmgr.tcp_endpoint._sock.settimeout(runtime[0] + 20.0)
    donemsg = connmgr.tcp_endpoint.getmsg()
    # print(donemsg == (bytearray(b'Done'), 15, 15)) # Returns 'True'
    connmgr.tcp_endpoint._sock.settimeout(10)
    return donemsg


def run_sequence_chain():
    if not msgseq.seqchainfirstcall:
        connmgr.tcp_endpoint._sock.settimeout(msgseq.seqchainlastruntime + 20.5)
        donemsg = connmgr.tcp_endpoint.getmsg()
        if donemsg != (bytearray(b'Done'), 15, 15):
            raise ValueError('Return from LV is unexpected')
        connmgr.tcp_endpoint._sock.settimeout(10)

    number_cycles = 1  # Don't Change (feature not yet implemented)
    tosend = bytearray(struct.pack('>l', number_cycles))
    tcpmessage = b""
    lv_instruct = 0x16
    print('[entangleware] Number of elements sent to Entangleware Control: ', msgseq.seqendindex)
    tcpmessage = tosend + msgseq.seqview.tobytes()[:msgseq.seqendindex*msgseq.lengthpayload]
    connmgr.tcp_endpoint.sendmsg(tcpmessage, 0, lv_instruct)
    msgseq.clear()
    msgseq.building = False

    runreturn = connmgr.tcp_endpoint.getmsg()
    runtime = struct.unpack('>d', runreturn[0])
    msgseq.seqchainlastruntime = runtime[0]
    pathtoTemp = str(pathlib.Path().absolute().joinpath("LastCompiledRun.dat"))
    with open(pathtoTemp, "wb") as out_file:
        out_file.write(tcpmessage)
    msgseq.seqchainfirstcall = False
    return


def stop_sequence():
    number_cycles = 1
    lv_instruct = 0x13
    tosend = bytearray(struct.pack('>l', number_cycles))
    connmgr.tcp_endpoint.sendmsg(tosend, 0, lv_instruct)
    return


def set_digital_state(seqtime, board, connector, channel_mask, output_enable_state, output_state):
    """
    Sets the digital output state.  If 'build_sequence' hasn't been executed
    before this method, the method will ignore 'time' and immediately change the
    digital state.  If 'build_sequence' has been executed before this method,
    the method will queue this state into the sequence which will execute, in
    time-order, after 'run_sequence' has been executed.

    Parameters
    ----------
    seqtime : float
        Absolute time, in seconds, when state will change.
    board : 8-bit int >= 0
        7820R board number -- starts at '0'
    connector : 8-bit int >= 0
        Connector of the 7820R -- starts at '0'
    channel_mask : 32-bit int >= 0
        Mask of the channel(s) to be changed
    output_enable_state : 32-bit int >= 0
        State of output enable for the channel(s) to be changed
    output_state : 32-bit int >= 0
        State of the channel(s) starting at `time`

    Returns
    -------
    None
    """
    connector = (0x01 << 24) | ((board & 0xFF) << 8) | (connector & 0xFF)
    lv_instruct = 0x14  # instruction to Entangleware Control software
    if msgseq.building:
        tosend = bytearray(struct.pack('>dLLLL', seqtime, connector, channel_mask, output_enable_state, output_state))
        msgseq.addElement(tosend)
    else:
        tosend = bytearray(struct.pack('>dLLLL', seqtime, connector, channel_mask, output_enable_state, output_state))
        connmgr.tcp_endpoint.sendmsg(tosend, 0, lv_instruct)
    return


def set_analog_state(seq_time, board, channel, value):
    numtype = (int, float)
    # print(type(seq_time))
    lv_instruct = 0x14  # instruction to Entangleware Control software
    if isinstance(seq_time, numtype) \
            and isinstance(board, numtype) \
            and isinstance(channel, numtype) \
            and isinstance(value, numtype):
        board_in_range = (0 <= board <= 255)
        channel_in_range = (0 <= channel <= 255)
        if board_in_range and channel_in_range:
            connector = (0x02 << 24) | (board << 8) | (0x00 & 0xFF)
            channel_mask = (0x01 << 24) | channel  # Most-sig byte is the 'Enable' Byte: 1 = enabled 0 = disabled
            to_send = bytearray(
                struct.pack('>dLLd', seq_time, connector, channel_mask, value))
            if msgseq.building:
                msgseq.addElement(to_send)  # add to queue if building
            else:
                connmgr.tcp_endpoint.sendmsg(to_send, 0, lv_instruct)  # set default if not building
    elif isinstance(seq_time, list) \
            and isinstance(board, numtype) \
            and isinstance(channel, numtype) \
            and isinstance(value, list):
        # print('in list mode')
        board_in_range = (0 <= board <= 255)
        channel_in_range = (0 <= channel <= 255)
        if msgseq.building:
            if board_in_range and channel_in_range:
                connector = (0x02 << 24) | (board << 8) | (0x00 & 0xFF)
                channel_mask = (0x01 << 24) | channel

                length_payload = min(len(seq_time), len(value))
                seq_time = seq_time[:length_payload]
                connector = [connector]*length_payload
                channel_mask = [channel_mask]*length_payload
                value = value[:length_payload]

                str_fmt = '>'+'dLLd'*length_payload
                data_to_pack = [0]*4*length_payload
                data_to_pack[0::4] = seq_time
                data_to_pack[1::4] = connector
                data_to_pack[2::4] = channel_mask
                data_to_pack[3::4] = value
                to_send = bytearray(struct.pack(str_fmt, *data_to_pack))
                msgseq.addElement(to_send)
        else:
            raise ValueError
    else:
        raise ValueError
    return


def debug_mode(active):
    if active:
        tosend = b'True'
    else:
        tosend = b'False'
    lv_instruct = 0x15
    connmgr.tcp_endpoint.sendmsg(tosend, 0, lv_instruct)
    return


connmgr = ConnectionManager()
msgseq = Sequence()
msgseq.local = True # Always keep True

if __name__ == "__main__":
    for numbers in range(5):
        print(str(numbers)+' connecting...')
        connect()
        if connmgr.isConnected:
            print('sending message to echo server...')
            connmgr.tcp_endpoint.sendmsg(b'Hello LV', numbers, 0x80000000)
            print(connmgr.tcp_endpoint.getmsg())
            time.sleep(1)
        else:
            print('not connected')

        time.sleep(0)
        print('disconnecting...')
        disconnect()