From dc862ace3a564c4ccffa108bf0c09565e53babb8 Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Wed, 21 Jun 2023 20:55:26 -0500 Subject: [PATCH] Print and parse the type of size variables --- include/tvm/arith/var_context.h | 2 +- include/tvm/tir/var.h | 4 +--- src/arith/egg_simpl.cc | 22 +++++++++++++--------- src/arith/var_context.cc | 4 ++-- src/auto_scheduler/sym_feats/transform.cc | 4 ++-- src/auto_scheduler/transform_step.cc | 2 +- src/tir/ir/expr.cc | 17 ++--------------- 7 files changed, 22 insertions(+), 33 deletions(-) diff --git a/include/tvm/arith/var_context.h b/include/tvm/arith/var_context.h index a97595a69..33d82f90a 100644 --- a/include/tvm/arith/var_context.h +++ b/include/tvm/arith/var_context.h @@ -68,7 +68,7 @@ class OrderedVarMap { private: tir::SizeVar push_aux(const std::string& vname, const PrimExpr& expr) { this->vnames.insert(vname); - tir::SizeVar var(vname, true); + tir::SizeVar var(vname, expr->dtype); this->expr2var.emplace(expr, var); return var; } diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index fcfbd1340..221bbd9d2 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -153,9 +153,7 @@ class SizeVar : public Var { * \param span The location of this object in the source code. */ TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32), - Span span = Span()); - - SizeVar(String name_hint, bool is_config_var); + Span span = Span(), bool is_config_var = false); /*! * \brief Get pointer to the internal value. diff --git a/src/arith/egg_simpl.cc b/src/arith/egg_simpl.cc index ad29fd8c7..e11d49a64 100644 --- a/src/arith/egg_simpl.cc +++ b/src/arith/egg_simpl.cc @@ -27,8 +27,9 @@ class PreorderPrinter : public ExprFunctor<void(const PrimExpr&)> { void VisitExpr_(const VarNode* op) override { this->var_map.Set(op->name_hint, GetRef<Var>(op)); - ss << op->name_hint; + ss << op->name_hint << ":" << (op->dtype.is_bool() ? "b" : "i"); } + void VisitExpr_(const IntImmNode* op) override { if (op->dtype.is_bool()) { ss << (op->value ? "true" : "false"); @@ -147,15 +148,18 @@ std::pair<PrimExpr, size_t> ParseExprPreorder(const std::string& str, return Bool(true); } else if (str == "false") { return Bool(false); - } else if (std::isalpha(str[0])) { - if (!var_map) { - return Var(str, DataType::Int(32)); - } else { - auto var = var_map.value().Get(str); - if (!var.defined()) { - throw std::runtime_error("Undefined variable: " + str); - } + } else if (str.size() > 1 && str[str.size() - 2] == ':') { + DataType datatype(str.back() == 'b' ? DataType::Bool() : DataType::Int(32)); + auto var_name = str.substr(0, str.size() - 2); + if (var_map) { + auto var = var_map.value().Get(var_name); + ICHECK(var.defined()) << "Undefined variable: " << str; + ICHECK(var.value()->dtype == datatype) + << "Variable " << str << " has type " << var.value()->dtype << " but expected " + << datatype; return var.value(); + } else { + return Var(var_name, datatype); } } auto is_digit = [](char c) { return std::isdigit(c); }; diff --git a/src/arith/var_context.cc b/src/arith/var_context.cc index 3a9acf75a..6e953cfd5 100644 --- a/src/arith/var_context.cc +++ b/src/arith/var_context.cc @@ -84,7 +84,7 @@ Array<SizeVar> VarContext::GetSplitVars(PrimExpr extent, size_t n_splits, bool w Array<String> var_names; for (size_t i = 0; i < n_splits; i++) { String name = "sp_" + std::to_string(data->split_counter) + "_" + std::to_string(i); - vars.push_back(SizeVar(name, true)); + vars.push_back(SizeVar(name, DataType::Int(32), Span(), true)); var_names.push_back(name); } ++data->split_counter; @@ -100,7 +100,7 @@ Array<SizeVar> VarContext::GetSplitVars(PrimExpr extent, size_t n_splits, bool w for (auto& var : vars) { product *= var; } - SizeVar quotient = data->MakeAndInsertVar(extent / product, false); + Var quotient = data->MakeAndInsertVar(extent / product, false); data->_div_map.emplace(extent, product * quotient); return vars; } diff --git a/src/auto_scheduler/sym_feats/transform.cc b/src/auto_scheduler/sym_feats/transform.cc index b669551e5..9f313f748 100644 --- a/src/auto_scheduler/sym_feats/transform.cc +++ b/src/auto_scheduler/sym_feats/transform.cc @@ -213,7 +213,7 @@ class ExpTransform { auto it = this->var_decomp.find(vname); if (it == this->var_decomp.end()) { auto new_name = vname + "_2"; - SizeVar var(new_name, true); + SizeVar var(new_name, DataType::Int(32), Span(), true); this->var_decomp[vname][2] = var; } } @@ -249,7 +249,7 @@ class ExpTransform { for (auto [prime, power] : factors) { PrimExpr con = Integer(0); for (auto& varname : vars) { - SizeVar vprime(varname + "_" + std::to_string(prime), true); + SizeVar vprime(varname + "_" + std::to_string(prime), DataType::Int(32), Span(), true); // sp_0_1 = 2**sp_0_1_2 * 3**sp_0_1_3 * 5**sp_0_1_5 * 7**sp_0_1_7 this->var_decomp[varname][prime] = vprime; con += vprime; diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 1aa5d0bce..37336345e 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -676,7 +676,7 @@ PrimExpr ParseUnrollStep(const std::string &pragma) { if (IsNumber(num_or_var)) { return std::stoi(num_or_var); } else { - return SizeVar(num_or_var, true); + return SizeVar(num_or_var, DataType::Int(32), Span(), true); } } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index c45acd8a9..fb2c33898 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -123,31 +123,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // SizeVar -SizeVar::SizeVar(String name_hint, DataType dtype, Span span) { +SizeVar::SizeVar(String name_hint, DataType dtype, Span span, bool is_config_var) { auto n = make_object<SizeVarNode>(); n->name_hint = std::move(name_hint); n->dtype = std::move(dtype); n->span = std::move(span); - n->is_config_var = false; - data_ = std::move(n); -} - -SizeVar::SizeVar(String name_hint, bool is_config_var) { - auto n = make_object<SizeVarNode>(); - n->name_hint = std::move(name_hint); - n->dtype = DataType::Int(32); - n->span = Span(); n->is_config_var = is_config_var; data_ = std::move(n); } TVM_REGISTER_GLOBAL("tir.SizeVar") .set_body_typed([](String s, DataType t, Span span, bool is_config_var) { - if (is_config_var) { - return SizeVar(s, true); - } else { - return SizeVar(s, t, span); - } + return SizeVar(s, t, span, is_config_var); }); TVM_REGISTER_NODE_TYPE(SizeVarNode); -- GitLab