From b6784183666e46b7cebaa9f463ac47e0a6e3a390 Mon Sep 17 00:00:00 2001
From: "Hsieh, Chiao" <chsieh16@illinois.edu>
Date: Tue, 4 Jan 2022 20:55:14 -0600
Subject: [PATCH] Add ploting script for abs learned via dtree

---
 plot_abstractions.py | 60 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 60 insertions(+)
 create mode 100644 plot_abstractions.py

diff --git a/plot_abstractions.py b/plot_abstractions.py
new file mode 100644
index 0000000..9a68cbf
--- /dev/null
+++ b/plot_abstractions.py
@@ -0,0 +1,60 @@
+from typing import List, Tuple
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+def plot_dtree_abs(state: np.ndarray, dnf: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]):
+    d_lim = 2.0  # PRE_Y_LIM
+    psi_lim = np.pi/3  # PRE_YAW_LIM
+    d_space = np.linspace(-d_lim, d_lim, 800)
+    psi_space = np.linspace(-psi_lim, psi_lim, 800)
+    d, psi = np.meshgrid(d_space, psi_space)
+
+    disjunct_bool_arr = np.zeros((len(d_space)*len(psi_space),))
+    for a_mat, b_vec, coeff_mat, cut_vec in dnf:
+        center = a_mat @ state + b_vec
+        dbar = d - center[0]
+        psibar = psi - center[1]
+        v_arr = np.dot(coeff_mat, [dbar.ravel(), psibar.ravel()])
+        bool_arr = np.all(v_arr.T <= cut_vec, axis=1)
+        disjunct_bool_arr = np.logical_or(disjunct_bool_arr, bool_arr)
+
+    disjunct_bool_arr = disjunct_bool_arr.reshape(d.shape)
+    im = plt.imshow(disjunct_bool_arr.astype(int),
+                    extent=(d.min(), d.max(), psi.min(), psi.max()),
+                    aspect="auto",
+                    origin="lower", cmap="Greens")
+
+    plt.savefig("temp.png")
+
+
+def plot_sygus_abs():
+    pass
+
+
+state = np.array([0., 0., 0.])
+
+a_mat_0 = np.array([[0., -1., 0.],
+                    [0., 0., -1.]])
+b_vec_0 = np.zeros(2)
+coeff_mat_0 = np.array(
+    [[1., 1.],
+     [-1., 0.],
+     [0., 1.],
+     [1., -1.]])
+
+coeff_mat_1 = np.array(
+    [[1., 0.],
+     [-1., 0.],
+     [0., 1.],
+     [0., -1.]])
+
+cut_vec_0 = np.array(
+    [0.5, 0.5, 0.75, 0.75])
+
+candidate_dnf_0 = [
+    # (a_mat_0, b_vec_0, coeff_mat_0, cut_vec_0),
+    (a_mat_0, b_vec_0, coeff_mat_1, cut_vec_0)
+]
+plot_dtree_abs(state, candidate_dnf_0)
-- 
GitLab