Skip to content
Snippets Groups Projects
sequences.py 10.22 KiB
from __future__ import annotations
from lib.structures import *
import matplotlib.pyplot as pp
import lib.plotdefs as pd
import copy
import pathlib
import toml

"""
Builds on abstractions defined in lib.structures.
"""

def _titlecase(s: str) -> str:
    return " ".join(_s[:1].upper() + _s[1:]
        for _s in s.replace(" ", "_").split("_"))

def _fresh_filename(path: pathlib.Path, overwrite: bool=False) -> pathlib.Path:
    if overwrite:
        return path
    _path = path
    while _path.is_file():
        print(f"WARNING: found existing file {_path}")
        _path = _path.with_stem(_path.stem + "_")
        print(f"Write instead to {_path}")
    return _path

class SuperSequence:
    """
    Sequence of labeled Sequences. Mostly for visualization purposes.

    Fields
    ------
    outdir : pathlib.Path
    name : str
    sequences : dict[str, Sequence]
    """

    def __init__(self, outdir: pathlib.Path, name: str,
        sequences: dict[str, Sequence]=None, defaults: ConnectionLayout=None):
        """
        Constructor.

        Parameters
        ----------
        outdir : pathlib.Path
        name : str
        sequences : dict[str, Sequence]
        """
        self.outdir = outdir
        self.name = name
        self.sequences = dict() if sequences is None else sequences
        self.defaults = defaults

    def __getitem__(self, key: str) -> Sequence:
        assert isinstance(key, str)
        return self.sequences[key]

    def __setitem__(self, key: str, seq: Sequence):
        assert isinstance(key, str)
        assert isinstance(seq, Sequence)
        self.sequences[key] = seq

    def get(self, key, default) -> Sequence:
        return self.sequences.get(key, default)

    def keys(self):
        return self.sequences.keys()

    def values(self):
        return self.sequences.values()

    def items(self):
        return self.sequences.items()

    def update(self, other):
        self.sequences.update(other)
        return self

    def __add__(self, other):
        sequences = copy.deepcopy(self.sequences)
        return SuperSequence(sequences).update(other)

    def __iadd__(self, other):
        return self.update(other)

    def by_times(self) -> list[Sequence]:
        """
        Return Sequences in a list ordered by earliest Event.
        """
        return sorted(
            self.sequences.values(),
            key=lambda e: e.min_time()
        )

    def to_sequence(self) -> Sequence:
        """
        Condense to a single Sequence to pass to the computer.
        """
        seq = Sequence()
        for S in self.by_times():
            seq = seq + S
        return seq

    def to_primitives(self) -> dict[str, dict]:
        return {
            label: seq.to_primitives() for label, seq in self.sequences.items()
        }

    @staticmethod
    def from_primitives(outdir: pathlib.Path, name: str,
            alldata: dict[str, dict]):
        return SuperSequence(outdir, name, {
            label: Sequence.from_primitives(seq_dict)
            for label, seq_dict in alldata.items()
        })

    def save(self, target: pathlib.Path=None, overwrite: bool=False,
            printflag: bool=True):
        T = self.outdir.joinpath(self.name) if target is None else target
        T = T.parent.joinpath(T.name + ".toml") if T.suffix != ".toml" else T
        if printflag: print(f"[sequencing] Saving sequence data to '{T}':")
        if not T.parent.is_dir():
            if printflag: print(f"[sequencing] mkdir {T.parent}")
            T.parent.mkdir(parents=True, exist_ok=True)

        sequence_file = _fresh_filename(T, overwrite)
        if printflag: print(f"[sequencing] save sequence '{T.stem}'")
        with sequence_file.open('w') as outfile:
            toml.dump(self.to_primitives(), outfile)
        return self

    @staticmethod
    def load(target: pathlib.Path, outdir: pathlib.Path=None, name: str=None,
            printflag: bool=True):
        outdir = target.parent if outdir is None else outdir
        name = target.stem if name is None else name
        if printflag: print(f"[sequencing] Loading data from target '{target}':")
        if not target.is_dir():
            raise Exception(f"Target '{target}' does not exist")

        sequence_file = target
        if printflag: print("[sequencing] load sequence")
        if not sequence_file.is_file():
            raise Exception(f"Sequence file '{sequence_file}' does not exist")
        seq_dict = toml.load(sequence_file)
        return SuperSequence.from_primitives(outdir, name, seq_dict)

    @staticmethod
    def _gen_timeline(detailed=False, layout=(1, 4, 32, 1, 32, -10, 10)) \
            -> pd.Plotter:
        FS = pp.rcParams["font.size"]
        pp.rcParams["font.size"] = 3
        P = pd.Plotter()
        P.set_yticks([]).set_yticklabels([]).ggrid(True)
        pp.rcParams["font.size"] = FS
        return P

    def draw_simple(self) -> pd.Plotter:
        P = SuperSequence._gen_timeline(detailed=False)
        FS = pp.rcParams["font.size"]
        pp.rcParams["font.size"] = 3

        H = 0.2
        c = 0
        tl_level = 0
        tl_max = 0
        t1_last = dict()
        Tmin = 0
        Tmax = 0
        for i, (label, seq) in enumerate(self.items()):
            tmin, tmax = seq.when()
            Tmin = min(Tmin, tmin)
            Tmax = max(Tmax, tmax)
            k = 0
            while tmin < t1_last.get(k, tmin):
                k += 1
            tl_level = k if seq.stack_idx is None else seq.stack_idx
            tl_max = max(tl_max, tl_level)
            P.ax.add_patch(
                pp.Rectangle(
                    (tmin, tl_level * H),
                    tmax - tmin,
                    H / 2,
                    edgecolor="k",
                    facecolor=f"C{c}" if seq.color is None else seq.color,
                    zorder=100
                )
            )
            P.ax.text(
                tmin,
                (tl_level * H - H / 15),
                _titlecase(label),
                horizontalalignment="left", verticalalignment="top"
            )
            if seq.color is None:
                c = (c + 1) % 10
            t1_last[tl_level] = max(tmax, t1_last.get(tl_level, tmax))
        P.set_xlim(Tmin - (Tmax - Tmin) / 20, Tmax + (Tmax - Tmin) / 20)
        P.set_ylim(-3* H / 2, (tl_max + 2) * H)
        P.set_xlabel("Time [s]")

        pp.rcParams["font.size"] = FS
        return P

    def draw_detailed(self, connections: dict[str, Connection]=None) \
            -> pd.Plotter:
        assert (connections is not None) ^ (self.defaults is not None)
        _connections = list(
            (connections if connections is not None
                else self.defaults.connections).items()
        )
        P = SuperSequence._gen_timeline(detailed=True)
        FS = pp.rcParams["font.size"]
        pp.rcParams["font.size"] = 3

        T0 = min([seq.min_time() for seq in self.values()])
        T1 = max([seq.max_time() for seq in self.values()])
        color_C = 0
        hk = 3
        def ycoord(x, conn, k):
            y = x if isinstance(conn, DigitalConnection) else (x + 10) / 20
            return hk * k + 1 + y

        # set up background lines and histories on each connection
        # each history is a list of matplotlib line objects for each sequence
        conn_history = dict()
        for k, (conn_label, conn) in enumerate(_connections):
            if isinstance(conn, DigitalConnection):
                P.ax.axhline(
                    hk * k + 1, color="0.85", linestyle="-", linewidth=0.4,
                    zorder=100)
                P.ax.axhline(
                    hk * k + 2, color="0.85", linestyle="--", linewidth=0.4,
                    zorder=100)
            elif isinstance(conn, AnalogConnection):
                P.ax.axhline(
                    hk * k + 1, color="0.85", linestyle="--", linewidth=0.4,
                    zorder=100)
                P.ax.axhline(
                    hk * k + 1.5, color="0.85", linestyle="-", linewidth=0.4,
                    zorder=100)
                P.ax.axhline(
                    hk * k + 2, color="0.85", linestyle="--", linewidth=0.4,
                    zorder=100)
            else:
                raise Exception("Invalid connection type")
            P.ax.text(
                T0, hk * k + 1 + 0.5, _titlecase(conn_label) + " ",
                horizontalalignment="right", verticalalignment="center",
                zorder=101)
            P.ax.text(
                T1, hk * k + 1 + 0.5, " " + _titlecase(conn_label),
                horizontalalignment="left", verticalalignment="center",
                zorder=101)
            conn_history[conn] = [
                P.ax.plot(
                    [T0], [ycoord(conn.default, conn, k)], color="0.65",
                    zorder=103)[0]
            ]

        # add line objects to each history for each sequence they're a part of
        for seq_label, seq in self.items():
            if seq.color is None:
                color = f"C{color_C}"
                color_C = (color + 1) % 10
            else:
                color = seq.color
            t0 = seq.min_time()
            P.ax.axvline(
                t0, color="k", linestyle="--", linewidth=0.4, zorder=102)
            for k, (conn_label, conn) in enumerate(_connections):
                ts = seq._get_states(**conn)
                if len(ts) > 0:
                    _t = [conn_history[conn][-1].get_xdata()[-1]]
                    _s = [conn_history[conn][-1].get_ydata()[-1]]
                    for t, s in ts:
                        _t += [t, t]
                        _s += [_s[-1], ycoord(s, conn, k)]
                    conn_history[conn].append(
                        P.ax.plot(_t, _s, color=color, zorder=103)[0])
                    P.ax.text(t0, hk * k + 1, _titlecase(seq_label),
                        horizontalalignment="left", verticalalignment="top",
                        zorder=101)

        # add tails to the final history items going until T1
        for history in conn_history.values():
            x = history[-1].get_xdata()
            y = history[-1].get_ydata()
            history[-1].set_xdata(list(x) + [T1])
            history[-1].set_ydata(list(y) + [y[-1]])
        
        P.set_xlim(T0, T1)
        P.set_ylim(-2 - hk / 2, hk * len(_connections) + 2 + hk / 2)
        pp.rcParams["font.size"] = FS
        return P