Skip to content
Snippets Groups Projects
Commit 7a84b294 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Passed test for tuning & result fetching

parent a95a2982
No related branches found
No related tags found
No related merge requests found
......@@ -66,6 +66,7 @@ class ApproxTuner:
self.app = app
self.tune_sessions = []
self.db = None
self.keep_threshold = None
def tune(
self,
......@@ -81,8 +82,9 @@ class ApproxTuner:
from opentuner.tuningrunmain import TuningRunMain
# By default, keep_threshold == tuner_threshold
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
opentuner_args = opentuner_default_args()
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
self.keep_threshold = qos_keep_threshold
self.db = opentuner_args.database or f'opentuner.db/{socket.gethostname()}.db'
opentuner_args.test_limit = max_iter
tuner = TunerInterface(
......@@ -94,7 +96,7 @@ class ApproxTuner:
def get_all_configs(self) -> List[Config]:
from ._dbloader import read_opentuner_db
if self.db is None:
if self.db is None or self.keep_threshold is None:
raise RuntimeError(
f"No tuning session has been run; call self.tune() first."
)
......@@ -102,6 +104,7 @@ class ApproxTuner:
return [
Config(result.accuracy, result.time, configuration.data)
for result, configuration in rets
if result.accuracy > self.keep_threshold
]
def write_configs_to_dir(self, directory: PathLike):
......@@ -113,6 +116,8 @@ class ApproxTuner:
_, perf = self.app.measure_qos_perf({}, False)
fig, ax = plt.subplots()
confs = self.get_all_configs()
if not confs:
return fig
qos_speedup = [(c.qos, perf / c.perf) for c in confs]
qoses, speedups = zip(*sorted(qos_speedup, key=lambda p: p[0]))
ax.plot(qoses, speedups)
......
......@@ -22,8 +22,8 @@ class TestTorchApp(unittest.TestCase):
return TorchApp(
"TestTorchApp",
self.module,
DataLoader(self.dataset),
DataLoader(self.dataset),
DataLoader(self.dataset, batch_size=500),
DataLoader(self.dataset, batch_size=500),
get_knobs_from_file(),
accuracy,
)
......@@ -47,13 +47,10 @@ class TestTorchApp(unittest.TestCase):
self.assertAlmostEqual(qos, 88.0)
def test_tuning(self):
app = TorchApp(
"test",
self.module,
DataLoader(self.dataset, batch_size=4),
DataLoader(self.dataset, batch_size=4),
get_knobs_from_file(),
accuracy,
)
app = self.get_app()
baseline, _ = app.measure_qos_perf({}, False)
tuner = app.get_tuner()
tuner.tune(10, 3.0)
tuner.tune(100, baseline - 3.0)
configs = tuner.get_all_configs()
for conf in configs:
self.assertTrue(conf.qos > baseline - 3.0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment