Skip to content
Snippets Groups Projects
sequences.py 17.27 KiB
from __future__ import annotations
from .structures import *
import numpy as np
import matplotlib.pyplot as pp
from .plotdefs import Plotter
import copy
from collections import defaultdict
import pathlib
import toml
import shutil
import os

"""
Builds on abstractions defined in lib.structures.
"""

def _titlecase(s: str) -> str:
    return " ".join(_s[:1].upper() + _s[1:]
        for _s in s.replace(" ", "_").split("_"))

def _fresh_filename(path: pathlib.Path, overwrite: bool=False) -> pathlib.Path:
    _path = path
    while _path.exists():
        print(f"WARNING: {_path} exists in filesystem")
        if overwrite:
            print(f":: rm -rf {_path}")
            if _path.is_dir():
                shutil.rmtree(str(_path))
            else:
                os.remove(str(_path))
        else:
            _path = _path.parent.joinpath(_path.stem + "_" + _path.suffix)
            print(f"Write instead to {_path}")
    return _path

def min_n(generator, default=None):
    x0 = None
    for x in generator:
        if x is None:
            continue
        elif x0 is None or x < x0:
            x0 = x
    return default if x0 is None else x0

def max_n(generator, default=None):
    x0 = None
    for x in generator:
        if x is None:
            continue
        elif x0 is None or x > x0:
            x0 = x
    return default if x0 is None else x0

def _dict_to_defaultdict(D: dict, factory) -> defaultdict:
    DD = defaultdict(factory)
    for k, v in D.items():
        DD[k] = v
    return DD

