diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py index 8bd2d33466c9076ab4ca320d0c4c984ed197f90f..6ada08df00eb84051e9d98e510d3017b3132e14c 100644 --- a/predtuner/approxapp.py +++ b/predtuner/approxapp.py @@ -417,15 +417,17 @@ class ApproxTuner(Generic[T]): ax.set_ylabel("Speedup (x)") ax.legend() - def plot_test_phase(self, ax: plt.Axes = None, dot_format: str = "o"): + def plot_test_phase( + self, ax: plt.Axes = None, dot_format: str = "o", _tune_key: str = "qos" + ): configs = self.best_configs_prefilter tested = [conf.test_qos is not None for conf in configs] can_plot = all(tested) if not ax: return can_plot assert can_plot - tune_x, tune_y = self._config_qos_speedups(configs, "qos", True, False) - test_x, test_y = self._config_qos_speedups(configs, "test_qos", True, False) + tune_x, tune_y = self._config_qos_speedups(configs, _tune_key, True, False) + test_x, test_y = self._config_qos_speedups(configs, "test_qos", True, True) ax.plot(tune_x, tune_y, dot_format, label="Tune-set QoS") ax.plot(test_x, test_y, dot_format, label="Test-set QoS") self._set_xy_limit(ax) @@ -510,6 +512,11 @@ class ApproxTuner(Generic[T]): # These are also abs thresholds self.tune_keep_threshold = self.baseline_tune_qos - keep_threshold self.test_keep_threshold = self.baseline_test_qos - keep_threshold + msg_logger.info( + "Using relative thresholds: baseline QoS = %f (tune set) and %f (test set)", + self.baseline_tune_qos, + self.baseline_test_qos, + ) else: self.tune_keep_threshold = self.test_keep_threshold = keep_threshold opentuner_args.test_limit = max_iter diff --git a/predtuner/modeledapp.py b/predtuner/modeledapp.py index 085703c78706dd675120c2ea2b8e7395dc2bfe61..ce1e0d82cea0b9dc162883bf4d44bdb02c8843e1 100644 --- a/predtuner/modeledapp.py +++ b/predtuner/modeledapp.py @@ -520,7 +520,7 @@ class ApproxModeledTuner(ApproxTuner): ret_configs = [ cfg for cfg in configs - if cfg.validated_qos >= self.tune_keep_threshold + if (not cfg.validated_qos or cfg.validated_qos >= self.tune_keep_threshold) and cfg.test_qos >= self.test_keep_threshold ] msg_logger.info( @@ -558,22 +558,28 @@ class ApproxModeledTuner(ApproxTuner): raise RuntimeError( f"No tuning session has been run; call self.tune() first." ) + dot_format = "-o" if connect_best_points else "o" # Without `ax` argument, this function returns if we can # do the second/third plot or not. # plot_test_phase returns True implies plot_validation_phase returning True. - dot_format = "-o" if connect_best_points else "o" - if self.plot_test_phase(): - fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(14, 6), dpi=300) - self.plot_kept_and_best(ax0, show_qos_loss) - self.plot_validation_phase(ax1, show_qos_loss, dot_format) - self.plot_test_phase(ax2, dot_format) - elif self.plot_validation_phase(): - fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 6), dpi=300) - self.plot_kept_and_best(ax0, show_qos_loss) - self.plot_validation_phase(ax1, show_qos_loss, dot_format) - else: - fig, ax0 = plt.subplots(1, 1, figsize=(6, 6), dpi=300) - self.plot_kept_and_best(ax0, show_qos_loss) + val_phase = self.plot_validation_phase() + test_phase = self.plot_test_phase() + n_subplots = 1 + int(val_phase) + int(test_phase) + fig, axes = plt.subplots( + 1, n_subplots, squeeze=False, figsize=(6 + 4 * n_subplots, 6), dpi=300 + ) + + i = 1 + self.plot_kept_and_best(axes[0, 0], show_qos_loss) + if val_phase: + ax = axes[0, i] + self.plot_validation_phase(ax, show_qos_loss, dot_format) + i += 1 + if test_phase: + ax = axes[0, i] + tuneset_key = "validated_qos" if val_phase else "qos" + self.plot_test_phase(ax, dot_format, tuneset_key) + i += 1 fig.tight_layout() return fig @@ -602,27 +608,6 @@ class ApproxModeledTuner(ApproxTuner): ax.set_ylabel("Speedup (x)") ax.legend() - def plot_test_phase(self, ax: plt.Axes = None, dot_format: str = "o"): - configs = self.best_configs_prefilter - tested = [conf.test_qos is not None for conf in configs] - validated = [conf.validated_qos is not None for conf in configs] - can_plot = all(tested) and all(validated) - if not ax: - return can_plot - assert can_plot - tune_x, tune_y = self._config_qos_speedups( - configs, "validated_qos", True, False - ) - test_x, test_y = self._config_qos_speedups(configs, "test_qos", True, False) - ax.plot(tune_x, tune_y, dot_format, label="Tune-set QoS") - ax.plot(test_x, test_y, dot_format, label="Test-set QoS") - self._set_xy_limit(ax) - rthres = self.baseline_tune_qos - self.tune_keep_threshold - self._draw_qos_line(ax, rthres, f"Relative threshold: {rthres:.2f}") - ax.set_xlabel("Empirical QoS Loss") - ax.set_ylabel("Speedup (x)") - ax.legend() - @classmethod def _get_config_class(cls) -> Type[Config]: return ValConfig