Skip to content
Snippets Groups Projects
Commit ae767db6 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Add deepcopy / use implicit device

parent b3b5a2a2
Branches main
No related tags found
No related merge requests found
import logging
from copy import deepcopy
from typing import Dict, List, NamedTuple
import opentuner as ot
......@@ -61,7 +62,7 @@ class AutotunedPruner(StructuredPruner):
from opentuner.tuningrunmain import TuningRunMain
# 1. Run tuning
measurer = PruningMeasurer(self.ot_args, self, module, groups)
measurer = PruningMeasurer(self.ot_args, self, deepcopy(module), groups)
trm = TuningRunMain(measurer, self.ot_args)
# A little bit of hack to get the _real_ progress when duplicated configs exist
measurer.set_progress_getter(lambda: trm.search_driver.test_count)
......@@ -75,9 +76,7 @@ class AutotunedPruner(StructuredPruner):
def make_pruned_copy(
self, module: pl.LightningModule, config: Dict[str, List[int]]
):
import copy
module = copy.deepcopy(module)
module = deepcopy(module)
assert self.args is not None
self._intern_prune_module_(module, self.args, config)
return module
......@@ -86,7 +85,7 @@ class AutotunedPruner(StructuredPruner):
# Calibrate on a small portion of training set and get accuracy on validation set.
# (We're already inside torch.no_grad() here)
n_im = 0
module = module.train().cuda()
module = module.train()
loader = module.train_dataloader()
assert isinstance(loader, DataLoader) # Not a list of DataLoaders
for batch in loader:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment