Skip to content
Snippets Groups Projects
Commit 585bd049 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Remove uses of default hw params to allow running without CUDA

parent ddd8929d
Branches main
No related tags found
No related merge requests found
...@@ -33,7 +33,8 @@ def main(): ...@@ -33,7 +33,8 @@ def main():
features = np.array(features, dtype=object) features = np.array(features, dtype=object)
# features: [batch_size, n_buf, n_features] # features: [batch_size, n_buf, n_features]
dtest, pack_ids = ansor_xgb.feature_to_pack_sum_xgbmatrix(features) dtest, pack_ids = ansor_xgb.feature_to_pack_sum_xgbmatrix(features)
raw_preds = bst_model.predict(dtest) # TODO: run model prediction here.
raw_preds = None # HINT: bst_model.?.?(dtest)
predicted_perf = ansor_xgb.predict_throughput_pack_sum(raw_preds, pack_ids) predicted_perf = ansor_xgb.predict_throughput_pack_sum(raw_preds, pack_ids)
# predicted: [batch_size] # predicted: [batch_size]
# TODO: compare predicted_perf and groundtruth_perf # TODO: compare predicted_perf and groundtruth_perf
......
...@@ -6,14 +6,29 @@ from torch import Tensor ...@@ -6,14 +6,29 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from tvm import auto_scheduler as ansor from tvm import auto_scheduler as ansor
from tvm import relay from tvm import relay
from tvm.auto_scheduler.search_task import HardwareParams
from tvm.target import Target from tvm.target import Target
PathLike = Union[Path, str] PathLike = Union[Path, str]
def load_tuned_configs(dnn: Module, example_input: Tensor, json_file: str, target: Target): def load_tuned_configs(
dnn: Module, example_input: Tensor, json_file: str, target: Target
):
mod, params = gen_tvm_model(dnn, example_input) mod, params = gen_tvm_model(dnn, example_input)
tasks, task_weights = ansor.extract_tasks(mod["main"], params, target) hw_params = HardwareParams(
num_cores=-1,
vector_unit_bytes=16,
cache_line_bytes=64,
max_shared_memory_per_block=49152,
max_local_memory_per_block=2147483647,
max_threads_per_block=1024,
max_vthread_extent=8,
warp_size=32,
)
tasks, task_weights = ansor.extract_tasks(
mod["main"], params, target, hardware_params=hw_params
)
tuner = ansor.TaskScheduler(tasks, task_weights, load_log_file=json_file) tuner = ansor.TaskScheduler(tasks, task_weights, load_log_file=json_file)
return _load_tuned_configs(tuner) return _load_tuned_configs(tuner)
...@@ -98,7 +113,6 @@ def _convert_to_nhwc(model: Module, mod): ...@@ -98,7 +113,6 @@ def _convert_to_nhwc(model: Module, mod):
return mod return mod
def parse_task(task): def parse_task(task):
import ast import ast
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment