Skip to content
Snippets Groups Projects
Commit 7102f5b1 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Improved plotting of configurations

parent c7cd0f44
No related branches found
No related tags found
No related merge requests found
...@@ -114,8 +114,8 @@ class Config: ...@@ -114,8 +114,8 @@ class Config:
self.test_qos: Optional[float] = test_qos self.test_qos: Optional[float] = test_qos
@property @property
def qos_speedup(self): def speedup(self):
return self.qos, 1 / self.cost return 1 / self.cost
T = TypeVar("T", bound=Config) T = TypeVar("T", bound=Config)
...@@ -232,7 +232,7 @@ class ApproxTuner(Generic[T]): ...@@ -232,7 +232,7 @@ class ApproxTuner(Generic[T]):
@staticmethod @staticmethod
def take_best_configs(configs: List[T], n: Optional[int] = None) -> List[T]: def take_best_configs(configs: List[T], n: Optional[int] = None) -> List[T]:
points = np.array([c.qos_speedup for c in configs]) points = np.array([(c.qos, c.speedup) for c in configs])
taken_idx = is_pareto_efficient(points, take_n=n) taken_idx = is_pareto_efficient(points, take_n=n)
return [configs[i] for i in taken_idx] return [configs[i] for i in taken_idx]
...@@ -252,16 +252,22 @@ class ApproxTuner(Generic[T]): ...@@ -252,16 +252,22 @@ class ApproxTuner(Generic[T]):
f.write(encode(confs, indent=2)) f.write(encode(confs, indent=2))
def plot_configs( def plot_configs(
self, show_qos_loss: bool = False, connect_best_points: bool = False self,
show_qos_loss: bool = False,
connect_best_points: bool = False,
use_test_qos: bool = False,
) -> plt.Figure: ) -> plt.Figure:
if not self.tuned: if not self.tuned:
raise RuntimeError( raise RuntimeError(
f"No tuning session has been run; call self.tune() first." f"No tuning session has been run; call self.tune() first."
) )
def qos_speedup(conf):
return conf.test_qos if use_test_qos else conf.qos, conf.speedup
def get_points(confs): def get_points(confs):
sorted_points = np.array( sorted_points = np.array(
sorted([c.qos_speedup for c in confs], key=lambda p: p[0]) sorted([qos_speedup(c) for c in confs], key=lambda p: p[0])
).T ).T
if show_qos_loss: if show_qos_loss:
sorted_points[0] = self.baseline_qos - sorted_points[0] sorted_points[0] = self.baseline_qos - sorted_points[0]
......
...@@ -5,6 +5,7 @@ import pickle ...@@ -5,6 +5,7 @@ import pickle
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
...@@ -390,7 +391,7 @@ class ApproxModeledTuner(ApproxTuner): ...@@ -390,7 +391,7 @@ class ApproxModeledTuner(ApproxTuner):
is_threshold_relative=is_threshold_relative, is_threshold_relative=is_threshold_relative,
take_best_n=take_best_n, take_best_n=take_best_n,
test_configs=False, # Test configs below by ourselves test_configs=False, # Test configs below by ourselves
app_kwargs={"cost_model": cost_model, "qos_model": qos_model} app_kwargs={"cost_model": cost_model, "qos_model": qos_model},
) )
if validate_configs is None and qos_model != "none": if validate_configs is None and qos_model != "none":
msg_logger.info( msg_logger.info(
...@@ -440,6 +441,39 @@ class ApproxModeledTuner(ApproxTuner): ...@@ -440,6 +441,39 @@ class ApproxModeledTuner(ApproxTuner):
msg_logger.info("%d of %d configs remain", len(ret_configs), len(configs)) msg_logger.info("%d of %d configs remain", len(ret_configs), len(configs))
return ret_configs return ret_configs
def plot_configs(
self, show_qos_loss: bool = False, connect_best_points: bool = False
) -> plt.Figure:
if not self.tuned:
raise RuntimeError(
f"No tuning session has been run; call self.tune() first."
)
def get_points(confs, validated):
def qos_speedup(conf):
return conf.validated_qos if validated else conf.qos, conf.speedup
sorted_points = np.array(
sorted([qos_speedup(c) for c in confs], key=lambda p: p[0])
).T
if show_qos_loss:
sorted_points[0] = self.baseline_qos - sorted_points[0]
return sorted_points
fig, ax = plt.subplots()
kept_confs = get_points(self.kept_configs, False)
best_confs = get_points(self.best_configs, False)
best_confs_val = get_points(self.best_configs, True)
ax.plot(kept_confs[0], kept_confs[1], "o", label="valid")
mode = "-o" if connect_best_points else "o"
ax.plot(best_confs[0], best_confs[1], mode, label="best")
mode = "-o" if connect_best_points else "o"
ax.plot(best_confs_val[0], best_confs_val[1], mode, label="best_validated")
ax.set_xlabel("QoS Loss" if show_qos_loss else "QoS")
ax.set_ylabel("Speedup (x)")
ax.legend()
return fig
@classmethod @classmethod
def _get_config_class(cls) -> Type[Config]: def _get_config_class(cls) -> Type[Config]:
return ValConfig return ValConfig
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment