From dc862ace3a564c4ccffa108bf0c09565e53babb8 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Wed, 21 Jun 2023 20:55:26 -0500
Subject: [PATCH] Print and parse the type of size variables

---
 include/tvm/arith/var_context.h           |  2 +-
 include/tvm/tir/var.h                     |  4 +---
 src/arith/egg_simpl.cc                    | 22 +++++++++++++---------
 src/arith/var_context.cc                  |  4 ++--
 src/auto_scheduler/sym_feats/transform.cc |  4 ++--
 src/auto_scheduler/transform_step.cc      |  2 +-
 src/tir/ir/expr.cc                        | 17 ++---------------
 7 files changed, 22 insertions(+), 33 deletions(-)

diff --git a/include/tvm/arith/var_context.h b/include/tvm/arith/var_context.h
index a97595a69..33d82f90a 100644
--- a/include/tvm/arith/var_context.h
+++ b/include/tvm/arith/var_context.h
@@ -68,7 +68,7 @@ class OrderedVarMap {
  private:
   tir::SizeVar push_aux(const std::string& vname, const PrimExpr& expr) {
     this->vnames.insert(vname);
-    tir::SizeVar var(vname, true);
+    tir::SizeVar var(vname, expr->dtype);
     this->expr2var.emplace(expr, var);
     return var;
   }
diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h
index fcfbd1340..221bbd9d2 100644
--- a/include/tvm/tir/var.h
+++ b/include/tvm/tir/var.h
@@ -153,9 +153,7 @@ class SizeVar : public Var {
    * \param span The location of this object in the source code.
    */
   TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32),
-                           Span span = Span());
-
-  SizeVar(String name_hint, bool is_config_var);
+                           Span span = Span(), bool is_config_var = false);
 
   /*!
    * \brief Get pointer to the internal value.
diff --git a/src/arith/egg_simpl.cc b/src/arith/egg_simpl.cc
index ad29fd8c7..e11d49a64 100644
--- a/src/arith/egg_simpl.cc
+++ b/src/arith/egg_simpl.cc
@@ -27,8 +27,9 @@ class PreorderPrinter : public ExprFunctor<void(const PrimExpr&)> {
 
   void VisitExpr_(const VarNode* op) override {
     this->var_map.Set(op->name_hint, GetRef<Var>(op));
-    ss << op->name_hint;
+    ss << op->name_hint << ":" << (op->dtype.is_bool() ? "b" : "i");
   }
+
   void VisitExpr_(const IntImmNode* op) override {
     if (op->dtype.is_bool()) {
       ss << (op->value ? "true" : "false");
@@ -147,15 +148,18 @@ std::pair<PrimExpr, size_t> ParseExprPreorder(const std::string& str,
       return Bool(true);
     } else if (str == "false") {
       return Bool(false);
-    } else if (std::isalpha(str[0])) {
-      if (!var_map) {
-        return Var(str, DataType::Int(32));
-      } else {
-        auto var = var_map.value().Get(str);
-        if (!var.defined()) {
-          throw std::runtime_error("Undefined variable: " + str);
-        }
+    } else if (str.size() > 1 && str[str.size() - 2] == ':') {
+      DataType datatype(str.back() == 'b' ? DataType::Bool() : DataType::Int(32));
+      auto var_name = str.substr(0, str.size() - 2);
+      if (var_map) {
+        auto var = var_map.value().Get(var_name);
+        ICHECK(var.defined()) << "Undefined variable: " << str;
+        ICHECK(var.value()->dtype == datatype)
+            << "Variable " << str << " has type " << var.value()->dtype << " but expected "
+            << datatype;
         return var.value();
+      } else {
+        return Var(var_name, datatype);
       }
     }
     auto is_digit = [](char c) { return std::isdigit(c); };
diff --git a/src/arith/var_context.cc b/src/arith/var_context.cc
index 3a9acf75a..6e953cfd5 100644
--- a/src/arith/var_context.cc
+++ b/src/arith/var_context.cc
@@ -84,7 +84,7 @@ Array<SizeVar> VarContext::GetSplitVars(PrimExpr extent, size_t n_splits, bool w
   Array<String> var_names;
   for (size_t i = 0; i < n_splits; i++) {
     String name = "sp_" + std::to_string(data->split_counter) + "_" + std::to_string(i);
-    vars.push_back(SizeVar(name, true));
+    vars.push_back(SizeVar(name, DataType::Int(32), Span(), true));
     var_names.push_back(name);
   }
   ++data->split_counter;
@@ -100,7 +100,7 @@ Array<SizeVar> VarContext::GetSplitVars(PrimExpr extent, size_t n_splits, bool w
   for (auto& var : vars) {
     product *= var;
   }
-  SizeVar quotient = data->MakeAndInsertVar(extent / product, false);
+  Var quotient = data->MakeAndInsertVar(extent / product, false);
   data->_div_map.emplace(extent, product * quotient);
   return vars;
 }
diff --git a/src/auto_scheduler/sym_feats/transform.cc b/src/auto_scheduler/sym_feats/transform.cc
index b669551e5..9f313f748 100644
--- a/src/auto_scheduler/sym_feats/transform.cc
+++ b/src/auto_scheduler/sym_feats/transform.cc
@@ -213,7 +213,7 @@ class ExpTransform {
     auto it = this->var_decomp.find(vname);
     if (it == this->var_decomp.end()) {
       auto new_name = vname + "_2";
-      SizeVar var(new_name, true);
+      SizeVar var(new_name, DataType::Int(32), Span(), true);
       this->var_decomp[vname][2] = var;
     }
   }
@@ -249,7 +249,7 @@ class ExpTransform {
     for (auto [prime, power] : factors) {
       PrimExpr con = Integer(0);
       for (auto& varname : vars) {
-        SizeVar vprime(varname + "_" + std::to_string(prime), true);
+        SizeVar vprime(varname + "_" + std::to_string(prime), DataType::Int(32), Span(), true);
         // sp_0_1 = 2**sp_0_1_2 * 3**sp_0_1_3 * 5**sp_0_1_5 * 7**sp_0_1_7
         this->var_decomp[varname][prime] = vprime;
         con += vprime;
diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc
index 1aa5d0bce..37336345e 100644
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -676,7 +676,7 @@ PrimExpr ParseUnrollStep(const std::string &pragma) {
   if (IsNumber(num_or_var)) {
     return std::stoi(num_or_var);
   } else {
-    return SizeVar(num_or_var, true);
+    return SizeVar(num_or_var, DataType::Int(32), Span(), true);
   }
 }
 
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index c45acd8a9..fb2c33898 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -123,31 +123,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     });
 
 // SizeVar
-SizeVar::SizeVar(String name_hint, DataType dtype, Span span) {
+SizeVar::SizeVar(String name_hint, DataType dtype, Span span, bool is_config_var) {
   auto n = make_object<SizeVarNode>();
   n->name_hint = std::move(name_hint);
   n->dtype = std::move(dtype);
   n->span = std::move(span);
-  n->is_config_var = false;
-  data_ = std::move(n);
-}
-
-SizeVar::SizeVar(String name_hint, bool is_config_var) {
-  auto n = make_object<SizeVarNode>();
-  n->name_hint = std::move(name_hint);
-  n->dtype = DataType::Int(32);
-  n->span = Span();
   n->is_config_var = is_config_var;
   data_ = std::move(n);
 }
 
 TVM_REGISTER_GLOBAL("tir.SizeVar")
     .set_body_typed([](String s, DataType t, Span span, bool is_config_var) {
-      if (is_config_var) {
-        return SizeVar(s, true);
-      } else {
-        return SizeVar(s, t, span);
-      }
+      return SizeVar(s, t, span, is_config_var);
     });
 
 TVM_REGISTER_NODE_TYPE(SizeVarNode);
-- 
GitLab