From e610f76c9505c4cda5e05b588a08f622a602e61d Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Thu, 8 Apr 2021 02:29:42 -0500
Subject: [PATCH] Improved config plotting

---
 predtuner/approxapp.py  | 141 ++++++++++++++++++++++++-----------
 predtuner/modeledapp.py | 158 +++++++++++++++++++++++++---------------
 2 files changed, 199 insertions(+), 100 deletions(-)

diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py
index 222d3cb..8bd2d33 100644
--- a/predtuner/approxapp.py
+++ b/predtuner/approxapp.py
@@ -318,7 +318,9 @@ class ApproxTuner(Generic[T]):
         self.kept_configs = [
             cfg for cfg in self.all_configs if cfg.qos > self.tune_keep_threshold
         ]
-        self.best_configs_prefilter = self._take_best_configs(self.kept_configs, take_best_n)
+        self.best_configs_prefilter = self._take_best_configs(
+            self.kept_configs, take_best_n
+        )
         msg_logger.info(
             "Tuning finished with %d configs in total, "
             "%d configs above keeping threshold, "
@@ -329,7 +331,8 @@ class ApproxTuner(Generic[T]):
         )
         if test_configs:
             msg_logger.info("Running configurations on test inputs")
-            self.best_configs = self._test_configs(self.best_configs_prefilter)
+            # Also fills in the test QoS of self.best_configs_prefilter
+            self.best_configs = self._test_configs_(self.best_configs_prefilter)
         else:
             self.best_configs = self.best_configs_prefilter
         return self.best_configs
@@ -361,48 +364,110 @@ class ApproxTuner(Generic[T]):
         self,
         show_qos_loss: bool = False,
         connect_best_points: bool = False,
-        use_test_qos: bool = False,
     ) -> plt.Figure:
-        """Plots the QoS and speedup of configurations into a scatter plot.
+        """Plots 1 or 2 QoS-vs-speedup scatter plot of configurations.
+
+        All kept configurations and all "best" configurations (before test-set filtering if any)
+        are always plotted in the first subplot.
+        If test-set filtering was used, the second subplot contains the "best" configurations
+        plotted twice, with tune-set and test-set QoS loss respectively.
 
         :param show_qos_loss: If True, uses the loss of QoS (compared to the baseline)
