Skip to content
Snippets Groups Projects
Commit b7c92d82 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Added some code to interface with opentuner

parent a992b34a
No related branches found
No related tags found
No related merge requests found
import abc
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from opentuner.measurement.interface import MeasurementInterface
from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter
msg_logger = logging.getLogger(__name__)
KnobsT = Dict[str, str]
PathLike = Union[Path, str]
TunerConfigT = Dict[int, int]
class ApproxKnob:
......@@ -15,10 +24,6 @@ class ApproxKnob:
return f'Knob"{self.name}"({self.kwargs})'
KnobsT = Dict[str, str]
PathLike = Union[Path, str]
class ApproxApp(abc.ABC):
"""Generic approximable application with operator & knob enumeration,
and measures its own QoS and performance given a configuration."""
......@@ -38,7 +43,14 @@ class ApproxApp(abc.ABC):
def get_tuner(self) -> "ApproxTuner":
"""We implement this function. Sets up an ApproxTuner instance
which the user can directly call `tune()` on with opentuner parameters."""
return ApproxTuner() # TODO
return ApproxTuner(self)
@property
@abc.abstractmethod
def name(self) -> str:
"""Name of application. Acts as an identifier in many places, so
the user should try to make it unique."""
return ""
class Config:
......@@ -46,19 +58,34 @@ class Config:
class ApproxTuner:
def __init__(self, app: ApproxApp) -> None:
self.app = app
self.tune_sessions = []
def tune(
self,
qos_threshold: float,
accuracy_convention: str = "absolute",
**kwargs, # many opentuner parameters with defaults, omitted
max_iter: int,
qos_tuner_threshold: float,
qos_keep_threshold: Optional[float] = None,
accuracy_convention: str = "absolute" # TODO: this
# TODO: more parameters + opentuner param forwarding
):
"""Generate an optimal set of approximation configurations for the model."""
pass # TODO
from opentuner.tuningrunmain import TuningRunMain
# By default, keep_threshold == tuner_threshold
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
opentuner_args = opentuner_default_args()
tuner = TunerInterface(
opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter,
)
# This is where opentuner runs
TuningRunMain(tuner, opentuner_args).main()
# More helpers for selecting a config omitted for brevity
def get_all_configs(self) -> List[Config]:
return [] # TODO
return [] # TODO: parse opentuner database (do they have helpers?)
# TODO
# Work out details of saving / loading
......@@ -69,3 +96,70 @@ class ApproxTuner:
def load_configs(self, path: PathLike):
pass
def opentuner_default_args():
from opentuner import default_argparser
return default_argparser().parse_args([])
class TunerInterface(MeasurementInterface):
def __init__(
self,
args,
app: ApproxApp,
tuner_thres: float,
keep_thres: float,
test_limit: int,
):
from opentuner.measurement.inputmanager import FixedInputManager
from opentuner.search.objective import ThresholdAccuracyMinimizeTime
from tqdm import tqdm
self.app = app
self.tune_thres = tuner_thres
self.keep_thres = keep_thres
self.pbar = tqdm(total=test_limit, leave=False)
objective = ThresholdAccuracyMinimizeTime(tuner_thres)
input_manager = FixedInputManager(size=len(self.app.op_knobs))
super(TunerInterface, self).__init__(
args,
program_name=self.app.name,
input_manager=input_manager,
objective=objective,
)
def manipulator(self) -> ConfigurationManipulator:
"""Define the search space by creating a ConfigurationManipulator."""
manipulator = ConfigurationManipulator()
for op, knobs in self.app.op_knobs.items():
knob_names = [knob.name for knob in knobs]
manipulator.add_parameter(EnumParameter(op, knob_names))
return manipulator
def run(self, desired_result, input_, limit):
"""Run a given configuration then return performance and accuracy."""
from opentuner.resultsdb.models import Result
cfg = desired_result.configuration.data
qos, perf = self.app.measure_qos_perf(cfg, False)
# Print a debug message for each config in tuning and keep threshold
self.print_debug_config(qos, perf)
self.pbar.update()
return Result(time=perf, accuracy=qos)
def print_debug_config(self, qos: float, perf: float):
gt_tune, gt_keep = qos > self.tune_thres, qos > self.keep_thres
if not gt_tune and not gt_keep:
return
if gt_tune and not gt_keep:
kind = "tuning"
elif gt_keep and not gt_tune:
kind = "keep"
else:
kind = "tuning and keep"
msg_logger.debug(
f"Found config in {kind} threshold: QoS = {qos}, perf = {perf}"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment