Skip to content
Snippets Groups Projects
Commit 7a84b294 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Passed test for tuning & result fetching

parent a95a2982
No related branches found
No related tags found
No related merge requests found
...@@ -66,6 +66,7 @@ class ApproxTuner: ...@@ -66,6 +66,7 @@ class ApproxTuner:
self.app = app self.app = app
self.tune_sessions = [] self.tune_sessions = []
self.db = None self.db = None
self.keep_threshold = None
def tune( def tune(
self, self,
...@@ -81,8 +82,9 @@ class ApproxTuner: ...@@ -81,8 +82,9 @@ class ApproxTuner:
from opentuner.tuningrunmain import TuningRunMain from opentuner.tuningrunmain import TuningRunMain
# By default, keep_threshold == tuner_threshold # By default, keep_threshold == tuner_threshold
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
opentuner_args = opentuner_default_args() opentuner_args = opentuner_default_args()
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
self.keep_threshold = qos_keep_threshold
self.db = opentuner_args.database or f'opentuner.db/{socket.gethostname()}.db' self.db = opentuner_args.database or f'opentuner.db/{socket.gethostname()}.db'
opentuner_args.test_limit = max_iter opentuner_args.test_limit = max_iter
tuner = TunerInterface( tuner = TunerInterface(
...@@ -94,7 +96,7 @@ class ApproxTuner: ...@@ -94,7 +96,7 @@ class ApproxTuner:
def get_all_configs(self) -> List[Config]: def get_all_configs(self) -> List[Config]:
from ._dbloader import read_opentuner_db from ._dbloader import read_opentuner_db
if self.db is None: if self.db is None or self.keep_threshold is None:
raise RuntimeError( raise RuntimeError(
f"No tuning session has been run; call self.tune() first." f"No tuning session has been run; call self.tune() first."
) )
...@@ -102,6 +104,7 @@ class ApproxTuner: ...@@ -102,6 +104,7 @@ class ApproxTuner:
return [ return [
Config(result.accuracy, result.time, configuration.data) Config(result.accuracy, result.time, configuration.data)
for result, configuration in rets for result, configuration in rets
if result.accuracy > self.keep_threshold
] ]
def write_configs_to_dir(self, directory: PathLike): def write_configs_to_dir(self, directory: PathLike):
...@@ -113,6 +116,8 @@ class ApproxTuner: ...@@ -113,6 +116,8 @@ class ApproxTuner:
_, perf = self.app.measure_qos_perf({}, False) _, perf = self.app.measure_qos_perf({}, False)
fig, ax = plt.subplots() fig, ax = plt.subplots()
confs = self.get_all_configs() confs = self.get_all_configs()
if not confs:
return fig
qos_speedup = [(c.qos, perf / c.perf) for c in confs] qos_speedup = [(c.qos, perf / c.perf) for c in confs]
qoses, speedups = zip(*sorted(qos_speedup, key=lambda p: p[0])) qoses, speedups = zip(*sorted(qos_speedup, key=lambda p: p[0]))
ax.plot(qoses, speedups) ax.plot(qoses, speedups)
......
...@@ -22,8 +22,8 @@ class TestTorchApp(unittest.TestCase): ...@@ -22,8 +22,8 @@ class TestTorchApp(unittest.TestCase):
return TorchApp( return TorchApp(
"TestTorchApp", "TestTorchApp",
self.module, self.module,
DataLoader(self.dataset), DataLoader(self.dataset, batch_size=500),
DataLoader(self.dataset), DataLoader(self.dataset, batch_size=500),
get_knobs_from_file(), get_knobs_from_file(),
accuracy, accuracy,
) )
...@@ -47,13 +47,10 @@ class TestTorchApp(unittest.TestCase): ...@@ -47,13 +47,10 @@ class TestTorchApp(unittest.TestCase):
self.assertAlmostEqual(qos, 88.0) self.assertAlmostEqual(qos, 88.0)
def test_tuning(self): def test_tuning(self):
app = TorchApp( app = self.get_app()
"test", baseline, _ = app.measure_qos_perf({}, False)
self.module,
DataLoader(self.dataset, batch_size=4),
DataLoader(self.dataset, batch_size=4),
get_knobs_from_file(),
accuracy,
)
tuner = app.get_tuner() tuner = app.get_tuner()
tuner.tune(10, 3.0) tuner.tune(100, baseline - 3.0)
configs = tuner.get_all_configs()
for conf in configs:
self.assertTrue(conf.qos > baseline - 3.0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment