Skip to content
Snippets Groups Projects
utils.py 1.72 KiB
import matplotlib.pyplot as plt
import skimage.draw

def specify_mask(im):
    # get mask
    print("If it doesn't get you to the drawing mode, then rerun this function again.")
    fig,ax = plt.subplots(1,1,figsize=(7,7))
    fig.set_label('Draw polygon around source object')
    plt.axis('off')
    ax.imshow(im, cmap='gray')
    xs = []
    ys = []
    clicked = []
    
    def on_key(event):
        clicked.append(event.key)
        if event.key == 'a':
            mouse_coords()
        elif event.key == 'z':
            plot()
    
    def plot():
        image_shape = img.shape
        print(image_shape)
        rr, cc = skimage.draw.polygon(clicked[1],clicked[0])
        img = np.zeros(image_shape)
        img[rr, cc] = 1
        masked_im1 = img * im
        masked_im1_mask = masked_im1 > 0
        coords = np.argwhere(masked_im1_mask)
        min_y, max_y = np.min(coords[:,0]), np.max(coords[:,0])
        min_x, max_x = np.min(coords[:,1]), np.max(coords[:,1])
        box_coords = [(min_x, min_y), (min_x, max_y), (max_x, min_y), (max_x, max_y)]
        new_img = masked_im1[min_y:max_y,min_x:max_x,:]
        ax.imshow(new_img)
        ax.show()

    def on_mouse_pressed(event):
        x = event.xdata
        y = event.ydata
        xs.append(x)
        ys.append(y)
        ax.plot(x, y, 'r+')

    def mouse_coords(event):
        clicked.append(xs)
        clicked.append(ys)
    # Create an hard reference to the callback not to be cleared by the garbage
    # collector
    fig.canvas.mpl_connect('button_press_event', on_mouse_pressed)
    #on_key_cid = fig.canvas.mpl_connect('key_press_event', on_key)
    #fig.canvas.mpl_disconnect(on_key_cid)
    fig.canvas.mpl_connect('close_event', mouse_coords)
    return clicked