class SuperSequence:
    """
    Sequence of labeled Sequences. Mostly for visualization purposes.

    Fields
    ------
    outdir : pathlib.Path
    name : str
    sequences : connections.defaultdict[str, Sequence]
    defaults : ConnectionLayout
    """

    def __init__(self, outdir: pathlib.Path, name: str,
            sequences: dict[str, Sequence]=None, params: dict[str, ...]=None,
            defaults: ConnectionLayout=None):
        """
        Constructor.

        Parameters
        ----------
        outdir : pathlib.Path
        name : str
        sequences : dict[str, Sequence] (optional)
        defaults : ConnectionLayout (optional)
        """
        self.outdir = outdir
        self.name = name
        self.sequences = (
            defaultdict(Sequence) if sequences is None
            else _dict_to_defaultdict(sequences, Sequence)
        )
        self.params = dict() if params is None else params
        self.defaults = defaults

    def __getitem__(self, key: str) -> Sequence:
        assert isinstance(key, str)
        return self.sequences[key]

    def __setitem__(self, key: str, seq: Sequence):
        assert isinstance(key, str)
        assert isinstance(seq, Sequence)
        self.sequences[key] = seq

    def get(self, key, default=None) -> Sequence:
        return self.sequences.get(key, default)

    def keys(self):
        return self.sequences.keys()

    def values(self):
        return self.sequences.values()

    def items(self):
        return self.sequences.items()

    def update(self, other):
        self.sequences.update(other)
        return self

    def __add__(self, other):
        assert isinstance(other, SuperSequence)
        sequences = copy.deepcopy(self.sequences)
        return SuperSequence(self.outdir, self.name, sequences).update(other)

    def __iadd__(self, other):
        return self.update(other)

    def by_times(self) -> list[Sequence]:
        """
        Return Sequences in a list ordered by earliest Event.
        """
        return sorted(
            self.sequences.values(),
            key=lambda e: e.min_time()
        )

    def by_times_named(self) -> list[(str, Sequence)]:
        """
        Return Sequences with their names in a list ordered by earliest Event.
        """
        return sorted(
            self.sequences.items(),
            key=lambda name_seq: name_seq[1].min_time()
        )

    def min_time(self) -> float:
        return min(map(lambda seq: seq.min_time(), self.sequences.values()))

    def max_time(self) -> float:
        return max(map(lambda seq: seq.max_time(), self.sequences.values()))

    def to_sequence(self) -> Sequence:
        """
        Condense to a single Sequence to pass to the computer.
        """
        seq = Sequence()
        for S in self.by_times():
            seq = seq + S
        return seq

    def to_primitives(self) -> dict[str, dict]:
        return {
            label: seq.to_primitives() for label, seq in self.sequences.items()
        }

    @staticmethod
    def from_primitives(outdir: pathlib.Path, name: str,
            alldata: dict[str, dict]):
        return SuperSequence(outdir, name, {
            label: Sequence.from_primitives(seq_dict)
            for label, seq_dict in alldata.items()
        })

    def save(self, target: pathlib.Path=None, overwrite: bool=False,
            printflag: bool=True):
        T = self.outdir.joinpath(self.name) if target is None else target
        if printflag: print(f"[sequencing] Saving sequence data to '{T}':")
        T = _fresh_filename(T, overwrite)
        T_seq = T.joinpath("sequence.toml")
        T_par = T.joinpath("parameters.toml")

        if not T.is_dir():
            if printflag: print(f":: mkdir -p '{T}'")
            T.mkdir(parents=True, exist_ok=True)
        with T_seq.open('w') as outfile:
            toml.dump(self.to_primitives(), outfile)
        with T_par.open('w') as outfile:
            toml.dump(
                {
                    k: float(v)
                    if isinstance(v, (float, np.float32, np.float64))
                    else v for k, v in self.params.items()
                },
                outfile
            )
        return self

    @staticmethod
    def load(target: pathlib.Path, outdir: pathlib.Path=None, name: str=None,
            printflag: bool=True):
        outdir = target.parent if outdir is None else outdir
        name = target.stem if name is None else name
        if printflag: print(f"[sequencing] Loading data from target '{target}':")
        if not target.is_dir():
            raise Exception(f"Target '{target}' does not exist")

        sequence_file = target
        if printflag: print("[sequencing] load sequence")
        if not sequence_file.is_file():
            raise Exception(f"Sequence file '{sequence_file}' does not exist")
        seq_dict = toml.load(sequence_file)
        return SuperSequence.from_primitives(outdir, name, seq_dict)

    @staticmethod
    def _gen_timeline(detailed=False, layout=(1, 4, 32, 1, 32, -10, 10)) \
            -> Plotter:
        P = Plotter().ggrid(True)
        P.set_yticks([]).set_yticklabels([])
        return P

    def draw_simple(self) -> Plotter:
        P = SuperSequence._gen_timeline(detailed=False)
        H = 0.2
        c = 0
        tl_level = 0
        tl_max = 0
        t1_last = dict()
        Tmin = 0
        Tmax = 0
        for i, (label, seq) in enumerate(self.items()):
            tmin, tmax = seq.when()
            Tmin = min(Tmin, tmin)
            Tmax = max(Tmax, tmax)
            k = 0
            while tmin < t1_last.get(k, tmin):
                k += 1
            tl_level = k if seq.stack_idx is None else seq.stack_idx
            tl_max = max(tl_max, tl_level)
            P.ax.add_patch(
                pp.Rectangle(
                    (tmin, tl_level * H),
                    tmax - tmin,
                    H / 2,
                    edgecolor="k",
                    facecolor=f"C{c}" if seq.color is None else seq.color,
                    zorder=100
                )
            )
            P.ax.text(
                tmin,
                (tl_level * H - H / 15),
                label, #_titlecase(label),
                horizontalalignment="left", verticalalignment="top"
            )
            if seq.color is None:
                c = (c + 1) % 10
            t1_last[tl_level] = max(tmax, t1_last.get(tl_level, tmax))
        P.set_xlim(Tmin - (Tmax - Tmin) / 20, Tmax + (Tmax - Tmin) / 20)
        P.set_ylim(-3* H / 2, (tl_max + 2) * H)
        P.set_xlabel("Time [s]")
        return P

    def draw_detailed(
        self,
        ghost_delays: bool=True,
        connections: dict[str, Connection]=None,
        mogtables: list[(..., dict[str, ...] | None)]=None
    ) -> Plotter:
        assert (connections is not None) ^ (self.defaults is not None)
        _connections = list(
            (connections if connections is not None
                else self.defaults.connections).items()
        )
        P = SuperSequence._gen_timeline(detailed=True)
        T0 = min([seq.min_time() for seq in self.values()])
        T1 = max([seq.max_time() for seq in self.values()])
        color_C = 0
        hk = 3
        linewidth=0.7
        def ycoord(x, conn, k):
            y = x if isinstance(conn, DigitalConnection) else (x + 10) / 20
            return hk * k + 1 + y

        # set up background lines and histories on each connection
        # each history is a list of matplotlib line objects for each sequence
        conn_history = dict()
        for k, (conn_label, conn) in enumerate(reversed(_connections)):
            if isinstance(conn, DigitalConnection):
                P.ax.axhline(
                    ycoord(0, conn, k),
                    color="0.85", linestyle="-",
                    linewidth=linewidth / 2, zorder=100)
                P.ax.axhline(
                    ycoord(1, conn, k),
                    color="0.85", linestyle="--",
                    linewidth=linewidth / 2, zorder=100)
            elif isinstance(conn, AnalogConnection):
                P.ax.axhline(
                    ycoord(-10.0, conn, k),
                    color="0.85", linestyle="--",
                    linewidth=linewidth / 2, zorder=100)
                P.ax.axhline(
                    ycoord(0.0, conn, k),
                    color="0.85", linestyle="-",
                    linewidth=linewidth / 2, zorder=100)
                P.ax.axhline(
                    ycoord(+10.0, conn, k),
                    color="0.85", linestyle="--",
                    linewidth=linewidth / 2, zorder=100)
            else:
                raise Exception("Invalid connection type")
            P.ax.text(
                T0, ycoord(0.0, AnalogConnection, k),
                conn_label + " ", #_titlecase(conn_label) + " ",
                horizontalalignment="right", verticalalignment="center",
                zorder=101)
            P.ax.text(
                T1, ycoord(0.0, AnalogConnection, k),
                " " + conn_label, #" " + _titlecase(conn_label),
                horizontalalignment="left", verticalalignment="center",
                zorder=101)
            conn_history[conn] = [
                P.ax.plot(
                    [T0], [ycoord(conn.default, conn, k)], color="0.65",
                    linewidth=linewidth, zorder=103)[0]
            ]

        # add line objects to each history for each sequence they're a part of
        for seq_label, seq in self.by_times_named():
            if seq.color is None:
                color = f"C{color_C}"
                color_C = (color_C + 1) % 10
            else:
                color = seq.color
            t0 = seq.min_time()
            P.ax.axvline(
                t0, color="k", linestyle="--", linewidth=linewidth / 2,
                zorder=102)
            for k, (conn_label, conn) in enumerate(reversed(_connections)):
                ts = seq._get_states(**conn)
                if len(ts) > 0:
                    t_prev = conn_history[conn][-1].get_xdata()
                    s_prev = conn_history[conn][-1].get_ydata()
                    conn_history[conn][-1].set_xdata(
                        list(t_prev) + [(t_prev[-1] + t0) / 2])
                    conn_history[conn][-1].set_ydata(
                        list(s_prev) + [s_prev[-1]])
                    _t = [(t_prev[-1] + t0) / 2]
                    _s = [s_prev[-1]]

                    # _t = [conn_history[conn][-1].get_xdata()[-1]]
                    # _s = [conn_history[conn][-1].get_ydata()[-1]]
                    for t, s in ts:
                        if ghost_delays:
                            if isinstance(conn, DigitalConnection):
                                dt = conn.delay_up if s == 1 else conn.delay_down
                            elif isinstance(conn, AnalogConnection):
                                dt = conn.delay
                            if dt > 0:
                                P.ax.plot(
                                    _t[-1:] + [t + dt, t + dt],
                                    _s[-1:] + [_s[-1], ycoord(s, conn, k)],
                                    color=color,
                                    linewidth=linewidth,
                                    linestyle="--",
                                    zorder=103
                                )
                        _t += [t, t]
                        _s += [_s[-1], ycoord(s, conn, k)]
                    conn_history[conn].append(
                        P.ax.plot(_t, _s, color=color, linewidth=linewidth,
                            zorder=103)[0])
                    P.ax.text(t0, hk * k + 1, seq_label, #_titlecase(seq_label),
                        horizontalalignment="left", verticalalignment="top",
                        zorder=101)

        # add tails to the final history items going until T1
        for history in conn_history.values():
            i0 = history.index(max(history, key=lambda h: max(h.get_xdata())))
            x = history[i0].get_xdata()
            y = history[i0].get_ydata()
            history[i0].set_xdata(list(x) + [T1])
            history[i0].set_ydata(list(y) + [y[-1]])
        
        # process any MOGTables
        moghs = 1
        moghk = 4 + 1 + 3 * moghs
        mT0 = (T0 + T1) / 2
        mT1 = (T0 + T1) / 2

        def mogycoord(x, xlim, k, j):
            xmin, xmax = xlim
            return moghk * k + j * (1 + moghs) + 1 + (x - xmin) / (xmax - xmin)

        mogtables = list() if mogtables is None else mogtables
        for k, (mogtable, opt) \
                in zip(range(-1, -len(mogtables) - 1, -1), mogtables):
            opt = dict() if opt is None else opt
            t0 = opt.get("offset", 0.0)
            name = opt.get("name", f"mogtable {k}")
            color = "C7" if mogtable.color is None else mogtable.color
            mT0 = min(T0, t0 + mogtable.min_time())
            mT1 = max(T1, t0 + mogtable.max_time())

            # default min/max taken from the MOGRF manual
            fmin = min_n((e.frequency for e in mogtable), 10.0)
            fmax = max_n((e.frequency for e in mogtable), 200.0)
            pmin = min_n((e.power for e in mogtable), -70.0)
            pmax = max_n((e.power for e in mogtable), 33.0)
            phmin = min_n((e.phase for e in mogtable), 0.0)
            phmax = max_n((e.phase for e in mogtable), 360.0)

            # set up lines in the analog style for frequency, power, and phase
            for j, (param, lim) in enumerate([
                ("phase", (phmin, phmax)),
                ("power", (pmin, pmax)),
                ("frequency", (fmin, fmax))
            ]):
                P.ax.axhline(
                    mogycoord(0.0, [0.0, 1.0], k, j),
                    color="0.85", linestyle="--",
                    linewidth=linewidth / 2, zorder=100)
                P.ax.axhline(
                    mogycoord(1.0, [0.0, 1.0], k, j),
                    color="0.85", linestyle="--",
                    linewidth=linewidth / 2, zorder=100)
                P.ax.text(
                    mT0, mogycoord(0.5, [0.0, 1.0], k, j),
                    param + " ", #_titlecase(param) + " ",
                    horizontalalignment="right", verticalalignment="center",
                    zorder=101)
                P.ax.text(
                    mT1, mogycoord(0.5, [0.0, 1.0], k, j),
                    " " + param, #" " + _titlecase(param),
                    horizontalalignment="left", verticalalignment="center",
                    zorder=101)
                P.axvline(
                    t0, color="0.25", linestyle=":", linewidth=linewidth / 2,
                    zorder=102)
                P.ax.text(
                    t0, moghk * k + 1,
                    name, #_titlecase(name),
                    horizontalalignment="left", verticalalignment="top",
                    zorder=101)

                t = [(mT0 + t0 + mogtable.min_time()) / 2]
                s = [mogycoord(lim[0], lim, k, j)]
                for i, mogevent in enumerate(mogtable):
                    t += [t0 + mogevent.time, t0 + mogevent.time]
                    s += [
                        s[-1],
                        mogycoord(getattr(mogevent, param), lim, k, j)
                            if getattr(mogevent, param) is not None else s[-1]
                    ]
                t += [(t[-1] + mT1) / 2]
                s += [s[-1]]

                P.plot(
                    [mT0, t[0]],
                    2 * [mogycoord(lim[0], lim, k, j)],
                    color="0.65", linewidth=linewidth, zorder=103
                )
                P.plot(
                    [t[-1], mT1],
                    2 * [s[-1]],
                    color="0.65", linewidth=linewidth, zorder=103
                )
                P.plot(t, s, color=color, linewidth=linewidth, zorder=103)

        P.set_xlim(min(mT0, T0), max(mT1, T1))
        P.set_ylim(
            -moghk * len(mogtables) - 2 - hk / 2,
            hk * len(_connections) + 2 + hk / 2
        )
        return P