Skip to content
Snippets Groups Projects
Commit 784d4e46 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Fixed bugs and passed tuning test

parent 8a727602
No related branches found
No related tags found
No related merge requests found
......@@ -76,6 +76,7 @@ class ApproxTuner:
# By default, keep_threshold == tuner_threshold
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
opentuner_args = opentuner_default_args()
opentuner_args.test_limit = max_iter
tuner = TunerInterface(
opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter,
)
......
......@@ -110,7 +110,8 @@ class TorchApp(ModeledApp, abc.ABC):
end = begin + len(target)
qos = self.tensor_to_qos(tensor_output[begin:end], target)
qoses.append(qos)
return self.combine_qos(np.array(qoses))
# float64 -> float
return float(self.combine_qos(np.array(qoses)))
return [
LinearPerfModel(self._op_costs, self._knob_speedups),
......@@ -127,10 +128,11 @@ class TorchApp(ModeledApp, abc.ABC):
qoses = []
for inputs, targets in dataloader:
inputs = move_to_device_recursively(inputs, self.device)
targets = move_to_device_recursively(targets, self.device)
outputs = approxed(inputs)
qoses.append(self.tensor_to_qos(outputs, targets))
qos = self.combine_qos(np.array(qoses))
return 0.0, qos
return float(qos), 0.0 # float64->float
def __repr__(self) -> str:
class_name = self.__class__.__name__
......
import unittest
from torch.utils.data.dataset import Subset
from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy
......@@ -10,21 +12,29 @@ from torchvision.datasets import CIFAR10
from torchvision.models.vgg import vgg16
class TestTorchAppInit(unittest.TestCase):
class TestTorchApp(unittest.TestCase):
def setUp(self):
transform = transforms.Compose([transforms.ToTensor()])
self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose([transforms.ToTensor(), normalize])
dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
self.dataset = Subset(dataset, range(100))
self.module = vgg16(pretrained=True)
def test_init(self):
app = TorchApp(
"test",
def get_app(self):
return TorchApp(
"TestTorchApp",
self.module,
DataLoader(self.dataset),
DataLoader(self.dataset),
get_knobs_from_file(),
accuracy,
)
def test_init(self):
app = self.get_app()
n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
for op_name, op in app.midx.name_to_module.items():
if isinstance(op, Conv2d):
......@@ -34,3 +44,19 @@ class TestTorchAppInit(unittest.TestCase):
else:
nknob = 1
self.assertEqual(n_knobs[op_name], nknob)
# def test_baseline_qos(self):
# app = self.get_app()
# qos, _ = app.measure_qos_perf({}, False)
def test_tuning(self):
app = TorchApp(
"test",
self.module,
DataLoader(self.dataset, batch_size=4),
DataLoader(self.dataset, batch_size=4),
get_knobs_from_file(),
accuracy,
)
tuner = app.get_tuner()
tuner.tune(10, 3.0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment