diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py index c8a9825f535b507e87afd51688f0b23d78832b0d..b1edc136f37c445f054e73022eeb89d97c2be9a9 100644 --- a/predtuner/approxapp.py +++ b/predtuner/approxapp.py @@ -159,8 +159,8 @@ class ApproxTuner(Generic[T]): self.kept_configs = [] self.best_configs = [] # The following will be filled after self.tune() is called - self.keep_threshold = None - self.baseline_qos = None + self.baseline_tune_qos, self.baseline_test_qos = None, None + self.tune_keep_threshold, self.test_keep_threshold = None, None @property def tuned(self) -> bool: @@ -200,7 +200,7 @@ class ApproxTuner(Generic[T]): is_threshold_relative, app_kwargs or {}, ) - assert self.keep_threshold is not None + assert self.tune_keep_threshold is not None trm = TuningRunMain(tuner, opentuner_args) # TuningRunMain.__init__ initializes its own logger, so we'll override it and use ours override_opentuner_config() @@ -219,7 +219,7 @@ class ApproxTuner(Generic[T]): for result, configuration in read_opentuner_db(opentuner_args.database) ] self.kept_configs = [ - cfg for cfg in self.all_configs if cfg.qos > self.keep_threshold + cfg for cfg in self.all_configs if cfg.qos > self.tune_keep_threshold ] self.best_configs = self.take_best_configs(self.kept_configs, take_best_n) msg_logger.info( @@ -240,7 +240,7 @@ class ApproxTuner(Generic[T]): from tqdm import tqdm - assert self.keep_threshold is not None + assert self.test_keep_threshold is not None if not configs: return [] ret_configs = [] @@ -251,7 +251,7 @@ class ApproxTuner(Generic[T]): cfg.test_qos, _ = self.app.measure_qos_cost(cfg.knobs, True) msg_logger.debug(f"Calibration: {cfg.qos} (mean) -> {cfg.test_qos} (mean)") total_error += abs(cfg.qos - cfg.test_qos) - if cfg.test_qos > self.keep_threshold: + if cfg.test_qos > self.test_keep_threshold: ret_configs.append(cfg) else: msg_logger.debug("Config removed") @@ -292,7 +292,11 @@ class ApproxTuner(Generic[T]): ) def qos_speedup(conf): - return conf.test_qos if use_test_qos else conf.qos, conf.speedup + qos = conf.test_qos if use_test_qos else conf.qos + baseline_qos = ( + self.baseline_test_qos if use_test_qos else self.baseline_tune_qos + ) + return baseline_qos - qos if show_qos_loss else qos, conf.speedup def get_points(confs): if not confs: @@ -300,8 +304,6 @@ class ApproxTuner(Generic[T]): sorted_points = np.array( sorted([qos_speedup(c) for c in confs], key=lambda p: p[0]) ).T - if show_qos_loss: - sorted_points[0] = self.baseline_qos - sorted_points[0] return sorted_points fig, ax = plt.subplots() @@ -325,22 +327,26 @@ class ApproxTuner(Generic[T]): app_kwargs: dict, ) -> "TunerInterface": # By default, keep_threshold == tuner_threshold - self.keep_threshold = qos_keep_threshold or qos_tuner_threshold + keep_threshold = qos_keep_threshold or qos_tuner_threshold if is_threshold_relative: - self.baseline_qos, _ = self.app.measure_qos_cost({}, False) - qos_tuner_threshold = self.baseline_qos - qos_tuner_threshold - self.keep_threshold = self.baseline_qos - self.keep_threshold + self.baseline_tune_qos, _ = self.app.measure_qos_cost({}, False) + self.baseline_test_qos, _ = self.app.measure_qos_cost({}, True) + # Now abs threshold + qos_tuner_threshold = self.baseline_tune_qos - qos_tuner_threshold + # These are also abs thresholds + self.tune_keep_threshold = self.baseline_tune_qos - keep_threshold + self.test_keep_threshold = self.baseline_test_qos - keep_threshold opentuner_args.test_limit = max_iter msg_logger.info( - "Tuner QoS threshold: %f; keeping configurations with QoS >= %f", + "Tuner QoS threshold: %f; keeping configurations with QoS >= %f (tune dataset)", qos_tuner_threshold, - self.keep_threshold, + self.tune_keep_threshold, ) return TunerInterface( opentuner_args, self.app, qos_tuner_threshold, - self.keep_threshold, + self.tune_keep_threshold, max_iter, **app_kwargs, ) diff --git a/predtuner/modeledapp.py b/predtuner/modeledapp.py index a188eb9ad0c9ecc0fd4f66920fb5d15540dc1fd4..7b678a47abf49bb696f4d490ae34c9ad451c6a0a 100644 --- a/predtuner/modeledapp.py +++ b/predtuner/modeledapp.py @@ -24,7 +24,9 @@ class ModeledApp(ApproxApp, abc.ABC): for non-modeling application, inherit from `ApproxApp` instead. """ - def __init__(self, op_knobs: Dict[str, List[ApproxKnob]], tuning_device: str = None) -> None: + def __init__( + self, op_knobs: Dict[str, List[ApproxKnob]], tuning_device: str = None + ) -> None: super().__init__(op_knobs, tuning_device) models = self.get_models() self._name_to_model = {m.name: m for m in models} @@ -411,10 +413,13 @@ class ApproxModeledTuner(ApproxTuner): from tqdm import tqdm - assert self.keep_threshold is not None if not configs: msg_logger.info("No configurations found.") return [] + keep_threshold = ( + self.test_keep_threshold if test_mode else self.tune_keep_threshold + ) + assert keep_threshold is not None ret_configs = [] total_error = 0 for cfg in tqdm(configs, leave=False): @@ -429,7 +434,7 @@ class ApproxModeledTuner(ApproxTuner): cfg.validated_qos = qos msg_logger.debug(f"Validation: {cfg.qos} (mean) -> {qos} (mean)") total_error += abs(cfg.qos - qos) - if qos > self.keep_threshold: + if qos > keep_threshold: ret_configs.append(cfg) else: msg_logger.debug("Config removed") @@ -451,7 +456,7 @@ class ApproxModeledTuner(ApproxTuner): # For empirical tuning there's no `validated_qos`. # We first check, and if that's the case, we pass on to our parent class instead. - val_qos_nones = [conf.validated_qos is None for conf in self.kept_configs] + val_qos_nones = [conf.validated_qos is None for conf in self.best_configs] if any(val_qos_nones): assert all(val_qos_nones) return super().plot_configs(show_qos_loss, connect_best_points, False) @@ -466,7 +471,7 @@ class ApproxModeledTuner(ApproxTuner): sorted([qos_speedup(c) for c in confs], key=lambda p: p[0]) ).T if show_qos_loss: - sorted_points[0] = self.baseline_qos - sorted_points[0] + sorted_points[0] = self.baseline_tune_qos - sorted_points[0] return sorted_points fig, ax = plt.subplots() @@ -478,7 +483,7 @@ class ApproxModeledTuner(ApproxTuner): ax.plot(best_confs[0], best_confs[1], mode, label="best") mode = "-o" if connect_best_points else "o" ax.plot(best_confs_val[0], best_confs_val[1], mode, label="best_validated") - ax.set_xlabel("QoS Loss" if show_qos_loss else "QoS") + ax.set_xlabel("QoS Loss (Tune dataset)" if show_qos_loss else "QoS (Tune dataset)") ax.set_ylabel("Speedup (x)") ax.legend() return fig