-               instead of the absolute QoS.
-        :param connect_best_points: If True, draw a line connecting all the "best"
-               configurations (otherwise just plot a scatter plot).
-        :param use_test_qos: If True, plots with the test set QoS (`Config.test_qos`);
-               otherwise plots the tuning QoS (`Config.qos`).
+               instead of the absolute QoS in the first graph.
+               *This does not apply to the second graph* if it exists,
+               which always use QoS loss for ease of comparison.
         """
 
         if not self.tuned:
             raise RuntimeError(
                 f"No tuning session has been run; call self.tune() first."
             )
+        # Without `ax` argument, this function returns if we can
+        # do the second plot or not.
+        dot_format = "-o" if connect_best_points else "o"
+        if self.plot_test_phase():
+            fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 6), dpi=300)
+            self.plot_kept_and_best(ax0, show_qos_loss)
+            self.plot_test_phase(ax1, dot_format)
+        else:
+            fig, ax0 = plt.subplots(1, 1, figsize=(6, 6), dpi=300)
+            self.plot_kept_and_best(ax0, show_qos_loss)
+        fig.tight_layout()
+        return fig
 
-        def qos_speedup(conf):
-            qos = conf.test_qos if use_test_qos else conf.qos
-            baseline_qos = (
-                self.baseline_test_qos if use_test_qos else self.baseline_tune_qos
-            )
-            return baseline_qos - qos if show_qos_loss else qos, conf.speedup
-
-        def get_points(confs):
-            if not confs:
-                return np.zeros((2, 0))
-            sorted_points = np.array(
-                sorted([qos_speedup(c) for c in confs], key=lambda p: p[0])
-            ).T
-            return sorted_points
-
-        fig, ax = plt.subplots()
-        kept_confs = get_points(self.kept_configs)
-        best_confs = get_points(self.best_configs_prefilter)
-        ax.plot(kept_confs[0], kept_confs[1], "o", label="kept")
-        mode = "-o" if connect_best_points else "o"
-        ax.plot(best_confs[0], best_confs[1], mode, label="best")
-        ax.set_xlabel("QoS Loss" if show_qos_loss else "QoS")
+    def plot_kept_and_best(self, ax: plt.Axes, show_qos_loss: bool):
+        kept_confs = self._config_qos_speedups(
+            self.kept_configs, "qos", show_qos_loss, False
+        )
+        best_confs = self._config_qos_speedups(
+            self.best_configs_prefilter, "qos", show_qos_loss, False
+        )
+        ax.plot(kept_confs[0], kept_confs[1], "o", label="Kept Configs")
+        ax.plot(best_confs[0], best_confs[1], "o", label="Best Configs")
+        self._set_xy_limit(ax, show_qos_loss)
+        if show_qos_loss:
+            rthres = self.baseline_tune_qos - self.tune_keep_threshold
+            self._draw_qos_line(ax, rthres, f"Relative threshold: {rthres:.2f}")
+            ax.set_xlabel("QoS Loss (tune dataset)")
+        else:
+            bqos, thres = self.baseline_tune_qos, self.tune_keep_threshold
+            self._draw_qos_line(ax, bqos, f"Baseline QoS: {bqos:.2f}")
+            self._draw_qos_line(ax, thres, f"Threshold: {thres:.2f}")
+            ax.set_xlabel("QoS (tune dataset)")
         ax.set_ylabel("Speedup (x)")
         ax.legend()
-        return fig
+
+    def plot_test_phase(self, ax: plt.Axes = None, dot_format: str = "o"):
+        configs = self.best_configs_prefilter
+        tested = [conf.test_qos is not None for conf in configs]
+        can_plot = all(tested)
+        if not ax:
+            return can_plot
+        assert can_plot
+        tune_x, tune_y = self._config_qos_speedups(configs, "qos", True, False)
+        test_x, test_y = self._config_qos_speedups(configs, "test_qos", True, False)
+        ax.plot(tune_x, tune_y, dot_format, label="Tune-set QoS")
+        ax.plot(test_x, test_y, dot_format, label="Test-set QoS")
+        self._set_xy_limit(ax)
+        rthres = self.baseline_tune_qos - self.tune_keep_threshold
+        self._draw_qos_line(ax, rthres, f"Relative threshold: {rthres:.2f}")
+        ax.set_xlabel("QoS Loss")
+        ax.set_ylabel("Speedup (x)")
+        ax.legend()
+
+    def _set_xy_limit(self, ax: plt.Axes, show_qos_loss: bool = True):
+        xmin, ymin = ax.get_xlim()
+        if show_qos_loss:
+            ax.set_xlim(xmin=min(0, xmin))
+        ax.set_ylim(ymin=min(1, ymin))
+
+    def _config_qos_speedups(
+        self,
+        configs: List[Config],
+        qos_attr: str,
+        qos_loss: bool,
+        baseline_is_test: bool,
+    ):
+        def qos_speedup(conf: Config):
+            qos = getattr(conf, qos_attr)
+            bqos = (
+                self.baseline_test_qos if baseline_is_test else self.baseline_tune_qos
+            )
+            return bqos - qos if qos_loss else qos, conf.speedup
+
+        if not configs:
+            return np.zeros((2, 0))
+        sorted_points = np.array(
+            sorted([qos_speedup(c) for c in configs], key=lambda p: p[0])
+        ).T
+        return sorted_points
+
+    @staticmethod
+    def _draw_qos_line(ax: plt.Axes, qos: float, text: str):
+        ymin, ymax = ax.get_ylim()
+        ymid = (ymin + ymax) / 2
+        ax.axvline(qos)
+        ax.annotate(text, (qos, ymid), rotation=90, verticalalignment="center")
 
     @staticmethod
     def _take_best_configs(configs: List[T], n: Optional[int] = None) -> List[T]:
@@ -410,29 +475,21 @@ class ApproxTuner(Generic[T]):
         taken_idx = is_pareto_efficient(points, take_n=n)
         return [configs[i] for i in taken_idx]
 
-    def _test_configs(self, configs: List[Config]):
-        from copy import deepcopy
-
+    def _test_configs_(self, configs: List[Config]):
         from tqdm import tqdm
 
         assert self.test_keep_threshold is not None
         if not configs:
             return []
-        ret_configs = []
         total_error = 0
         for cfg in tqdm(configs, leave=False):
-            cfg = deepcopy(cfg)
             assert cfg.test_qos is None
             cfg.test_qos, _ = self.app.measure_qos_cost(cfg.knobs, True)
             msg_logger.debug(f"Test dataset: {cfg.qos:.3f} -> {cfg.test_qos:.3f}")
             total_error += abs(cfg.qos - cfg.test_qos)
-            if cfg.test_qos > self.test_keep_threshold:
-                ret_configs.append(cfg)
-            else:
-                msg_logger.debug("Config removed")
         mean_err = total_error / len(configs)
         msg_logger.debug("QoS changed by %f on test dataset (mean abs diff)", mean_err)
-        return ret_configs
+        return [cfg for cfg in configs if cfg.test_qos > self.test_keep_threshold]
 
     def _get_tuner_interface(
         self,
diff --git a/predtuner/modeledapp.py b/predtuner/modeledapp.py
index 0b14cd9..085703c 100644
--- a/predtuner/modeledapp.py
+++ b/predtuner/modeledapp.py
@@ -481,29 +481,22 @@ class ApproxModeledTuner(ApproxTuner):
             msg_logger.info(
                 'Validating configurations due to using qos model "%s"', qos_model
             )
-            self.best_configs = self._update_configs(self.best_configs_prefilter, False)
+            self._update_configs_(self.best_configs_prefilter, False)
         elif validate_configs:
             msg_logger.info("Validating configurations as user requested")
-            self.best_configs = self._update_configs(self.best_configs_prefilter, False)
-        else:
-            self.best_configs = self.best_configs_prefilter
+            self._update_configs_(self.best_configs_prefilter, False)
         if test_configs:
             msg_logger.info("Calibrating configurations on test inputs")
-            self.best_configs = self._update_configs(self.best_configs, True)
+            self._update_configs_(self.best_configs_prefilter, True)
+        self.best_configs = self._filter_configs(self.best_configs_prefilter)
         return self.best_configs
 
-    def _update_configs(self, configs: List[ValConfig], test_mode: bool):
-        from copy import deepcopy
-
+    def _update_configs_(self, configs: List[ValConfig], test_mode: bool):
         from tqdm import tqdm
 
         if not configs:
             msg_logger.info("No configurations found.")
-            return []
-        keep_threshold = (
-            self.test_keep_threshold if test_mode else self.tune_keep_threshold
-        )
-        assert keep_threshold is not None
+            return
         ret_configs = []
         total_error = 0
         for cfg in tqdm(configs, leave=False):
@@ -517,69 +510,118 @@ class ApproxModeledTuner(ApproxTuner):
                 cfg.validated_qos = qos
                 msg_logger.debug(f"Validation: {cfg.qos} (mean) -> {qos} (mean)")
             total_error += abs(cfg.qos - qos)
-            if qos > keep_threshold:
-                ret_configs.append(cfg)
-            else:
-                msg_logger.debug("Config removed")
         mean_err = total_error / len(configs)
         dataset_name = "test" if test_mode else "tune"
-        msg_logger.info("QoS changed by %f on %s dataset (mean abs diff)", mean_err, dataset_name)
-        msg_logger.info("%d of %d configs remain", len(ret_configs), len(configs))
+        msg_logger.info(
+            "QoS changed by %f on %s dataset (mean abs diff)", mean_err, dataset_name
+        )
+
+    def _filter_configs(self, configs: List[ValConfig]):
+        ret_configs = [
+            cfg
+            for cfg in configs
+            if cfg.validated_qos >= self.tune_keep_threshold
+            and cfg.test_qos >= self.test_keep_threshold
+        ]
+        msg_logger.info(
+            "%d of %d configs remain after validation and testing",
+            len(ret_configs),
+            len(configs),
+        )
         return ret_configs
 
     def plot_configs(
         self,
         show_qos_loss: bool = False,
         connect_best_points: bool = False,
-        use_test_qos: bool = False,
     ) -> plt.Figure:
+        """Plots 1 to 3 QoS-vs-speedup scatter plot of configurations.
+
+        All kept configurations and all "best" configurations (before test-set filtering if any)
+        are always plotted in the first subplot.
+
+        If there was a validation phase during tuning,
+        the second subplot contains the "best" configurations plotted twice,
+        with predicted and empirically measured QoS (on tune set) respectively.
+
+        If both validation and test-set filtering was used,
+        the last subplot contains the "best" configurations
+        with *empirically measured* tune-set and test-set QoS loss respectively.
+
+        :param show_qos_loss: If True, uses the loss of QoS (compared to the baseline)
+               instead of the absolute QoS in the first 2 graphs.
+               *This does not apply to the third graph* if it exists,
+               which always use QoS loss for ease of comparison.
+        """
+
         if not self.tuned:
             raise RuntimeError(
                 f"No tuning session has been run; call self.tune() first."
             )
+        # Without `ax` argument, this function returns if we can
+        # do the second/third plot or not.
+        # plot_test_phase returns True implies plot_validation_phase returning True.
+        dot_format = "-o" if connect_best_points else "o"
+        if self.plot_test_phase():
+            fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(14, 6), dpi=300)
+            self.plot_kept_and_best(ax0, show_qos_loss)
+            self.plot_validation_phase(ax1, show_qos_loss, dot_format)
+            self.plot_test_phase(ax2, dot_format)
+        elif self.plot_validation_phase():
+            fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 6), dpi=300)
+            self.plot_kept_and_best(ax0, show_qos_loss)
+            self.plot_validation_phase(ax1, show_qos_loss, dot_format)
+        else:
+            fig, ax0 = plt.subplots(1, 1, figsize=(6, 6), dpi=300)
+            self.plot_kept_and_best(ax0, show_qos_loss)
+        fig.tight_layout()
+        return fig
 
-        # For empirical tuning there's no `validated_qos`.
-        # We first check, and if that's the case, we pass on to our parent class instead.
-        val_qos_nones = [conf.validated_qos is None for conf in self.best_configs]
-        if any(val_qos_nones):
-            assert all(val_qos_nones)
-            return super().plot_configs(
-                show_qos_loss, connect_best_points, use_test_qos
-            )
-
-        if use_test_qos:
-            raise ValueError(
-                "use_test_qos is not yet supported for plotting predictive tuning session."
-            )
+    def plot_validation_phase(
+        self, ax: plt.Axes = None, show_qos_loss: bool = False, dot_format: str = "o"
+    ):
+        configs = self.best_configs_prefilter
+        validated = [conf.validated_qos is not None for conf in configs]
+        can_plot = all(validated)
+        if not ax:
+            return can_plot
+        assert can_plot
+        pred_x, pred_y = self._config_qos_speedups(configs, "qos", show_qos_loss, False)
+        measured_x, measured_y = self._config_qos_speedups(
+            configs, "validated_qos", show_qos_loss, False
+        )
+        ax.plot(pred_x, pred_y, dot_format, label="Predicted QoS")
+        ax.plot(measured_x, measured_y, dot_format, label="Validated QoS")
+        self._set_xy_limit(ax, show_qos_loss)
+        if show_qos_loss:
+            ax.set_xlabel("QoS Loss (tune dataset)")
+            rthres = self.baseline_tune_qos - self.tune_keep_threshold
+            self._draw_qos_line(ax, rthres, f"Relative threshold: {rthres:.2f}")
+        else:
+            ax.set_xlabel("QoS (tune dataset)")
+        ax.set_ylabel("Speedup (x)")
+        ax.legend()
 
-        def get_points(confs, validated):
-            def qos_speedup(conf):
-                return conf.validated_qos if validated else conf.qos, conf.speedup
-
-            if not confs:
-                return np.zeros((2, 0))
-            sorted_points = np.array(
-                sorted([qos_speedup(c) for c in confs], key=lambda p: p[0])
-            ).T
-            if show_qos_loss:
-                sorted_points[0] = self.baseline_tune_qos - sorted_points[0]
-            return sorted_points
-
-        fig, ax = plt.subplots()
-        kept_confs = get_points(self.kept_configs, False)
-        best_confs = get_points(self.best_configs_prefilter, False)
-        best_confs_val = get_points(self.best_configs, True)
-        ax.plot(kept_confs[0], kept_confs[1], "o", label="kept")
-        mode = "-o" if connect_best_points else "o"
-        ax.plot(best_confs[0], best_confs[1], mode, label="best")
-        mode = "-o" if connect_best_points else "o"
-        ax.plot(best_confs_val[0], best_confs_val[1], mode, label="best_validated")
-        ax.set_xlabel(
-            "QoS Loss (Tune dataset)" if show_qos_loss else "QoS (Tune dataset)"
+    def plot_test_phase(self, ax: plt.Axes = None, dot_format: str = "o"):
+        configs = self.best_configs_prefilter
+        tested = [conf.test_qos is not None for conf in configs]
+        validated = [conf.validated_qos is not None for conf in configs]
+        can_plot = all(tested) and all(validated)
+        if not ax:
+            return can_plot
+        assert can_plot
+        tune_x, tune_y = self._config_qos_speedups(
+            configs, "validated_qos", True, False
         )
+        test_x, test_y = self._config_qos_speedups(configs, "test_qos", True, False)
+        ax.plot(tune_x, tune_y, dot_format, label="Tune-set QoS")
+        ax.plot(test_x, test_y, dot_format, label="Test-set QoS")
+        self._set_xy_limit(ax)
+        rthres = self.baseline_tune_qos - self.tune_keep_threshold
+        self._draw_qos_line(ax, rthres, f"Relative threshold: {rthres:.2f}")
+        ax.set_xlabel("Empirical QoS Loss")
         ax.set_ylabel("Speedup (x)")
         ax.legend()
-        return fig
 
     @classmethod
     def _get_config_class(cls) -> Type[Config]:
-- 
GitLab