Skip to content
Snippets Groups Projects
sequences.py 6.30 KiB
from __future__ import annotations
from lib.structures import *
import matplotlib.pyplot as pp
import lib.plotdefs as pd
import copy
import pathlib
import toml

"""
Builds on abstractions defined in lib.structures.
"""

def _titlecase(s: str) -> str:
    return " ".join(_s[:1].upper() + _s[1:]
        for _s in s.replace(" ", "_").split("_"))

def _fresh_filename(path: pathlib.Path, overwrite: bool=False) -> pathlib.Path:
    if overwrite:
        return path
    _path = path
    while _path.is_file():
        print(f"WARNING: found existing file {_path}")
        _path = _path.with_stem(_path.stem + "_")
        print(f"Write instead to {_path}")
    return _path

class SuperSequence:
    """
    Sequence of labeled Sequences. Mostly for visualization purposes.

    Fields
    ------
    outdir : pathlib.Path
    name : str
    sequences : dict[str, Sequence]
    """

    def __init__(self, outdir: pathlib.Path, name: str,
            sequences: dict[str, Sequence]=None):
        """
        Constructor.

        Parameters
        ----------
        outdir : pathlib.Path
        name : str
        sequences : dict[str, Sequence]
        """
        self.outdir = outdir
        self.name = name
        self.sequences = dict() if sequences is None else sequences

    def __getitem__(self, key: str) -> Sequence:
        assert isinstance(key, str)
        return self.sequences[key]

    def __setitem__(self, key: str, seq: Sequence):
        assert isinstance(key, str)
        assert isinstance(seq, Sequence)
        self.sequences[key] = seq

    def get(self, key, default) -> Sequence:
        return self.sequences.get(key, default)

    def keys(self):
        return self.sequences.keys()

    def items(self):
        return self.sequences.items()

    def update(self, other):
        self.sequences.update(other)
        return self

    def __add__(self, other):
        sequences = copy.deepcopy(self.sequences)
        return SuperSequence(sequences).update(other)

    def __iadd__(self, other):
        return self.update(other)

    def by_times(self) -> list[Sequence]:
        """
        Return Sequences in a list ordered by earliest Event.
        """
        return sorted(
            self.sequences.values(),
            key=lambda e: e.min_time()
        )

    def to_sequence(self) -> Sequence:
        """
        Condense to a single Sequence to pass to the computer.
        """
        seq = Sequence()
        for S in self.by_times():
            seq = seq + S
        return seq

    def to_primitives(self) -> dict[str, dict]:
        return {
            label: seq.to_primitives() for label, seq in self.sequences.items()
        }

    @staticmethod
    def from_primitives(outdir: pathlib.Path, name: str,
            alldata: dict[str, dict]):
        return SuperSequence(outdir, name, {
            label: Sequence.from_primitives(seq_dict)
            for label, seq_dict in alldata.items()
        })

    def save(self, target: pathlib.Path=None, overwrite: bool=False,
            printflag: bool=True):
        T = self.outdir.joinpath(self.name) if target is None else target
        T = T.parent.joinpath(T.name + ".toml") if T.suffix != ".toml" else T
        if printflag: print(f"[sequencing] Saving sequence data to '{T}':")
        if not T.parent.is_dir():
            if printflag: print(f"[sequencing] mkdir {T.parent}")
            T.parent.mkdir(parents=True, exist_ok=True)

        sequence_file = _fresh_filename(T, overwrite)
        if printflag: print(f"[sequencing] save sequence '{T.stem}'")
        with sequence_file.open('w') as outfile:
            toml.dump(self.to_primitives(), outfile)
        return self

    @staticmethod
    def load(target: pathlib.Path, outdir: pathlib.Path=None, name: str=None,
            printflag: bool=True):
        outdir = target.parent if outdir is None else outdir
        name = target.stem if name is None else name
        if printflag: print(f"[sequencing] Loading data from target '{target}':")
        if not target.is_dir():
            raise Exception(f"Target '{target}' does not exist")

        sequence_file = target
        if printflag: print("[sequencing] load sequence")
        if not sequence_file.is_file():
            raise Exception(f"Sequence file '{sequence_file}' does not exist")
        seq_dict = toml.load(sequence_file)
        return SuperSequence.from_primitives(outdir, name, seq_dict)

    @staticmethod
    def _gen_timeline(detailed=False, layout=(1, 4, 32, 1, 32, -10, 10)) \
            -> pd.Plotter:
        FS = pp.rcParams["font.size"]
        pp.rcParams["font.size"] = 3
        P = pd.Plotter()
        P.set_yticks([]).set_yticklabels([]).ggrid(True)
        pp.rcParams["font.size"] = FS
        return P

    def draw_simple(self) -> pd.Plotter:
        P = SuperSequence._gen_timeline(detailed=False)
        FS = pp.rcParams["font.size"]
        pp.rcParams["font.size"] = 3

        H = 0.2
        c = 0
        tl_level = 0
        tl_max = 0
        t1_last = dict()
        Tmin = 0
        Tmax = 0
        for i, (label, seq) in enumerate(self.items()):
            tmin, tmax = seq.when()
            Tmin = min(Tmin, tmin)
            Tmax = max(Tmax, tmax)
            k = 0
            while tmin < t1_last.get(k, tmin):
                k += 1
            tl_level = k if seq.stack_idx is None else seq.stack_idx
            tl_max = max(tl_max, tl_level)
            P.ax.add_patch(
                pp.Rectangle(
                    (tmin, tl_level * H),
                    tmax - tmin,
                    H / 2,
                    edgecolor="k",
                    facecolor=f"C{c}" if seq.color is None else seq.color,
                    zorder=100
                )
            )
            P.ax.text(
                tmin,
                (tl_level * H - H / 15),
                _titlecase(label),
                horizontalalignment="left", verticalalignment="top"
            )
            if seq.color is None:
                c = (c + 1) % 10
            t1_last[tl_level] = max(tmax, t1_last.get(tl_level, tmax))
        P.set_xlim(Tmin - (Tmax - Tmin) / 20, Tmax + (Tmax - Tmin) / 20)
        P.set_ylim(-3* H / 2, (tl_max + 2) * H)
        P.set_xlabel("Time [s]")

        pp.rcParams["font.size"] = FS
        return P

    def draw_detailed(self) -> pd.Plotter:
        P = SuperSequence._gen_timeline(detailed=True)
        pass