Skip to content
Snippets Groups Projects
Commit 1046ce59 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Fixed tune and test set QoS mixing

parent beadbc66
No related branches found
No related tags found
No related merge requests found
......@@ -159,8 +159,8 @@ class ApproxTuner(Generic[T]):
self.kept_configs = []
self.best_configs = []
# The following will be filled after self.tune() is called
self.keep_threshold = None
self.baseline_qos = None
self.baseline_tune_qos, self.baseline_test_qos = None, None
self.tune_keep_threshold, self.test_keep_threshold = None, None
@property
def tuned(self) -> bool:
......@@ -200,7 +200,7 @@ class ApproxTuner(Generic[T]):
is_threshold_relative,
app_kwargs or {},
)
assert self.keep_threshold is not None
assert self.tune_keep_threshold is not None
trm = TuningRunMain(tuner, opentuner_args)
# TuningRunMain.__init__ initializes its own logger, so we'll override it and use ours
override_opentuner_config()
......@@ -219,7 +219,7 @@ class ApproxTuner(Generic[T]):
for result, configuration in read_opentuner_db(opentuner_args.database)
]
self.kept_configs = [
cfg for cfg in self.all_configs if cfg.qos > self.keep_threshold
cfg for cfg in self.all_configs if cfg.qos > self.tune_keep_threshold
]
self.best_configs = self.take_best_configs(self.kept_configs, take_best_n)
msg_logger.info(
......@@ -240,7 +240,7 @@ class ApproxTuner(Generic[T]):
from tqdm import tqdm
assert self.keep_threshold is not None
assert self.test_keep_threshold is not None
if not configs:
return []
ret_configs = []
......@@ -251,7 +251,7 @@ class ApproxTuner(Generic[T]):
cfg.test_qos, _ = self.app.measure_qos_cost(cfg.knobs, True)
msg_logger.debug(f"Calibration: {cfg.qos} (mean) -> {cfg.test_qos} (mean)")
total_error += abs(cfg.qos - cfg.test_qos)
if cfg.test_qos > self.keep_threshold:
if cfg.test_qos > self.test_keep_threshold:
ret_configs.append(cfg)
else:
msg_logger.debug("Config removed")
......@@ -292,7 +292,11 @@ class ApproxTuner(Generic[T]):
)
def qos_speedup(conf):
return conf.test_qos if use_test_qos else conf.qos, conf.speedup
qos = conf.test_qos if use_test_qos else conf.qos
baseline_qos = (
self.baseline_test_qos if use_test_qos else self.baseline_tune_qos
)
return baseline_qos - qos if show_qos_loss else qos, conf.speedup
def get_points(confs):
if not confs:
......@@ -300,8 +304,6 @@ class ApproxTuner(Generic[T]):
sorted_points = np.array(
sorted([qos_speedup(c) for c in confs], key=lambda p: p[0])
).T
if show_qos_loss:
sorted_points[0] = self.baseline_qos - sorted_points[0]
return sorted_points
fig, ax = plt.subplots()
......@@ -325,22 +327,26 @@ class ApproxTuner(Generic[T]):
app_kwargs: dict,
) -> "TunerInterface":
# By default, keep_threshold == tuner_threshold
self.keep_threshold = qos_keep_threshold or qos_tuner_threshold
keep_threshold = qos_keep_threshold or qos_tuner_threshold
if is_threshold_relative:
self.baseline_qos, _ = self.app.measure_qos_cost({}, False)
qos_tuner_threshold = self.baseline_qos - qos_tuner_threshold
self.keep_threshold = self.baseline_qos - self.keep_threshold
self.baseline_tune_qos, _ = self.app.measure_qos_cost({}, False)
self.baseline_test_qos, _ = self.app.measure_qos_cost({}, True)
# Now abs threshold
qos_tuner_threshold = self.baseline_tune_qos - qos_tuner_threshold
# These are also abs thresholds
self.tune_keep_threshold = self.baseline_tune_qos - keep_threshold
self.test_keep_threshold = self.baseline_test_qos - keep_threshold
opentuner_args.test_limit = max_iter
msg_logger.info(
"Tuner QoS threshold: %f; keeping configurations with QoS >= %f",
"Tuner QoS threshold: %f; keeping configurations with QoS >= %f (tune dataset)",
qos_tuner_threshold,
self.keep_threshold,
self.tune_keep_threshold,
)
return TunerInterface(
opentuner_args,
self.app,
qos_tuner_threshold,
self.keep_threshold,
self.tune_keep_threshold,
max_iter,
**app_kwargs,
)
......
......@@ -24,7 +24,9 @@ class ModeledApp(ApproxApp, abc.ABC):
for non-modeling application, inherit from `ApproxApp` instead.
"""
def __init__(self, op_knobs: Dict[str, List[ApproxKnob]], tuning_device: str = None) -> None:
def __init__(
self, op_knobs: Dict[str, List[ApproxKnob]], tuning_device: str = None
) -> None:
super().__init__(op_knobs, tuning_device)
models = self.get_models()
self._name_to_model = {m.name: m for m in models}
......@@ -411,10 +413,13 @@ class ApproxModeledTuner(ApproxTuner):
from tqdm import tqdm
assert self.keep_threshold is not None
if not configs:
msg_logger.info("No configurations found.")
return []
keep_threshold = (
self.test_keep_threshold if test_mode else self.tune_keep_threshold
)
assert keep_threshold is not None
ret_configs = []
total_error = 0
for cfg in tqdm(configs, leave=False):
......@@ -429,7 +434,7 @@ class ApproxModeledTuner(ApproxTuner):
cfg.validated_qos = qos
msg_logger.debug(f"Validation: {cfg.qos} (mean) -> {qos} (mean)")
total_error += abs(cfg.qos - qos)
if qos > self.keep_threshold:
if qos > keep_threshold:
ret_configs.append(cfg)
else:
msg_logger.debug("Config removed")
......@@ -451,7 +456,7 @@ class ApproxModeledTuner(ApproxTuner):
# For empirical tuning there's no `validated_qos`.
# We first check, and if that's the case, we pass on to our parent class instead.
val_qos_nones = [conf.validated_qos is None for conf in self.kept_configs]
val_qos_nones = [conf.validated_qos is None for conf in self.best_configs]
if any(val_qos_nones):
assert all(val_qos_nones)
return super().plot_configs(show_qos_loss, connect_best_points, False)
......@@ -466,7 +471,7 @@ class ApproxModeledTuner(ApproxTuner):
sorted([qos_speedup(c) for c in confs], key=lambda p: p[0])
).T
if show_qos_loss:
sorted_points[0] = self.baseline_qos - sorted_points[0]
sorted_points[0] = self.baseline_tune_qos - sorted_points[0]
return sorted_points
fig, ax = plt.subplots()
......@@ -478,7 +483,7 @@ class ApproxModeledTuner(ApproxTuner):
ax.plot(best_confs[0], best_confs[1], mode, label="best")
mode = "-o" if connect_best_points else "o"
ax.plot(best_confs_val[0], best_confs_val[1], mode, label="best_validated")
ax.set_xlabel("QoS Loss" if show_qos_loss else "QoS")
ax.set_xlabel("QoS Loss (Tune dataset)" if show_qos_loss else "QoS (Tune dataset)")
ax.set_ylabel("Speedup (x)")
ax.legend()
return fig
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment