from dataclasses import dataclass
import numpy as np
from pypylon import pylon

def connect_camera():
    """
    Connect to a Basler camera, grabbing the first device listed by the backend.
    This function will not work as expected if multiple cameras are connected.
    """
    camera = pylon.InstantCamera(
        pylon.TlFactory.GetInstance().CreateFirstDevice()
    )
    camera.Open()
    return camera

def disconnect_camera(camera):
    """
    Disconnect from a single camera.
    """
    camera.StopGrabbing()
    camera.Close()

def camera_settings(camera):
    """
    Set a quick-and-dirty configuration.
    - Pixel format: Mono12
    - Exposure time: 10000μs
    - Gain: 0dB
    - Acquisition mode: continuous
    """
    camera.PixelFormat.SetValue("Mono12")
    camera.ExposureTime.SetValue(10000.0)
    camera.Gain.SetValue(0.0)
    camera.AcquisitionMode.SetValue("Continuous")
    # camera.StartGrabbing(pylon.GrabStrategy_LatestImageOnly)
    return camera

def camera_exposure(camera, exposure_us: float):
    """
    Set the camera exposure time in microseconds.
    """
    camera.ExposureTime.SetValue(exposure_us)

def camera_gain(camera, gain_db: float):
    """
    Set the camera gain in decibels.
    """
    camera.Gain.SetValue(gain_db)

def start_grabbing(camera):
    """
    Start grabbing frames (latest image only).
    """
    camera.StartGrabbing(pylon.GrabStrategy_LatestImageOnly)

@dataclass
class ImageData:
    """
    img: NumPy array with image data.
    w: Image width
    h: Image height
    t: Image timestamp in seconds.
    """
    img: np.ndarray
    w: int
    h: int
    t: float

def acquire_data(camera, timeout_ms: int=5000) -> ImageData:
    """
    Pull data for a single image from the camera. See `ImageData` for more info.
    """
    res = camera.RetrieveResult(
        timeout_ms, pylon.TimeoutHandling_ThrowException)
    img = res.Array
    w, h = res.Width, res.Height
    t = res.GetTimeStamp() / 1e9
    res.Release()
    return ImageData(img, w, h, t)

def save_image(camera, filename: str, timeout_ms: int=5000):
    """
    Pull a single image from the camera and immediately write it to a file in
    TIFF format.
    """
    res = camera.RetrieveResult(
        timeout_ms, pylon.TimeoutHandling_ThrowException)
    img = pylon.PylonImage()
    img.AttachGrabResultBuffer(res)
    img.Save(pylon.ImageFileFormat_Tiff, filename)

def acquire_time_series(camera, Trec, rect, dA, K):
    i0, i1, j0, j1 = rect
    T = list()
    I = list()
    _, _, _, t0 = acquire_data(camera)
    Tmax = t0 + Trec
    t = 0
    while t < Tmax:
        X, _, _, t = acquire_data(camera)
        T.append(t)
        I.append(X)
    T = np.array(T) - min(T)
    I = np.array([K * integrate_area(x[i0:i1, j0:j1] / 4095, dA) for x in I])
    return T, I