Skip to content
Snippets Groups Projects
Commit 784d4e46 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Fixed bugs and passed tuning test

parent 8a727602
No related branches found
No related tags found
No related merge requests found
...@@ -76,6 +76,7 @@ class ApproxTuner: ...@@ -76,6 +76,7 @@ class ApproxTuner:
# By default, keep_threshold == tuner_threshold # By default, keep_threshold == tuner_threshold
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
opentuner_args = opentuner_default_args() opentuner_args = opentuner_default_args()
opentuner_args.test_limit = max_iter
tuner = TunerInterface( tuner = TunerInterface(
opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter, opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter,
) )
......
...@@ -110,7 +110,8 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -110,7 +110,8 @@ class TorchApp(ModeledApp, abc.ABC):
end = begin + len(target) end = begin + len(target)
qos = self.tensor_to_qos(tensor_output[begin:end], target) qos = self.tensor_to_qos(tensor_output[begin:end], target)
qoses.append(qos) qoses.append(qos)
return self.combine_qos(np.array(qoses)) # float64 -> float
return float(self.combine_qos(np.array(qoses)))
return [ return [
LinearPerfModel(self._op_costs, self._knob_speedups), LinearPerfModel(self._op_costs, self._knob_speedups),
...@@ -127,10 +128,11 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -127,10 +128,11 @@ class TorchApp(ModeledApp, abc.ABC):
qoses = [] qoses = []
for inputs, targets in dataloader: for inputs, targets in dataloader:
inputs = move_to_device_recursively(inputs, self.device) inputs = move_to_device_recursively(inputs, self.device)
targets = move_to_device_recursively(targets, self.device)
outputs = approxed(inputs) outputs = approxed(inputs)
qoses.append(self.tensor_to_qos(outputs, targets)) qoses.append(self.tensor_to_qos(outputs, targets))
qos = self.combine_qos(np.array(qoses)) qos = self.combine_qos(np.array(qoses))
return 0.0, qos return float(qos), 0.0 # float64->float
def __repr__(self) -> str: def __repr__(self) -> str:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
......
import unittest import unittest
from torch.utils.data.dataset import Subset
from predtuner.approxes import get_knobs_from_file from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy from predtuner.torchutil import accuracy
...@@ -10,21 +12,29 @@ from torchvision.datasets import CIFAR10 ...@@ -10,21 +12,29 @@ from torchvision.datasets import CIFAR10
from torchvision.models.vgg import vgg16 from torchvision.models.vgg import vgg16
class TestTorchAppInit(unittest.TestCase): class TestTorchApp(unittest.TestCase):
def setUp(self): def setUp(self):
transform = transforms.Compose([transforms.ToTensor()]) normalize = transforms.Normalize(
self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform) mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose([transforms.ToTensor(), normalize])
dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
self.dataset = Subset(dataset, range(100))
self.module = vgg16(pretrained=True) self.module = vgg16(pretrained=True)
def test_init(self): def get_app(self):
app = TorchApp( return TorchApp(
"test", "TestTorchApp",
self.module, self.module,
DataLoader(self.dataset), DataLoader(self.dataset),
DataLoader(self.dataset), DataLoader(self.dataset),
get_knobs_from_file(), get_knobs_from_file(),
accuracy, accuracy,
) )
def test_init(self):
app = self.get_app()
n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()} n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
for op_name, op in app.midx.name_to_module.items(): for op_name, op in app.midx.name_to_module.items():
if isinstance(op, Conv2d): if isinstance(op, Conv2d):
...@@ -34,3 +44,19 @@ class TestTorchAppInit(unittest.TestCase): ...@@ -34,3 +44,19 @@ class TestTorchAppInit(unittest.TestCase):
else: else:
nknob = 1 nknob = 1
self.assertEqual(n_knobs[op_name], nknob) self.assertEqual(n_knobs[op_name], nknob)
# def test_baseline_qos(self):
# app = self.get_app()
# qos, _ = app.measure_qos_perf({}, False)
def test_tuning(self):
app = TorchApp(
"test",
self.module,
DataLoader(self.dataset, batch_size=4),
DataLoader(self.dataset, batch_size=4),
get_knobs_from_file(),
accuracy,
)
tuner = app.get_tuner()
tuner.tune(10, 3.0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment