From c6f208df6675e73ad5394781db5492d9d54d9f12 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Tue, 4 Jul 2023 23:02:29 -0500
Subject: [PATCH] Added feature extraction

---
 include/tvm/arith/egg_simpl.h   |    3 +-
 include/tvm/arith/var_context.h |   44 +-
 python/tvm/felix/features.py    |  349 ++++++++++
 python/tvm/felix/ffi.py         |   39 ++
 python/tvm/felix/sketch.py      |   28 +-
 src/arith/egg_simpl.cc          |   15 +-
 src/arith/egg_simpl/src/lang.rs |  112 +++-
 src/arith/egg_simpl/src/lib.rs  |    8 +-
 src/arith/var_context.cc        |  102 ++-
 src/felix/constraints.cc        |  207 ++++++
 src/felix/feat_transform.cc     | 1003 +++++++++++++++++++++++++++++
 src/felix/features.cc           | 1060 +++++++++++++++++++++++++++++++
 src/felix/features.h            |   29 +
 src/felix/rangeinfer.h          |  262 ++++++++
 src/felix/utils.h               |   49 ++
 src/tir/ir/expr.cc              |   49 +-
 src/tir/op/op.cc                |   14 +-
 17 files changed, 3208 insertions(+), 165 deletions(-)
 create mode 100644 python/tvm/felix/features.py
 create mode 100644 src/felix/constraints.cc
 create mode 100644 src/felix/feat_transform.cc
 create mode 100644 src/felix/features.cc
 create mode 100644 src/felix/features.h
 create mode 100644 src/felix/rangeinfer.h

diff --git a/include/tvm/arith/egg_simpl.h b/include/tvm/arith/egg_simpl.h
index 28eaf3e2f..ce2a891ea 100644
--- a/include/tvm/arith/egg_simpl.h
+++ b/include/tvm/arith/egg_simpl.h
@@ -22,8 +22,7 @@ PrimExpr SimplifyExpr(const PrimExpr& expr, size_t max_n_iters = 30, size_t max_
 
 PrimExpr SubAndSimplify(const PrimExpr& expr,
                         const std::unordered_map<std::string, PrimExpr>& subst,
-                        bool simpl_only_on_change, size_t max_n_iters = 30,
-                        size_t max_n_nodes = 10000);
+                        size_t max_n_iters = 30, size_t max_n_nodes = 10000);
 
 bool IsExprEquivalent(const PrimExpr& lhs, const PrimExpr& rhs, size_t max_n_iters = 30,
                       size_t max_n_nodes = 10000, bool diff_approx = false);
diff --git a/include/tvm/arith/var_context.h b/include/tvm/arith/var_context.h
index bd8139840..4a0b7d2f3 100644
--- a/include/tvm/arith/var_context.h
+++ b/include/tvm/arith/var_context.h
@@ -6,6 +6,8 @@
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
 
+#include <optional>
+
 namespace tvm {
 namespace arith {
 
@@ -35,14 +37,11 @@ class VarExprPair : public ObjectRef {
  public:
   VarExprPair(tir::SizeVar var, PrimExpr expr)
       : VarExprPair(make_object<VarExprPairNode>(var, expr)) {}
-  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarExprPair, ObjectRef, VarExprPairNode);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VarExprPair, ObjectRef, VarExprPairNode);
 };
 
 class VarDefStackNode : public Object {
-  using ContainerT = Array<VarExprPair>;
-  using ExprVisitor = std::function<PrimExpr(const PrimExpr& e)>;
-  using ThreadedExprVisitor = std::function<PrimExpr(const PrimExpr& e, size_t)>;
-  using VarExprVisitor = std::function<PrimExpr(const Var& v, const PrimExpr& e)>;
+  using ExprMutator = std::function<PrimExpr(const PrimExpr& e)>;
 
  public:
   void VisitAttrs(tvm::AttrVisitor* v) {
@@ -53,32 +52,36 @@ class VarDefStackNode : public Object {
 
   tir::SizeVar Append(const std::string& vname, const PrimExpr& expr);
   void Append(const tir::SizeVar& var, const PrimExpr& expr);
-  tir::SizeVar FindOrAppend(const std::string& vname, const PrimExpr& expr);
-
-  VarDefStackNode Prepend(const VarMapT& vmap) const;
-  VarMapT IntoVarMap() const;
+  PrimExpr DefineConstShorthand(PrimExpr expr);
 
   bool Contains(const std::string& vname) const {
     return this->var2idx.find(vname) != this->var2idx.end();
   }
   size_t Size() const { return this->exprs.size(); }
-  const ContainerT& GetExprs() const { return this->exprs; }
-  PrimExpr& GetExprAt(const std::string& vname) {
+
+  const Array<VarExprPair>& GetExprs() const { return this->exprs; }
+  const PrimExpr& GetExprAt(const std::string& vname) const {
     auto it = this->var2idx.find(vname);
     ICHECK(it != this->var2idx.end()) << "Var " << vname << " not found in VarDefStack";
     return this->exprs[(*it).second]->expr;
   }
 
-  std::vector<Var> FreeVars() const;
-  bool HasUndefVars(const PrimExpr& expr) const;
+  VarMapT IntoUnwindedVarMap() const;
+  std::unordered_set<std::string> GetAllUsedVars(std::optional<tir::SizeVarKind> kind) const;
+
+  void MapExprs(ExprMutator func);
+  void MapExprsParallel(ExprMutator func);
 
   static constexpr const char* _type_key = "arith.VarDefStack";
   TVM_DECLARE_FINAL_OBJECT_INFO(VarDefStackNode, Object);
 
  private:
-  ContainerT exprs;
+  Array<VarExprPair> exprs;
   Map<String, Integer> var2idx;
   std::unordered_map<PrimExpr, size_t, StructuralHash, StructuralEqual> expr2idx;
+
+  friend class VarDefStack;
+  friend class dmlc::json::Handler<VarDefStackNode>;
 };
 
 class VarDefStack : public ObjectRef {
@@ -125,20 +128,17 @@ class VarContextNode : public Object {
   Array<tir::SizeVar> GetSplitVars(const PrimExpr& extent, size_t n_splits, bool whole_div);
   std::pair<PrimExpr, PrimExpr> GetSplitSizes(const PrimExpr& extent, PrimExpr factor,
                                               bool no_tighten_factor);
+  PrimExpr DefineConstShorthand(PrimExpr expr) {
+    return this->var_defs->DefineConstShorthand(expr);
+  }
 
-  void DefineVar(const std::string& name, PrimExpr expr);
-
-  PrimExpr DefineConstShorthand(PrimExpr expr);
+  static constexpr const char* _type_key = "arith.VarContext";
+  TVM_DECLARE_FINAL_OBJECT_INFO(VarContextNode, Object);
 
  private:
-  Var AllocVarForExpr(PrimExpr expr);
-
   std::pair<PrimExpr, PrimExpr> SymbolicDiv(PrimExpr numer, PrimExpr denom, bool no_tighten_factor);
 
  public:
-  static constexpr const char* _type_key = "arith.VarContext";
-  TVM_DECLARE_FINAL_OBJECT_INFO(VarContextNode, Object);
-
   Array<SplitGroup> split_groups{};
   VarDefStack var_defs{};
 
diff --git a/python/tvm/felix/features.py b/python/tvm/felix/features.py
new file mode 100644
index 000000000..4ad16937b
--- /dev/null
+++ b/python/tvm/felix/features.py
@@ -0,0 +1,349 @@
+import logging
+import typing as ty
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Sequence, Tuple, cast
+
+import numpy as np
+import torch
+from sympy.ntheory import factorint
+from torch import Tensor, nn
+from torch.fx._symbolic_trace import symbolic_trace
+from torch.fx.graph_module import GraphModule
+from tvm import tir
+
+from . import ffi
+from .utils import transpose2
+
+__all__ = ["TorchFeatures"]
+logger = logging.getLogger(__name__)
+
+Number = ty.Union[int, float]
+T = ty.TypeVar("T")
+
+
+class TorchFeatures(nn.Module):
+    def __init__(
+        self,
+        features_f: GraphModule,
+        constraints_f: GraphModule,
+        lin_cons: list,
+        var_order: Dict[str, int],
+        var_decomp: Dict[str, Dict[int, tir.SizeVar]],
+        n_feats: int,
+        n_bufs: int,
+    ):
+        super().__init__()
+        self.features_f = features_f
+        self.constraints_f = constraints_f
+        self.lin_cons = lin_cons
+        self.var_order = var_order
+        self.var_decomp = {
+            k: {int(prime): v.name for prime, v in ds.items()} for k, ds in var_decomp.items()
+        }
+        self.n_feats, self.n_bufs = n_feats, n_bufs
+        self.n_consts = 0
+        self._dummy_param = nn.Parameter(torch.empty(0))  # type: ignore
+        self._conf_maker: Optional[RandConfigMaker] = None
+
+    @classmethod
+    def from_feat_pack(cls, feat: ffi.FeaturePack) -> "TorchFeatures":
+        var_order = {var: i for i, var in enumerate(sorted(feat.free_vars))}
+        features = defaultdict(list)
+        other_cons = []
+        vdefs = {}
+        for vname, expr in feat.expressions:
+            if vname.startswith("BS"):
+                bufstore_idx = int(vname.split(".")[0][2:])
+                features[bufstore_idx].append(expr)
+            elif vname.startswith("con_"):
+                other_cons.append(expr)
+            else:
+                vdefs[vname] = expr
+        features_ = np.array([features[i] for i in range(len(features))])
+        n_bufs, n_feats = features_.shape
+        features_f = TorchExprRunner(features_, var_order, vdefs).get_traced()
+        lin_cons = list(feat.linear_cons)
+        all_cons = np.array([c.as_primexpr() for c in lin_cons] + other_cons)
+        cons_f = TorchExprRunner(all_cons, var_order, vdefs).get_traced()
+        return cls(features_f, cons_f, lin_cons, var_order, feat.var_decomp, n_feats, n_bufs)
+
+    @property
+    def device(self):
+        return self._dummy_param.device
+
+    def rand_configs(self, n_configs: int):
+        if self._conf_maker is None:
+            self._conf_maker = RandConfigMaker(
+                list(self.var_order.keys()),
+                self.lin_cons,
+                self.constraints_f,
+                self.device,
+            )
+        return self._conf_maker.rand_configs(n_configs)
+
+    def forward(self, params: Tensor) -> Tuple[Tensor, Tensor]:
+        # params: [batch_size, n_vars]
+        # features: [batch_size, n_bufs, n_feats]
+        features = self.features_f(params)
+        # leq_0s: [batch_size, n_constraints]
+        leq_0s = self.constraints_f(params)
+        return features, leq_0s
+
+    def run_on_initial_configs(self, initial_configs: Sequence[dict]):
+        configs = [self.transform_config(c) for c in initial_configs]
+        return self(torch.stack(configs, dim=0))
+
+    def transform_config(self, config: Dict[str, Number]):
+        stage1: Dict[str, Number] = {}
+        for var_name, value in config.items():
+            if var_name not in self.var_decomp:
+                stage1[var_name] = value
+                continue
+            decomposed = self.var_decomp[var_name]
+            if value == 0:
+                for vname in decomposed.values():
+                    stage1[vname] = -10  # HACK
+                continue
+            if len(decomposed) == 1:
+                ((b, vname),) = decomposed.items()
+                if b < 2:
+                    raise ValueError()
+                stage1[vname] = np.log(value) / np.log(b)
+                continue
+            factors: Dict[int, int] = factorint(value)
+            uncovered_ps = set(factors.keys()) - set(decomposed.keys())
+            if uncovered_ps:
+                bases = list(decomposed.keys())
+                raise ValueError(f"Cannot factorize {value} into {bases}")
+            for basis, vname in decomposed.items():
+                power = factors.get(basis, 0)
+                stage1[vname] = power
+        stage2: List[Optional[float]] = [None] * len(self.var_order)
+        for name, value in stage1.items():
+            if name not in self.var_order:
+                continue
+            idx = self.var_order[name]
+            stage2[idx] = value
+        if any(x is None for x in stage2):
+            required = set(self.var_order.keys())
+            got = set(stage1.keys())
+            raise ValueError(f"Missing value for some symbols {required - got}")
+        return torch.tensor(cast(List[float], stage2)).float()
+
+    def inv_transform_config(self, config: Tensor):
+        def expr_to_int(config, d: Dict[int, str]):
+            v = np.prod([prime ** config[vname] for prime, vname in d.items()])
+            if not np.isclose(v, int(v)):
+                raise ValueError(f"Cannot convert {v} to int")
+            return int(v)
+
+        stage1 = {k: config[idx].item() for k, idx in self.var_order.items()}
+        return {var: expr_to_int(stage1, d) for var, d in self.var_decomp.items()}
+
+    def _concat_knobs(self, consts: Tensor, knobs: Tensor):
+        assert consts.shape[0] == knobs.shape[0]
+        assert consts.shape[1] == self.n_consts
+        assert consts.shape[1] + knobs.shape[1] == self.n_vars
+        return torch.cat([consts.float(), knobs.float()], dim=1)
+
+
+class RandConfigMaker:
+    def __init__(
+        self,
+        variables: List[str],
+        lin_cons: List[ffi.LinearExpr],
+        constraints_f,
+        device,
+    ) -> None:
+        self.variables = variables
+        self.lin_cons = lin_cons
+        self.constraints_f = constraints_f
+        self.device = device
+        coefs, biases, self.bounds = self._estimate_var_bounds()
+        # Make configs in [0, 1]^n but zoom them before returning.
+        # Consider a single constraint: \sum A_ij x_j + b <= 0.
+        # Now x_i \in [0, bound_i]; make x'_i := x_i / bound_i, then x'_i \in [0, 1].
+        # And the constraint becomes: \sum (A_ij * bound_j) x'_j + b <= 0.
+        # Therefore:
+        self.coefs, self.biases = coefs * self.bounds.unsqueeze(0), biases
+
+    def rand_configs(self, n_configs: int):
+        # This point (0) should be in the polyhedron (on the boundary).
+        current = torch.zeros((1, self.coefs.shape[1]), device=self.device)
+        while len(current) < n_configs + 1:
+            next = self.rand_next_points(current)
+            # Restore configs into the original scale (multiplying with `bounds`; see above)
+            # Still needs to check as we have non-linear constraints too.
+            valid_mask = torch.all(self.constraints_f(next * self.bounds) <= 0, dim=1)
+            next = next[valid_mask]
+            # Don't insert scale-restored points into `configs` yet.
+            # We need to work with [0, 1]^n points here.
+            current = torch.cat([current, next], dim=0)
+        # ret: [n_configs, n_vars] (discard the first point which is (0))
+        return current[1 : n_configs + 1] * self.bounds
+
+    def rand_next_points(self, xs: Tensor):
+        # https://mathoverflow.net/questions/9854
+        assert xs.shape[1] == self.coefs.shape[1]
+        alpha = torch.rand_like(xs, device=self.device)
+        alpha = alpha / torch.linalg.norm(alpha, dim=1, keepdim=True)
+        # Solve A_i . (x + t_i alpha) + b_i <= 0 (. is for dot prod)
+        #   -> t_i ? (-b - A_i . x) / (A_i . alpha)
+        #      (whether t_i is a lower or upper bound depends on the sign of A_i . alpha)
+        # alpha, xs: [n_configs, n_vars]; self.coefs: [n_constraints, n_vars]
+        denoms = torch.tensordot(alpha, self.coefs, dims=([1], [1]))  # type: ignore
+        a_dot_x = torch.tensordot(xs, self.coefs, dims=([1], [1]))  # type: ignore
+        # denoms, a_dot_x: [n_configs, n_constraints]; self.biases: [n_constraints]
+        ts: Tensor = (-self.biases - a_dot_x) / denoms
+        # ts: [n_configs, n_constraints]
+        # NOTE: for each row, we would have at least one column where denoms > 0
+        # and at least one column where denoms < 0.
+        # Otherwise the problem would be unbounded.
+        # (If that actually happens, this code doesn't detect it.)
+        min_ubounds = torch.where(denoms > 0, ts, +1e9).min(dim=1).values
+        max_lbounds = torch.where(denoms < 0, ts, -1e9).max(dim=1).values
+        t_rands = []
+        for i in range(len(xs)):
+            t_min, t_max = max_lbounds[i], min_ubounds[i]
+            assert t_min <= 0 <= t_max
+            # Pick a uniformly random value in [t_min, -d] \union [d, t_max]
+            # (where d == 0.05) if there is such space, to prevent too small
+            # of a step from previous configuration.
+            # (If space is not enough for this operation, t is then just uniformly
+            # sampled from [t_min, t_max]).
+            t_rands.append(self._rand_uniform_with_gap(t_min, t_max))
+        return xs + torch.tensor(t_rands).unsqueeze(1) * alpha
+
+    @staticmethod
+    def _rand_uniform_with_gap(left: float, right: float, half_gap: float = 0.05):
+        assert left <= 0 <= right
+
+        def rand_range_f(l0, r0):
+            return np.random.rand() * (r0 - l0) + l0
+
+        if right <= half_gap and left >= -half_gap:
+            # Not enough space to evade [-half_gap, half_gap].
+            return rand_range_f(left, right)
+        l_size = max(0, -left - half_gap)
+        r_size = max(0, right - half_gap)
+        x = rand_range_f(-l_size, r_size)
+        return -half_gap + x if x < 0 else half_gap + x
+
+    def _estimate_var_bounds(self):
+        from scipy.optimize import linprog
+
+        def _get_coefs(poly: ffi.LinearExpr):
+            lin_terms = {repr(k): v for k, v in poly.lin_terms.items()}
+            coefs = [float(lin_terms.get(s, 0)) for s in self.variables]
+            return coefs, float(poly.constant)
+
+        coefs_biases = [_get_coefs(poly) for poly in self.lin_cons]
+        coefs, biases = transpose2(coefs_biases)
+        coefs, biases = np.array(coefs), np.array(biases)
+        # Run linear optimization with scipy.
+        vars = self.variables
+        var_bounds = []
+        for i in range(len(vars)):
+            coef_c = np.zeros(len(vars))
+            coef_c[i] = -1  # To maximize x[i]
+            result = linprog(coef_c, coefs, -biases)
+            if result.success:
+                var_bounds.append(result.x[i])
+            else:
+                raise RuntimeError(f"Bound inference failed. Diagnostics: \n{result}")
+
+        def np2torch(x):
+            return torch.tensor(x, device=self.device).float()
+
+        return np2torch(coefs), np2torch(biases), np2torch(var_bounds)
+
+
+def safe_log(x: torch.Tensor):
+    # To prevent Infs and NaNs in the result, use this:
+    # log(x)        if x >= 1
+    # -log(2 - x)   otherwise (x <= 1, 2 - x >= 1)
+    return torch.clamp_min(x, 1).log() - torch.clamp_min(2 - x, 1).log()
+
+
+class TorchExprRunner:
+    TIR_OP_TORCH_FN = {
+        tir.Add: (torch.add, ["a", "b"]),
+        tir.Sub: (torch.sub, ["a", "b"]),
+        tir.Mul: (torch.mul, ["a", "b"]),
+        tir.Div: (torch.div, ["a", "b"]),
+        tir.Min: (torch.minimum, ["a", "b"]),
+        tir.Max: (torch.maximum, ["a", "b"]),
+        tir.Cast: (lambda x: x, ["value"]),
+    }
+    TIR_FN_TORCH_FN = {
+        "tir.pow": torch.pow,
+        "tir.exp": torch.exp,
+        "tir.log": safe_log,
+        "tir.logk": lambda base, x: safe_log(x) / torch.log(base),
+        "tir.sigmoid": lambda x: x * torch.pow(x**2 + 1, -0.5) / 2 + 0.5,
+        "tir.hump": lambda x: torch.pow(x**2 + 1, -0.5),
+    }
+
+    def __init__(self, exprs: np.ndarray, var2idx: dict, var2expr: dict) -> None:
+        self.exprs = exprs.ravel().tolist()
+        self.shape = exprs.shape
+        self.var2idx = var2idx
+        self.var2expr = var2expr
+        self.memo: Dict[str, Tensor] = {}
+
+    def get_traced(self) -> GraphModule:
+        self.memo.clear()
+        # `self` is coded in the graph.
+        # Use a lambda to convince symbolic_trace that it's a function, not a member
+        ret = symbolic_trace(lambda input: self.run(input))
+        self.memo.clear()
+        return ret
+
+    def _run_expr_memoized(self, expr: tir.PrimExpr, inputs: Tensor, batch_size: int):
+        estr = repr(expr)
+        if (result := self.memo.get(estr)) is not None:
+            return result
+        self.memo[estr] = result = self._run_expr(expr, inputs, batch_size)
+        return result
+
+    def _run_expr(self, expr: tir.PrimExpr, inputs: Tensor, nb: int) -> Tensor:
+        if isinstance(expr, tir.Var):
+            var_idx = self.var2idx.get(expr.name)
+            if var_idx is not None:
+                return inputs[:, var_idx]
+            var_expr = self.var2expr.get(expr.name)
+            if var_expr is not None:
+                return self._run_expr_memoized(var_expr, inputs, nb)
+            raise KeyError(f"Stray variable {expr.name}")
+        if isinstance(expr, tir.IntImm):
+            dtype = torch.bool if expr.dtype == "bool" else torch.int64
+            return inputs.new_full((nb,), expr.value, dtype=dtype)
+        if isinstance(expr, tir.FloatImm):
+            return inputs.new_full((nb,), expr.value, dtype=torch.float32)
+        if isinstance(expr, tir.Call):
+            if (func := self.TIR_FN_TORCH_FN.get(expr.op.name)) is None:
+                raise ValueError(f"Function call {expr.op.name} is unsupported")
+            args = expr.args
+        else:
+            if (func_fields := self.TIR_OP_TORCH_FN.get(type(expr))) is None:
+                raise ValueError(f"Operator {type(expr)} is unsupported")
+            func, fields = func_fields
+            args = [getattr(expr, f) for f in fields]
+        args = [self._run_expr_memoized(arg, inputs, nb) for arg in args]
+        return func(*args)
+
+    def run(self, inputs: Tensor):
+        batch_size = inputs.shape[0]
+        if not self.exprs:
+            return inputs.new_zeros(batch_size, *self.shape)
+        # inputs: [batch_size, n_vars]
+        ret = []
+        for expr in self.exprs:
+            result = self._run_expr_memoized(expr, inputs, batch_size)
+            if result.dtype is torch.float and (result.abs() > 1e4).any():
+                raise ValueError(f"Result too large: {result} in {expr}")
+            ret.append(result)
+        ret_ = torch.stack(ret, dim=1)
+        # return [batch_size, *shape] (shape = [n_buffers, n_exprs])
+        return ret_.view(-1, *self.shape)
diff --git a/python/tvm/felix/ffi.py b/python/tvm/felix/ffi.py
index ec876c9c4..591016302 100644
--- a/python/tvm/felix/ffi.py
+++ b/python/tvm/felix/ffi.py
@@ -17,6 +17,23 @@ class VarContext(tvm.Object):
         return dict(_arith.VarContextGetVarDefs(self))
 
 
+@tvm._ffi.register_object("ansor.LinearExpr")
+class LinearExpr(tvm.Object):
+    lin_terms: Dict[tir.Var, float]
+    constant: float
+
+    def as_primexpr(self) -> tir.PrimExpr:
+        return _ansor.LinearExprAsPrimExpr(self)
+
+
+@tvm._ffi.register_object("ansor.FeaturePackPy")
+class FeaturePack(tvm.Object):
+    expressions: List[Tuple[str, tir.PrimExpr]]
+    free_vars: List[str]
+    linear_cons: List[LinearExpr]
+    var_decomp: Dict[str, Dict[int, tir.SizeVar]]
+
+
 # PrimExpr Utils
 
 
@@ -57,5 +74,27 @@ def print_state_tr_steps(state: StateObject) -> str:
     return "\n".join(_ansor.PrintTrStep(s) for s in state.transform_steps)
 
 
+def get_feature_pack(
+    code: tir.Stmt,
+    context: VarContext,
+    hw_params: ansor.HardwareParams,
+    sizes: Dict[str, int],
+    cache_line_size: int,
+    max_n_buf: int,
+    factorize: bool,
+    save_load_path: str,
+) -> FeaturePack:
+    return _ansor.GetFeaturePack(
+        code,
+        context,
+        hw_params,
+        sizes,
+        cache_line_size,
+        max_n_buf,
+        factorize,
+        save_load_path,
+    )
+
+
 def get_loop_bounds(code: tir.Stmt) -> List[Tuple[str, tir.PrimExpr]]:
     return list(_ansor.GetLoopBounds(code))
diff --git a/python/tvm/felix/sketch.py b/python/tvm/felix/sketch.py
index 225953779..570cd959f 100644
--- a/python/tvm/felix/sketch.py
+++ b/python/tvm/felix/sketch.py
@@ -1,14 +1,17 @@
 import logging
 from pathlib import Path
+from typing import Dict
 
 from tvm.auto_scheduler.loop_state import StateObject
 from tvm.tir import Stmt
 
 from . import ffi
+from .features import TorchFeatures
+from .utils import HW_PARAMS
 
 _logger = logging.getLogger(__name__)
 __all__ = ["Sketch"]
-CACHE_PATH = (Path(__file__).parent / "../../lightning_logs/features").resolve()
+FEATURE_CACHE_PATH = Path(Path("~").expanduser(), ".tvm", "felix", "features")
 
 
 class Sketch:
@@ -39,7 +42,28 @@ class Sketch:
         return ffi.generate_code_for_state(task, state, False)[0]
 
     def save_path(self) -> Path:
-        return CACHE_PATH / f"{self.state_hash()}.json"
+        return FEATURE_CACHE_PATH / f"{self.state_hash()}.json"
+
+    def fetch_features(
+        self,
+        sizes: Dict[str, int],
+        prime_factorize: bool = True,
+        max_n_buf: int = 5,
+        cache_line_size: int = 64,
+    ):
+        path = self.save_path()
+        path.parent.mkdir(exist_ok=True, parents=True)
+        features = ffi.get_feature_pack(
+            self.code,
+            self.context,
+            HW_PARAMS,
+            sizes,
+            cache_line_size,
+            max_n_buf,
+            prime_factorize,
+            path.as_posix(),
+        )
+        return TorchFeatures.from_feat_pack(features)
 
     def __str__(self) -> str:
         return f"Sketch({self.backbone} from {self.parent_task})"
diff --git a/src/arith/egg_simpl.cc b/src/arith/egg_simpl.cc
index 71bce8435..8dad2ab7a 100644
--- a/src/arith/egg_simpl.cc
+++ b/src/arith/egg_simpl.cc
@@ -293,6 +293,8 @@ std::pair<PrimExpr, size_t> ParseExprPreorder(const std::string& str,
     return {MakeUnOp(op_str, exp, args), loc};
   } else if (op_str == "sigmoid") {
     return {MakeUnOp(op_str, sigmoid, args), loc};
+  } else if (op_str == "hump") {
+    return {MakeUnOp(op_str, hump, args), loc};
   } else {
     throw std::runtime_error("Unknown operator " + op_str + " in " + str);
   }
@@ -313,20 +315,15 @@ PrimExpr SimplifyExpr(const PrimExpr& expr, size_t max_n_iters, size_t max_n_nod
   char* simpl_str = simplify_expr(expr_str.c_str(), max_n_iters, max_n_nodes, diff_approx);
   PrimExpr simplified = ParseExprPreorder(simpl_str, printer.var_map).first;
   free_str(simpl_str);
-  // If not really simplified, don't return the simplified version
-  // because the order of the variables etc. may be different.
-  return CountOps(simplified) < CountOps(expr) ? simplified : expr;
+  return simplified;
 }
 
 PrimExpr SubAndSimplify(const PrimExpr& expr,
-                        const std::unordered_map<std::string, PrimExpr>& subst,
-                        bool simpl_only_on_change, size_t max_n_iters, size_t max_n_nodes) {
+                        const std::unordered_map<std::string, PrimExpr>& subst, size_t max_n_iters,
+                        size_t max_n_nodes) {
   bool changed = false;
   auto expr_ = SubstByName(expr, subst, &changed);
-  if (!changed && simpl_only_on_change) {
-    return expr;
-  }
-  return SimplifyExpr(expr_, max_n_iters, max_n_nodes, true);
+  return changed ? SimplifyExpr(expr_, max_n_iters, max_n_nodes, true) : expr;
 }
 
 bool IsExprEquivalent(const PrimExpr& lhs, const PrimExpr& rhs, size_t max_n_iters,
diff --git a/src/arith/egg_simpl/src/lang.rs b/src/arith/egg_simpl/src/lang.rs
index 576763586..83aff3e5e 100644
--- a/src/arith/egg_simpl/src/lang.rs
+++ b/src/arith/egg_simpl/src/lang.rs
@@ -86,6 +86,7 @@ define_language! {
         "!" = Not([Id; 1]),
         "select" = Select([Id; 3]),
         "sigmoid" = Sigmoid([Id; 1]),
+        "hump" = Hump([Id; 1]),
     }
 }
 
@@ -264,7 +265,7 @@ impl egg::Analysis<Math> for ConstantFold {
     }
 }
 
-fn is_const_with_pred(
+fn _is_const_with_pred(
     var: &str,
     fop: impl Fn(f64) -> bool,
 ) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
@@ -279,11 +280,29 @@ fn is_const_with_pred(
     }
 }
 
-fn is_geq_2(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
-    is_const_with_pred(var, |x| x >= 2.0)
+fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
+    let var = var.parse().unwrap();
+    move |egraph, _, subst| {
+        if let Some(n) = &egraph[subst[var]].data {
+            if let Const::Float(OrderedFloat(f)) = n.0 {
+                return f != 0.0;
+            }
+        }
+        true
+    }
+}
+
+fn is_symbol(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
+    let var = var.parse().unwrap();
+    move |egraph, _, subst| {
+        egraph[subst[var]]
+            .nodes
+            .iter()
+            .any(|n| matches!(n, Math::Symbol(..)))
+    }
 }
 
-fn is_pow_of_(egraph: &mut EGraph, subst: &Subst, x: Var, base: Var) -> Option<i64> {
+fn _is_pow_of(egraph: &mut EGraph, subst: &Subst, x: Var, base: Var) -> Option<i64> {
     let (x, base) = (
         egraph[subst[x]].data.as_ref()?.0,
         egraph[subst[base]].data.as_ref()?.0,
@@ -302,10 +321,10 @@ fn is_pow_of_(egraph: &mut EGraph, subst: &Subst, x: Var, base: Var) -> Option<i
 
 fn is_pow_of(x: &str, base: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
     let (x, base) = (x.parse().unwrap(), base.parse().unwrap());
-    move |egraph, _, subst| is_pow_of_(egraph, subst, x, base).is_some()
+    move |egraph, _, subst| _is_pow_of(egraph, subst, x, base).is_some()
 }
 
-pub static STABLE_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| {
+pub static BASIC_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| {
     vec![
         rewrite!("mul-comm";    "(* ?a ?b)"             => "(* ?b ?a)"),
         rewrite!("mul-assoc";   "(* ?a (* ?b ?c))"      => "(* (* ?a ?b) ?c)"),
@@ -318,7 +337,7 @@ pub static STABLE_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| {
         rewrite!("sub-intro";   "(+ ?a (* -1 ?b))"      => "(- ?a ?b)"),
         rewrite!("sub-cancel";  "(+ ?a (* -1 ?a))"      => "0"),
         rewrite!("div-canon";   "(/ ?a ?b)"             => "(* ?a (pow ?b -1))"),
-        rewrite!("div-intro";   "(* ?a (pow ?b -1))"    => "(/ ?a ?b)"),
+        rewrite!("div-intro";   "(* ?a (pow ?b -1))"    => "(/ ?a ?b)"  if is_not_zero("?b")),
         rewrite!("div-cancel";  "(* ?a (pow ?a -1))"    => "1"),
         rewrite!("add-mul-distrib";     "(* ?a (+ ?b ?c))"              => "(+ (* ?a ?b) (* ?a ?c))"),
         rewrite!("add-mul-factor";      "(+ (* ?a ?b) (* ?a ?c))"       => "(* ?a (+ ?b ?c))"),
@@ -327,12 +346,16 @@ pub static STABLE_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| {
         rewrite!("pow-pow";     "(pow (pow ?a ?b) ?c)"  => "(pow ?a (* ?b ?c))"),
         rewrite!("const-mul-pow";      "(* ?a (pow ?b ?c))"     => "(pow ?b (+ (logk ?b ?a) ?c))" if is_pow_of("?a", "?b")),
         rewrite!("const-div-pow";      "(/ ?a (pow ?b ?c))"     => "(pow ?b (- (logk ?b ?a) ?c))" if is_pow_of("?a", "?b")),
+        rewrite!("max-self";    "(max ?a ?a)"           => "?a"),
+        rewrite!("max-a-add-b"; "(max ?a (+ ?a ?b))"    => "(+ ?a ?b)"  if _is_const_with_pred("?b", |x| x >= 0.0)),
+        rewrite!("min-self";    "(max ?a ?a)"           => "?a"),
         rewrite!("min-pow";     "(min (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (min ?b ?c))"),
-        rewrite!("logk-canon";  "(logk ?a ?b)"      => "(/ (log ?b) (log ?a))"),
-        rewrite!("log-prod";    "(log (* ?a ?b))"   => "(+ (log ?a) (log ?b))"),
-        rewrite!("log-pow";     "(log (pow ?a ?b))" => "(* ?b (log ?a))"),
-        rewrite!("floordiv-cancel"; "(// ?a ?a)"        => "1"),
+        rewrite!("logk-canon";  "(logk ?a ?b)"          => "(/ (log ?b) (log ?a))"),
+        rewrite!("log-prod";    "(log (* ?a ?b))"       => "(+ (log ?a) (log ?b))"),
+        rewrite!("log-pow";     "(log (pow ?a ?b))"     => "(* ?b (log ?a))"),
+        rewrite!("floordiv-cancel"; "(// (* ?a ?b) ?a)"     => "?b"),
         rewrite!("floordiv-merge";  "(// (// ?a ?b) ?c)"    => "(// (/ ?a ?b) ?c)"),
+        rewrite!("floordiv-neg-1";  "(// -1 ?a)"            => "-1"),
         rewrite!("select-same";     "(select ?a ?b ?b)"     => "?b"),
         rewrite!("select-true";     "(select true ?a ?b)"   => "?a"),
         rewrite!("select-false";    "(select false ?a ?b)"  => "?b"),
@@ -357,32 +380,57 @@ pub static STABLE_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| {
         rewrite!("gt-canon";    "(> ?a ?b)"     => "(< ?b ?a)"),
         rewrite!("le-canon";    "(<= ?a ?b)"    => "(! (< ?b ?a))"),
         rewrite!("eq-comm";     "(== ?a ?b)"    => "(== ?b ?a)"),
+        rewrite!("lt-min";      "(< ?a (min ?b ?c))"    => "(&& (< ?a ?b) (< ?a ?c))"),
+        rewrite!("eq-min";      "(== ?a (min ?b ?c))"   => "(|| (== ?a ?b) (== ?a ?c))"),
     ]
 });
 
-pub static ALL_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| {
+pub static DIFF_APPROX_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| {
     let mut ret_rules = vec![];
-    ret_rules.extend_from_slice(STABLE_RULES.as_slice());
+    ret_rules.extend_from_slice(BASIC_RULES.as_slice());
     ret_rules.extend(vec![
         // Differentiability-approx-specific simplification rules
-        rewrite!("lt-pow-1";    "(< (pow ?a ?b) ?c)"    => "(< ?b (logk ?a ?c))"    if is_geq_2("?a")),
-        rewrite!("lt-pow-2";    "(< ?c (pow ?a ?b))"    => "(< (logk ?a ?c) ?b)"    if is_geq_2("?a")),
-        rewrite!("lt-add-1";    "(< (+ ?a ?b) ?c)"      => "(< ?a (- ?c ?b))"),
-        rewrite!("lt-add-2";    "(< ?c (+ ?a ?b))"      => "(< (- ?c ?b) ?a)"),
-        rewrite!("lt-sub-1";    "(< (- ?a ?b) ?c)"      => "(< ?a (+ ?b ?c))"),
-        rewrite!("lt-sub-2";    "(< ?c (- ?a ?b))"      => "(< (+ ?b ?c) ?a)"),
-        rewrite!("lt-div-pow";  "(< ?a (/ ?b (pow ?c ?d)))"      => "(< (* ?a (pow ?c ?d)) ?b)"),
-        rewrite!("lt-mul-pow";  "(< ?a (* ?b (pow ?c ?d)))"      => "(< (/ ?a (pow ?c ?d)) ?b)"),
-        rewrite!("lt-min";      "(< ?a (min ?b ?c))"    => "(&& (< ?a ?b) (< ?a ?c))"),
-        rewrite!("eq-pow";      "(== (pow ?a ?b) ?c)"   => "(== ?b (logk ?a ?c))"   if is_geq_2("?a")),
-        rewrite!("eq-pow-0";    "(== (pow ?a ?b) 0)"    => "false"   if is_geq_2("?a")),
-        rewrite!("eq-sub";      "(== (- ?a ?b) ?c)"     => "(== ?a (+ ?b ?c))"),
-        rewrite!("eq-to-sub";   "(== ?a ?b)"            => "(== (- ?a ?b) 0)"),
-        rewrite!("eq-div";      "(== ?a (/ ?b ?c))"     => "(== (* ?a ?c) ?b)"),
+        rewrite!("1-lt-prod";   "(< 1 (* ?a ?b))"       => "(|| (< 1 ?a) (< 1 ?b))"),
+        rewrite!("lt-pow-1";    "(< (pow ?a ?b) ?c)"    => "(< ?b (logk ?a ?c))"    if _is_const_with_pred("?a", |x| x >= 2.0)),
+        rewrite!("lt-pow-2";    "(< ?c (pow ?a ?b))"    => "(< (logk ?a ?c) ?b)"    if _is_const_with_pred("?a", |x| x >= 2.0)),
+        rewrite!("eq-to-lt";    "(== 1 ?a)"             => "(! (< 1 ?a))"           if is_symbol("?a")),
+        rewrite!("1-eq-prod";   "(== 1 (* ?a ?b))"      => "(&& (== 1 ?a) (== 1 ?b))"),
+        rewrite!("neg-1-never-0";   "(== 0 (* -1 ?a))"  => "false"),
+        rewrite!("hump-0";      "(hump 0)"              => "1"),
+        rewrite!("hump-const";  "(hump ?a)"             => "0" if _is_const_with_pred("?a", |x| x != 0.0)),
     ]);
     ret_rules
 });
 
+pub struct DiffApproxSimplCostFn;
+impl egg::CostFunction<Math> for DiffApproxSimplCostFn {
+    type Cost = usize;
+    fn cost<C>(&mut self, enode: &Math, mut costs: C) -> Self::Cost
+    where
+        C: FnMut(Id) -> Self::Cost,
+    {
+        // Prefer the form (|| (< 1 ?a) (< 1 ?b)) over (< 1 (* a b)).
+        let op_cost = match enode {
+            // Thus all logical and lt / eq operators are cheap,
+            // and everything else is expensive.
+            Math::Const(..) => 1,
+            Math::Symbol(..) => 1,
+            Math::And(..) => 1,
+            Math::Or(..) => 1,
+            Math::Not(..) => 1,
+            Math::Lt(..) => 1,
+            // Prefer Lt over any other comparators
+            Math::Eq(..) => 5,
+            Math::Ne(..) => 5,
+            Math::Le(..) => 5,
+            Math::Gt(..) => 5,
+            Math::Ge(..) => 5,
+            _ => 10,
+        };
+        enode.fold(op_cost, |sum, i| sum + costs(i))
+    }
+}
+
 #[cfg(feature = "count")]
 pub static GLOBAL_RULE_COUNTER: Lazy<Mutex<HashMap<Symbol, usize>>> =
     Lazy::new(|| Mutex::new(HashMap::new()));
@@ -447,9 +495,9 @@ fn make_runner(
     explain: bool,
 ) -> (Runner, &'static Vec<Rewrite>) {
     let rules = if diff_approx {
-        &*ALL_RULES
+        &*DIFF_APPROX_RULES
     } else {
-        &*STABLE_RULES
+        &*BASIC_RULES
     };
     let mut runner = Runner::default()
         .with_iter_limit(n_iters)
@@ -465,7 +513,11 @@ pub fn simplify(expr: &str, n_iters: usize, n_nodes: usize, diff_approx: bool) -
     let (mut runner, rules) = make_runner(n_iters, n_nodes, diff_approx, cfg!(feature = "count"));
     runner = runner.with_expr(&expr).run(rules);
     let root = runner.roots[0];
-    let (_, best) = Extractor::new(&runner.egraph, AstSize).find_best(root);
+    let (_, best) = if diff_approx {
+        Extractor::new(&runner.egraph, DiffApproxSimplCostFn).find_best(root)
+    } else {
+        Extractor::new(&runner.egraph, AstSize).find_best(root)
+    };
     if cfg!(feature = "count") {
         add_rules_to_counter(&mut runner.explain_equivalence(&expr, &best));
     }
diff --git a/src/arith/egg_simpl/src/lib.rs b/src/arith/egg_simpl/src/lib.rs
index c6a6e46a4..e0d674071 100644
--- a/src/arith/egg_simpl/src/lib.rs
+++ b/src/arith/egg_simpl/src/lib.rs
@@ -18,14 +18,14 @@ pub extern "C" fn simplify_expr(
     s_raw: *const c_char,
     n_iters: u64,
     n_nodes: u64,
-    simpl_rel: bool,
+    diff_approx: bool,
 ) -> *mut c_char {
     let s = str_from_ptr(s_raw);
     let simplified = lang::simplify(
         &s,
         n_iters.try_into().unwrap(),
         n_nodes.try_into().unwrap(),
-        simpl_rel,
+        diff_approx,
     );
     let c_str = CString::new(simplified).unwrap();
     c_str.into_raw()
@@ -49,7 +49,7 @@ pub extern "C" fn is_equivalent(
     explain: bool,
     n_iters: u64,
     n_nodes: u64,
-    simpl_rel: bool,
+    diff_approx: bool,
 ) -> bool {
     let lhs = str_from_ptr(lhs_raw);
     let rhs = str_from_ptr(rhs_raw);
@@ -59,7 +59,7 @@ pub extern "C" fn is_equivalent(
         explain,
         n_iters.try_into().unwrap(),
         n_nodes.try_into().unwrap(),
-        simpl_rel,
+        diff_approx,
     )
 }
 
diff --git a/src/arith/var_context.cc b/src/arith/var_context.cc
index 406f5e1be..0f42861ce 100644
--- a/src/arith/var_context.cc
+++ b/src/arith/var_context.cc
@@ -95,63 +95,55 @@ void VarDefStackNode::Append(const SizeVar& var, const PrimExpr& expr) {
   this->exprs.push_back(VarExprPair(var, expr));
 }
 
-SizeVar VarDefStackNode::FindOrAppend(const std::string& vname, const PrimExpr& expr) {
-  auto it = this->expr2idx.find(expr);
-  if (it != this->expr2idx.end()) {
-    return this->exprs[it->second]->var;
-  }
-  return this->Append(vname, expr);
-}
-
-VarDefStackNode VarDefStackNode::Prepend(const VarMapT& vmap) const {
-  VarDefStackNode ret;
-  for (auto& [vname, expr] : vmap) {
-    ret.Append(vname, expr);
-  }
-  for (auto& pair : this->exprs) {
-    ret.Append(pair->var->name_hint, pair->expr);
+PrimExpr VarDefStackNode::DefineConstShorthand(PrimExpr expr) {
+  if (ExprIsConstant(expr) && CountOps(expr) >= 10) {
+    std::string name = "v" + std::to_string(this->exprs.size());
+    expr = Analyzer().canonical_simplify(arith::SimplifyExpr(expr));
+    auto it = this->expr2idx.find(expr);
+    if (it == this->expr2idx.end()) {
+      expr = this->Append(name, expr);
+    } else {
+      expr = this->exprs[it->second]->var;
+    }
   }
-  return ret;
+  return expr;
 }
 
-VarMapT VarDefStackNode::IntoVarMap() const {
+VarMapT VarDefStackNode::IntoUnwindedVarMap() const {
   VarMapT vmap;
   for (const auto& pair : this->exprs) {
-    vmap.emplace(pair->var->name_hint, tir::SubstByName(pair->expr, vmap));
+    vmap.emplace(pair->var->name_hint, SubstByName(pair->expr, vmap));
   }
   return vmap;
 }
 
-std::vector<Var> VarDefStackNode::FreeVars() const {
-  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> vars;
-  auto CollectVars = [&vars](const ObjectRef& node) {
-    if (const VarNode* op = node.as<VarNode>()) {
-      vars.insert(GetRef<Var>(op));
-    }
-  };
-  for (const auto& pair : this->GetExprs()) {
-    tir::PostOrderVisit(pair->expr, CollectVars);
+std::unordered_set<std::string> VarDefStackNode::GetAllUsedVars(
+    std::optional<SizeVarKind> kind) const {
+  std::unordered_set<std::string> ret;
+  for (auto& pair : this->exprs) {
+    PostOrderVisit(pair->expr, [this, &ret, kind](const ObjectRef& node) {
+      if (auto* svnode = node.as<SizeVarNode>()) {
+        if (!kind || svnode->kind == kind) {
+          ret.insert(svnode->name_hint);
+        }
+      }
+    });
   }
-  for (const auto& pair : this->GetExprs()) {
-    vars.erase(pair->var);
+  return ret;
+}
+
+void VarDefStackNode::MapExprs(ExprMutator func) {
+  for (size_t i = 0; i < this->exprs.size(); i++) {
+    auto& pair = this->exprs[i];
+    this->exprs.Set(i, VarExprPair(pair->var, func(pair->expr)));
   }
-  return std::vector<Var>(vars.begin(), vars.end());
 }
 
-bool VarDefStackNode::HasUndefVars(const PrimExpr& expr) const {
-  // SizeVar is seen as constant and doesn't count unless it has kOther type.
-  bool has_undef = false;
-  auto CheckUndef = [&has_undef, this](const ObjectRef& obj) {
-    if (auto* vnode = obj.as<VarNode>()) {
-      auto* svnode = obj.as<SizeVarNode>();
-      if (this->var2idx.count(vnode->name_hint) == 0 &&
-          (!svnode || svnode->kind == SizeVarKind::kOther)) {
-        has_undef = true;
-      }
-    }
-  };
-  tir::PostOrderVisit(expr, CheckUndef);
-  return has_undef;
+void VarDefStackNode::MapExprsParallel(ExprMutator func) {
+  support::parallel_for(0, this->exprs.size(), [this, &func](int i) {
+    auto& pair = this->exprs[i];
+    this->exprs.Set(i, VarExprPair(pair->var, func(pair->expr)));
+  });
 }
 
 inline std::pair<PrimExpr, PrimExpr> ConservativeDiv(PrimExpr extent, PrimExpr factor,
@@ -174,8 +166,8 @@ Array<SizeVar> VarContextNode::GetSplitVars(const PrimExpr& extent, size_t n_spl
     var_names.push_back(name);
     product *= var;
   }
-  // Declare a quotient variable which would be equal to Extent_i / (sp_i_0*sp_i_1*...sp_i_j), as qi.
-  // For example q3 could be Cout / (sp_3_0 * sp_3_1 * sp_3_2...) for group 3.
+  // Declare a quotient variable which would be equal to Extent_i / (sp_i_0*sp_i_1*...sp_i_j), as
+  // qi. For example q3 could be Cout / (sp_3_0 * sp_3_1 * sp_3_2...) for group 3.
   // * Don't even define this variable in this->var_defs. We'll need to delay the expansion of q{i}
   //   as much as possible, and when finally it's needed it can be derived from the SplitGroup.
   SizeVar quotient("q" + group_idx, SizeVarKind::kShorthand);
@@ -213,29 +205,13 @@ std::pair<PrimExpr, PrimExpr> VarContextNode::GetSplitSizes(const PrimExpr& exte
   return this->SymbolicDiv(extent, factor, no_tighten_factor);
 }
 
-void VarContextNode::DefineVar(const std::string& name, PrimExpr expr) {
-  this->var_defs->Append(name, expr);
-}
-
-PrimExpr VarContextNode::DefineConstShorthand(PrimExpr expr) {
-  if (!this->var_defs->HasUndefVars(expr) && CountOps(expr) >= 10) {
-    expr = this->AllocVarForExpr(arith::SimplifyExpr(expr));
-  }
-  return expr;
-}
-
-Var VarContextNode::AllocVarForExpr(PrimExpr expr) {
-  std::string name = "v" + std::to_string(this->var_defs->Size());
-  return this->var_defs->FindOrAppend(name, expr);
-}
-
 std::pair<PrimExpr, PrimExpr> VarContextNode::SymbolicDiv(PrimExpr numer, PrimExpr denom,
                                                           bool no_tighten_factor) {
   PrimExpr simpl = SimplifyExpr(numer / denom);
   if (!HasDiv(simpl)) {
     return {denom, simpl};
   }
-  for (auto &[extent, subst]: this->div_extents) {
+  for (auto& [extent, subst] : this->div_extents) {
     if (IsExprEquivalent(numer, extent)) {
       PrimExpr simpl = SimplifyExpr(subst / denom);
       if (!HasDiv(simpl)) {
diff --git a/src/felix/constraints.cc b/src/felix/constraints.cc
new file mode 100644
index 000000000..6fd5f2503
--- /dev/null
+++ b/src/felix/constraints.cc
@@ -0,0 +1,207 @@
+#include <tvm/auto_scheduler/search_task.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <numeric>
+
+#include "../tir/transforms/ir_utils.h"
+#include "features.h"
+
+namespace tvm {
+namespace felix {
+
+using namespace tvm::tir;
+
+class GPUConstraintsMaker : public StmtExprVisitor {
+  // TODO(jcf94): Add support of detecting CUDA Misaligned Address error
+ public:
+  explicit GPUConstraintsMaker(size_t max_local_memory_per_block,
+                               size_t max_shared_memory_per_block, size_t max_threads_per_block,
+                               size_t max_vthread, size_t max_vector_size, size_t max_vector_bytes)
+      : max_local_memory_per_block(max_local_memory_per_block),
+        max_shared_memory_per_block(max_shared_memory_per_block),
+        max_threads_per_block(max_threads_per_block),
+        max_vthread(max_vthread),
+        max_vector_size(max_vector_size),
+        max_vector_bytes(max_vector_bytes) {}
+
+  void RunOnStmt(Stmt stmt) {
+    Reset_();
+    this->VisitStmt(stmt);
+  }
+
+  void VisitStmt_(const AllocateNode* op) final {
+    StmtVisitor::VisitStmt_(op);
+    // visit an allocation of a buffer in shared memory, record its size
+    auto scope = GetPtrStorageScope(op->buffer_var);
+    CountBufferSize_(op->extents, op->dtype, scope);
+    if (op->dtype.lanes() > 1) {
+      CheckVectorBytes_(op->dtype);
+    }
+  }
+
+  void VisitStmt_(const BufferRealizeNode* op) final {
+    StmtVisitor::VisitStmt_(op);
+    auto scope = GetPtrStorageScope(op->buffer->data);
+    Array<PrimExpr> extents;
+    for (auto& range : op->bounds) {
+      extents.push_back(range->extent);
+    }
+    CountBufferSize_(extents, op->buffer->dtype, scope);
+    if (op->buffer->dtype.lanes() > 1) {
+      CheckVectorBytes_(op->buffer->dtype);
+    }
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    auto attr = op->attr_key;
+    bool is_thread = attr == tir::attr::thread_extent;
+    bool is_vthread = attr == tir::attr::virtual_thread;
+    bool is_unroll = attr == tir::attr::pragma_auto_unroll_max_step;
+    if (!is_thread && !is_vthread && !is_unroll) {
+      StmtVisitor::VisitStmt_(op);
+      return;
+    }
+
+    if (this->nest_level == 0) {
+      // enter a new kernel, reset statistics
+      Reset_();
+      kernels_launched++;
+    }
+
+    // record the number of threads in a block
+    Var var = op->node.as<IterVarNode>()->var;
+    std::string name = var.get()->name_hint;
+    PrimExpr extent = op->value;
+    if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" ||
+        name == "vthread") {
+      // record the number of threads in a block
+      if (!this->visited_threads.count(name)) {
+        this->visited_threads.insert(name);
+        this->thread_per_block *= extent;
+      }
+      // else: the thread should be bound to axes with the same length
+      // but we don't check this here, as it can be difficult to
+      // compare the equality of two expressions.
+    }
+
+    this->nest_level++;
+    StmtVisitor::VisitStmt_(op);
+    this->nest_level--;
+
+    if (this->nest_level == 0) {
+      // exit a kernel, check the validity
+      AddConstraint_(this->thread_per_block, this->max_threads_per_block);
+      AddConstraint_(this->local_memory_per_block, this->max_local_memory_per_block);
+      AddConstraint_(this->shared_memory_per_block, this->max_shared_memory_per_block);
+    }
+  }
+
+  void VisitStmt_(const ForNode* op) {
+    if (op->kind == ForKind::kVectorized) {
+      AddConstraint_(op->extent, this->max_vector_size);
+    }
+    StmtVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const LoadNode* op) {
+    if (op->dtype.lanes() > 1) {
+      CheckVectorBytes_(op->dtype);
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const StoreNode* op) {
+    if (op->value->dtype.lanes() > 1) {
+      CheckVectorBytes_(op->value->dtype);
+    }
+    StmtVisitor::VisitStmt_(op);
+  }
+
+ private:
+  size_t max_local_memory_per_block;
+  size_t max_shared_memory_per_block;
+  size_t max_threads_per_block;
+  size_t max_vthread;
+  size_t max_vector_size;
+  size_t max_vector_bytes;
+
+  std::unordered_set<std::string> visited_threads{};
+  PrimExpr local_memory_per_block = 0;
+  PrimExpr shared_memory_per_block = 0;
+  PrimExpr thread_per_block = 1;
+  size_t kernels_launched = 0;
+  size_t nest_level = 0;
+
+ public:
+  std::vector<PrimExpr> constraints{};
+  std::vector<String> errors{};
+
+ private:
+  void AddConstraint_(PrimExpr lhs, size_t rhs) {
+    auto con = lhs <= Integer(rhs);
+    if (auto* simp_bool = con.as<IntImmNode>()) {
+      if (simp_bool->value == 0) {
+        std::stringstream s;
+        s << "Constraint " << lhs << " <= " << rhs << " is trivially False";
+        this->errors.push_back(s.str());
+      }
+    } else {
+      this->constraints.push_back(con);
+    }
+  }
+
+  void CountBufferSize_(const Array<PrimExpr>& extents, DataType dtype, const String& scope) {
+    PrimExpr one = Integer(1);
+    PrimExpr alloc_count =
+        std::accumulate(extents.begin(), extents.end(), one,
+                        [](const PrimExpr& a, const PrimExpr& b) { return a * b; });
+    PrimExpr alloc_size = alloc_count * dtype.bytes() * dtype.lanes();
+    if (scope == "local") {
+      this->local_memory_per_block += alloc_size;
+    } else if (scope == "shared") {
+      this->shared_memory_per_block += alloc_size;
+    }
+  }
+
+  void CheckVectorBytes_(DataType dtype) {
+    if (static_cast<size_t>(dtype.lanes() * dtype.bytes()) > this->max_vector_bytes) {
+      std::stringstream s;
+      s << "Number of lanes (" << dtype.lanes() << ") times number of bytes (" << dtype.bytes()
+        << ") for dtype " << dtype << " is greater than the maximum number of vector bytes ("
+        << this->max_vector_bytes << ")";
+      this->errors.push_back(s.str());
+    }
+  }
+
+  void Reset_() {
+    this->local_memory_per_block = 0;
+    this->shared_memory_per_block = 0;
+    this->visited_threads.clear();
+  }
+};
+
+std::vector<PrimExpr> GetConstraints(const Stmt& code,
+                                     const auto_scheduler::HardwareParams& hw_params) {
+  // Run GPU verification pass to inject constraints.
+  // HACK: size of 4 is based on src/target/source/codegen_cuda.cc.
+  //   - line 246: bool vector is only supported when size is less than 4
+  //   - line 351: vector of int/uint is only supported when size is less than 8.
+  //   While this is true for all CUDA devices and will stay so for a while,
+  //   it's not a good idea to hardcode it here.
+  GPUConstraintsMaker cmaker(hw_params->max_local_memory_per_block,
+                             hw_params->max_shared_memory_per_block,
+                             hw_params->max_threads_per_block, hw_params->max_vthread_extent,
+                             /*max_vector_size*/ 4, hw_params->vector_unit_bytes);
+  cmaker.RunOnStmt(code);
+  if (!cmaker.errors.empty()) {
+    LOG_WARNING << "Code constraint check failed: ";
+    for (auto& err : cmaker.errors) {
+      LOG_WARNING << "  " << err;
+    }
+    LOG_WARNING << "Failed code: " << code;
+    return {};
+  }
+  return cmaker.constraints;
+}
+}  // namespace felix
+}  // namespace tvm
\ No newline at end of file
diff --git a/src/felix/feat_transform.cc b/src/felix/feat_transform.cc
new file mode 100644
index 000000000..8a9f04380
--- /dev/null
+++ b/src/felix/feat_transform.cc
@@ -0,0 +1,1003 @@
+#include <tvm/arith/egg_simpl.h>
+#include <tvm/arith/var_context.h>
+#include <tvm/auto_scheduler/feature.h>
+#include <tvm/support/parallel_for.h>
+
+#include <optional>
+#include <variant>
+
+#include "features.h"
+#include "utils.h"
+
+namespace tvm {
+namespace felix {
+using namespace tvm::arith;
+
+std::unordered_map<uint64_t, uint64_t> Factorize(uint64_t n) {
+  std::unordered_map<uint64_t, uint64_t> factors;
+  for (uint64_t i = 2; i <= n; ++i) {
+    while (n % i == 0) {
+      factors[i] += 1;
+      n /= i;
+    }
+  }
+  return factors;
+}
+
+template <typename T>
+void CollectSameOps(const T* e, std::vector<PrimExpr>& ret) {
+  auto *lhs = e->a.template as<T>(), *rhs = e->b.template as<T>();
+  if (lhs) {
+    CollectSameOps(lhs, ret);
+  } else {
+    ret.push_back(e->a);
+  }
+  if (rhs) {
+    CollectSameOps(rhs, ret);
+  } else {
+    ret.push_back(e->b);
+  }
+}
+
+template <typename T>
+std::vector<PrimExpr> CollectSameOps(const PrimExpr& e) {
+  std::vector<PrimExpr> ret;
+  auto* node = e.as<T>();
+  if (node) {
+    CollectSameOps(node, ret);
+  } else {
+    ret.push_back(e);
+  }
+  return ret;
+}
+
+// size_t is not (de)serializable, so we use uint64_t instead
+using DecompT = std::unordered_map<std::string, std::unordered_map<uint64_t, SizeVar>>;
+
+class LinearExprNode : public Object {
+ public:
+  Map<SizeVar, FloatImm> lin_terms;
+  FloatImm constant;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("lin_terms", &lin_terms);
+    v->Visit("constant", &constant);
+  }
+
+  static constexpr const char* _type_key = "ansor.LinearExpr";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LinearExprNode, Object);
+};
+
+TVM_REGISTER_NODE_TYPE(LinearExprNode);
+
+class LinearExpr : public ObjectRef {
+ public:
+  explicit LinearExpr(double constant) {
+    auto node = make_object<LinearExprNode>();
+    node->constant = ToFloatImm(constant);
+    this->data_ = std::move(node);
+  }
+
+  explicit LinearExpr(SizeVar var) {
+    auto node = make_object<LinearExprNode>();
+    node->constant = ToFloatImm(0.0f);
+    node->lin_terms.Set(var, ToFloatImm(1.0f));
+    this->data_ = std::move(node);
+  }
+
+  LinearExpr(double constant, const std::unordered_map<std::string, double>& name2coef) {
+    auto node = make_object<LinearExprNode>();
+    node->constant = ToFloatImm(0.0f);
+    for (auto& [name, coef] : name2coef) {
+      node->lin_terms.Set(SizeVar(name, SizeVarKind::kScheduleKnob), ToFloatImm(coef));
+    }
+    this->data_ = std::move(node);
+  }
+
+  PrimExpr ToPrimExpr() const {
+    auto node = this->operator->();
+    PrimExpr ret = node->constant;
+    for (auto& [var, coef] : node->lin_terms) {
+      ret = ret + var * coef;
+    }
+    return ret;
+  }
+
+  LinearExpr& operator+=(const LinearExpr& other) {
+    ICHECK(this->defined() && other.defined());
+    auto this_ = this->CopyOnWrite();
+    this_->constant = ToFloatImm(this_->constant->value + other->constant->value);
+    for (auto& [var, coef1] : other->lin_terms) {
+      auto coef2 = this_->lin_terms.Get(var).value_or(ToFloatImm(0.0f));
+      this_->lin_terms.Set(var, ToFloatImm(coef1->value + coef2->value));
+    }
+    return *this;
+  }
+  LinearExpr& operator*=(double other) {
+    ICHECK(this->defined());
+    auto this_ = this->CopyOnWrite();
+    this_->constant = ToFloatImm(this_->constant->value * other);
+    for (auto& [var, coef] : this_->lin_terms) {
+      this_->lin_terms.Set(var, ToFloatImm(coef->value * other));
+    }
+    return *this;
+  }
+  LinearExpr& operator-=(const LinearExpr& other) { return *this += LinearExpr(-1.0f) * other; }
+  LinearExpr& operator/=(double other) { return *this *= (1.0f / other); }
+
+#define DEF_BINARY_OP(def, with, other_t, check)        \
+  LinearExpr operator def(const other_t& other) const { \
+    if (!this->defined() || check) {                    \
+      return LinearExpr();                              \
+    }                                                   \
+    LinearExpr ret = *this;                             \
+    ret with other;                                     \
+    return ret;                                         \
+  }
+
+  DEF_BINARY_OP(+, +=, LinearExpr, !other.defined())
+  DEF_BINARY_OP(-, -=, LinearExpr, !other.defined())
+  DEF_BINARY_OP(*, *=, double, false)
+  DEF_BINARY_OP(/, /=, double, false)
+
+  LinearExpr operator*(LinearExpr other) {
+    if (!this->defined() || !other.defined()) {
+      return LinearExpr();
+    } else if ((*this)->lin_terms.empty()) {
+      return other * (*this)->constant->value;
+    } else if (other->lin_terms.empty()) {
+      return (*this) * other->constant->value;
+    } else {
+      return LinearExpr();
+    }
+  }
+  LinearExpr operator/(LinearExpr other) {
+    if (!this->defined() || !other.defined()) {
+      return LinearExpr();
+    } else if (other->lin_terms.empty()) {
+      return (*this) / other->constant->value;
+    } else {
+      return LinearExpr();
+    }
+  }
+
+  TVM_DEFINE_OBJECT_REF_METHODS(LinearExpr, ObjectRef, LinearExprNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(LinearExprNode);
+};
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<LinearExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* expr = static_cast<const LinearExprNode*>(node.get());
+      auto& os = p->stream;
+      os << "(" << expr->constant << "; ";
+      bool first = true;
+      for (auto& [var, coef] : expr->lin_terms) {
+        if (!first) {
+          os << ", ";
+        }
+        os << "(" << coef << " * " << var << ")";
+        first = false;
+      }
+      os << ")";
+    });
+
+class LinExprExtractor : public ExprFunctor<LinearExpr(const PrimExpr&)> {
+  LinearExpr VisitExpr_(const SizeVarNode* e) override { return LinearExpr(GetRef<SizeVar>(e)); }
+  LinearExpr VisitExpr_(const IntImmNode* e) override { return LinearExpr((double)e->value); }
+  LinearExpr VisitExpr_(const FloatImmNode* e) override { return LinearExpr(e->value); }
+  LinearExpr VisitExpr_(const VarNode* e) override {
+    LOG_FATAL << "Do not use LinExprExtractor on expressions with non-sizevar variable; got "
+              << GetRef<Var>(e);
+    return LinearExpr();
+  }
+
+  LinearExpr VisitExpr_(const AddNode* e) override { return VisitExpr(e->a) + VisitExpr(e->b); }
+  LinearExpr VisitExpr_(const SubNode* e) override { return VisitExpr(e->a) - VisitExpr(e->b); }
+  LinearExpr VisitExpr_(const MulNode* e) override { return VisitExpr(e->a) * VisitExpr(e->b); }
+  LinearExpr VisitExpr_(const DivNode* e) override { return VisitExpr(e->a) / VisitExpr(e->b); }
+  LinearExpr VisitExpr_(const CastNode* e) override { return VisitExpr(e->value); }
+
+  LinearExpr VisitExprDefault_(const Object* e) override { return LinearExpr(); }
+};
+
+class FloorRemover : public ExprMutator {
+  PrimExpr VisitExpr_(const FloorDivNode* e) final {
+    // Simply drop the floor() operator; it's not differentiable.
+    return div(VisitExpr(e->a), VisitExpr(e->b));
+  }
+  PrimExpr VisitExpr_(const FloorModNode* e) final {
+    LOG_WARNING << "Mod operator is not differentiable and will be dropped.";
+    return Integer(0);
+  }
+};
+
+class DiffableApprox : public MemoizedExprFunctor<PrimExpr> {
+ public:
+  DiffableApprox(const DiffableApprox& other) = delete;
+  DiffableApprox() = delete;
+
+  // * Important to use Float values (Range(1.0, 5.0)) for range inf default range.
+  DiffableApprox(const std::unordered_set<std::string>& new_knobs, const VarMapT& exp_subst,
+                 const VarMapT& shorthands, const VarMapT& quotients)
+      : rinf(Range(ToFloatImm(1.0), ToFloatImm(5.0))),
+        new_knobs(new_knobs),
+        exp_subst(exp_subst),
+        shorthands(shorthands),
+        quotients(quotients) {}
+
+  // Simplification strategies:
+  // 1. `K < sp_i_j_b + sp_i_j_b + ...` (K is actual constant)
+  //    - We're good; just return sigmoid(RHS - LHS).
+  // 2. `1 < sp_i_j`
+  //    - Replace sp_i_j with its exp decomposition, simplify, should give us Form 1.
+  // 3. Anything expr with shorthand variables
+  //    - Substitute with shorthand vars and simplify; if changed, revisit.
+  // 4. Now safe to use special simplification (safe on expressions that only contain `sp_i_j` and
+  //    quotient vars `d{i}`, and shape constants)
+  //    - This helps reduce forms such as `1 < sp_i_j * sp_i'_j' * ...` to smaller forms.
+  //    - If expr contains quotients d{i}, try twice: expanding d{i} = E{i} / sp_i_j / sp_i_j' ... ,
+  //      see which one is shorter.
+  // 5. If everything fall through, try taking log on both sides (with range inference to help with
+  // safety).
+  //
+  // Step 3-5 also applies to VisitExpr_(EQ).
+
+  PrimExpr VisitExpr_(const LTNode* e) final {
+    auto expr = GetRef<PrimExpr>(e);
+    LinearExpr diff = IsTermLinear(e->b - e->a);
+    if (diff.defined()) {
+      return MakeDiffableCond(diff.ToPrimExpr(), /* sigmoid_or_hump */ true);
+    }
+    PrimExpr response;
+    if ((response = ConstVsSingleSchedVar(e->a, e->b)).defined()) {
+      response = SimplifyExpr(LT(e->a, response), 20, 1000, true);
+      return VisitExpr(response);
+    }
+    if ((response = SubstShorthandsAway(expr)).defined()) {
+      return VisitExpr(response);
+    }
+    if ((response = SpecialSimplAndVisit(expr)).defined()) {
+      return response;
+    }
+    if ((response = TakeLogDiffIfSafe(e->b, e->a)).defined()) {
+      return MakeDiffableCond(response, /* sigmoid_or_hump */ true);
+    }
+    LOG_FATAL << "Cannot simplify " << e->a << " < " << e->b;
+    return PrimExpr();
+  }
+
+  PrimExpr VisitExpr_(const EQNode* e) final {
+    auto expr = GetRef<PrimExpr>(e);
+    PrimExpr response;
+    if ((response = SubstShorthandsAway(expr)).defined()) {  // Step 1
+      return VisitExpr(response);
+    }
+    if ((response = SpecialSimplAndVisit(expr)).defined()) {
+      return response;
+    }
+    if ((response = TakeLogDiffIfSafe(e->a, e->b)).defined()) {
+      return MakeDiffableCond(response, /* sigmoid_or_hump */ false);
+    }
+    LOG_ERROR << "Cannot simplify " << e->a << " == " << e->b;
+    return PrimExpr();
+  }
+
+  PrimExpr VisitExpr_(const IntImmNode* e) final {
+    // Converts true into 1 and false into 0.
+    return Integer(e->value);
+  }
+
+  PrimExpr VisitExpr_(const SelectNode* e) final {
+    if (ExprIsShapeConstant(e->condition)) {
+      return select(SimplifyExpr(e->condition, 20, 1000), VisitExpr(e->true_value),
+                    VisitExpr(e->false_value));
+    }
+    auto cond = VisitExpr(e->condition), tv = VisitExpr(e->true_value),
+         fv = VisitExpr(e->false_value);
+    return cond * (tv - fv) + fv;
+  }
+
+  // 1 - x below stands for `not x`
+  PrimExpr VisitExpr_(const LENode* e) final { return 1 - this->VisitExpr(LT(e->b, e->a)); }
+  PrimExpr VisitExpr_(const GENode* e) final { return 1 - this->VisitExpr(LT(e->a, e->b)); }
+  PrimExpr VisitExpr_(const GTNode* e) final { return this->VisitExpr(LT(e->b, e->a)); }
+  PrimExpr VisitExpr_(const NENode* e) final { return 1 - this->VisitExpr(EQ(e->a, e->b)); }
+
+  PrimExpr VisitExpr_(const AndNode* e) final { return VisitExpr(e->a) * VisitExpr(e->b); }
+  PrimExpr VisitExpr_(const OrNode* e) final {
+    std::vector<PrimExpr> chained_ors;
+    CollectSameOps(e, chained_ors);
+    // Using de Morgan's law:
+    PrimExpr ret = Integer(1);
+    for (auto& expr : chained_ors) {
+      ret *= 1 - VisitExpr(expr);
+    }
+    return 1 - ret;
+  }
+  PrimExpr VisitExpr_(const NotNode* e) final { return 1 - VisitExpr(e->a); }
+
+  LinearExpr IsTermLinear(const PrimExpr& diff) {
+    auto lin_diff = LinExprExtractor()(diff);
+    if (!lin_diff.defined()) {
+      return LinearExpr();
+    }
+    bool all_vars_allowed = true;
+    for (auto& [var, _] : lin_diff->lin_terms) {
+      if (!this->new_knobs.count(var->name_hint)) {
+        all_vars_allowed = false;
+      }
+    }
+    return all_vars_allowed ? lin_diff : LinearExpr();
+  }
+
+ private:
+  PrimExpr ConstVsSingleSchedVar(const PrimExpr& lhs, const PrimExpr& rhs) {
+    auto* rhs_v = rhs.as<SizeVarNode>();
+    if (lhs->IsInstance<IntImmNode>() && rhs_v) {
+      auto it = this->exp_subst.find(rhs_v->name_hint);
+      if (it != this->exp_subst.end()) {
+        return it->second;
+      }
+    }
+    return PrimExpr();
+  }
+
+  PrimExpr SubstShorthandsAway(const PrimExpr& expr) {
+    t4.Start();
+    bool changed = false;
+    auto ret = SubstByName(expr, this->shorthands, &changed);
+    ret = changed ? SimplifyExpr(ret) : PrimExpr();
+    t4.Stop();
+    return ret;
+  }
+
+  PrimExpr SpecialSimplAndVisit(const PrimExpr& expr) {
+    if (!this->progressed) {
+      t5.Start();
+      auto direct_simpl = SpecialSimplIfSafe(expr);
+      t5.Stop();
+      if (direct_simpl.defined()) {
+        this->progressed = true;
+        auto ret = VisitExpr(direct_simpl);
+        this->progressed = false;
+        return ret;
+      }
+    }
+    return PrimExpr();
+  }
+
+  PrimExpr SpecialSimplIfSafe(const PrimExpr& expr) {
+    bool safe_vars_only = true, has_quotients = false;
+    PostOrderVisit(expr, [this, &safe_vars_only, &has_quotients](const ObjectRef& obj) {
+      if (auto* var = obj.as<VarNode>()) {
+        if (!obj->IsInstance<SizeVarNode>() || this->shorthands.count(var->name_hint)) {
+          safe_vars_only = false;
+        } else if (this->quotients.count(var->name_hint)) {
+          has_quotients = true;
+        }
+      }
+    });
+    if (!safe_vars_only) {
+      return PrimExpr();
+    }
+    auto simpl = SimplifyExpr(expr, 30, 10000, true);
+    if (has_quotients) {
+      auto subbed = SubAndSimplify(expr, this->quotients);
+      return CountOps(subbed) < CountOps(expr) ? subbed : simpl;
+    } else {
+      return simpl;
+    }
+  }
+
+  PrimExpr TakeLogDiffIfSafe(PrimExpr a, PrimExpr b) {
+    double amin, amax, bmin, bmax;
+    ICHECK(ToConstRange(this->rinf(a), amin, amax));
+    ICHECK(ToConstRange(this->rinf(b), bmin, bmax));
+    if (amin <= 0 || bmin <= 0) {
+      LOG_WARNING << "Cannot take diff of " << PrintExprPrefix(a) << " (" << amin << ") and "
+                  << PrintExprPrefix(b) << " (" << bmin << ") to a range stable form";
+    }
+    auto ret = SimplifyExpr(log(a) - log(b), 20, 5000, true);
+    return ret;
+  }
+
+  PrimExpr MakeDiffableCond(const PrimExpr& e, bool sigmoid_or_hump) {
+    auto* e_int = e.as<IntImmNode>();
+    if (e_int) {
+      if (sigmoid_or_hump) {
+        return e_int->value > 0 ? 1 : 0;
+      } else {
+        return e_int->value == 0 ? 1 : 0;
+      }
+    }
+    String var_name = "da_" + std::to_string(this->vdefs_out.size());
+    this->vdefs_out[var_name] = sigmoid_or_hump ? sigmoid(e) : hump(e);
+    return SizeVar(var_name, SizeVarKind::kShorthand);
+  }
+
+ public:
+  VarMapT vdefs_out;
+
+ private:
+  Timer t4{"SubstShorthandsAway"}, t5{"SpecialSimplAndVisit"};
+  RangeInfer rinf;
+  std::unordered_set<std::string> new_knobs;
+  const VarMapT &exp_subst, &shorthands, &quotients;
+  bool progressed{};
+};
+
+class FeaturePackPyNode : public Object {
+ public:
+  Array<Array<ObjectRef>> expressions;
+  Array<String> free_vars;
+  Array<LinearExpr> linear_cons;
+  Map<String, Map<Integer, SizeVar>> var_decomp;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("expressions", &expressions);
+    v->Visit("free_vars", &free_vars);
+    v->Visit("linear_cons", &linear_cons);
+    v->Visit("var_decomp", &var_decomp);
+  }
+
+  static constexpr const char* _type_key = "ansor.FeaturePackPy";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FeaturePackPyNode, Object);
+};
+
+TVM_REGISTER_NODE_TYPE(FeaturePackPyNode);
+
+class FeaturePackPy : public ObjectRef {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(FeaturePackPy, ObjectRef, FeaturePackPyNode);
+};
+
+class FeaturePack {
+ public:
+  FeaturePack() = default;
+
+  FeaturePack(VarDefStack vdefs, VarDefStack features, VarDefStack constraints,
+              Array<SplitGroup> sp_groups)
+      : variables({{"vdefs", std::move(vdefs)},
+                   {"features", std::move(features)},
+                   {"constraints", std::move(constraints)}}),
+        split_groups(std::move(sp_groups)) {}
+
+  void RunRollingSimplify() {
+    // For every expression `e` in vdefs,
+    // for each variable `v` in `e` that's defined in vdefs as `v = e_v`,
+    // if substituting `v = e_v` into `e` (and simplifying) makes `e` smaller, do it.
+    auto& vdefs = this->variables.at("vdefs");
+    vdefs->MapExprs([&vdefs](const PrimExpr& expr) {
+      std::unordered_set<const VarNode*> vars;
+      PostOrderVisit(expr, [&vdefs, &vars](const ObjectRef& node) {
+        auto* var = node.as<SizeVarNode>();
+        if (var && vdefs->Contains(var->name_hint)) {
+          vars.insert(node.as<VarNode>());
+        }
+      });
+      auto expr_ = expr;
+      for (auto& var : vars) {
+        auto& to_sub = vdefs->GetExprAt(var->name_hint);
+        auto subbed = arith::SubAndSimplify(expr, {{var->name_hint, to_sub}});
+        if (CountOps(subbed) < CountOps(expr)) {
+          expr_ = subbed;
+        }
+      }
+      return expr_;
+    });
+  }
+
+  void RunExpTransform(const std::vector<size_t>& prime_bases) {
+    // Concat `features` to the end of `vdefs` to form `vi_and_features`,
+    // and split out E{i} variables from vdefs into `ei_vars`, the rest into `vi_vars`.
+    VarDefStack& vi_and_features = this->variables["vi_and_features"];
+    {
+      VarDefStack &vi_vars = this->variables["vi_vars"], &ei_vars = this->variables["ei_vars"];
+      for (auto& pair : this->variables.at("vdefs")->GetExprs()) {
+        if (pair->var->kind == SizeVarKind::kShapeVar) {
+          ei_vars->Append(pair->var, pair->expr);
+        } else {
+          vi_vars->Append(pair->var, pair->expr);
+          vi_and_features->Append(pair->var, pair->expr);
+        }
+      }
+      for (auto& pair : this->variables.at("features")->GetExprs()) {
+        vi_and_features->Append(pair->var, pair->expr);
+      }
+      this->variables.erase("vdefs");
+      this->variables.erase("features");
+    }
+    // List all schedule vars (such as sp_i_j), and create exp decomposition for them,
+    // inserting into this->var_decomp and `exp_decomp`.
+    auto& exp_decomp = this->variables["exp_decomp"];
+    auto all_var_names = vi_and_features->GetAllUsedVars(SizeVarKind::kScheduleKnob);
+    // 1. For variables in `SplitGroup`s, we create a new variable for each prime base p:
+    for (auto& group : this->split_groups) {
+      auto extent_vname = group->extent->name_hint;
+      // Create sp_i_j_2, sp_i_j_3, ... for each sp_i_j and prime base p.
+      for (auto& vname : group->vars) {
+        exp_decomp->Append(vname,
+                           this->DecomposeOneVar(prime_bases, vname, SizeVarKind::kScheduleKnob));
+        all_var_names.erase(vname);
+      }
+      // Create Ei_2, Ei_3, ... for the group's extent Ei and each prime base p.
+      exp_decomp->Append(extent_vname,
+                         this->DecomposeOneVar(prime_bases, extent_vname, SizeVarKind::kShapeVar));
+      // Create representation of the quotient Ei / (sp_i_0 * sp_i_1 * ...)
+      // as a sum in the power: 2**(Ei_2 - sp_i_0_2 - sp_i_1_2 - ...) * 3**(...) * ...
+      // Also insert the constraints that Ei_b >= sp_i_0_b + sp_i_1_b + ...
+      // (essentially meaning that each power consisting q{i} is non-negative)
+      auto quotient = Integer(1);
+      for (size_t prime : prime_bases) {
+        LinearExpr q_total_power(this->var_decomp[extent_vname][prime]);
+        for (auto& vname : group->vars) {
+          q_total_power -= LinearExpr(this->var_decomp[vname][prime]);
+        }
+        this->linear_cons.push_back(q_total_power);  // Defaults to >= 0.
+        quotient *= pow(Integer(prime), q_total_power.ToPrimExpr());
+      }
+      exp_decomp->Append(group->quotient, quotient);
+    }
+    // 2. For all other variables, just do a log2.
+    for (auto& vname : all_var_names) {
+      SizeVar sv(vname + "_2", SizeVarKind::kScheduleKnob);
+      this->var_decomp[vname][2] = sv;
+      exp_decomp->Append(vname, pow(Integer(2), sv));
+    }
+  }
+
+  void RunDiffTransform(size_t n_threads) {
+    // Remove floordiv and floormod from all variables except Ei variables.
+    for (auto& [table_name, vdefs] : this->variables) {
+      if (table_name != "ei_vars") {
+        vdefs->MapExprs(FloorRemover());
+      }
+    }
+    // Prepare 2 substitution maps for the differentiability transformation:
+    // 1. exp_decomp: sp_i_j = 2**sp_i_j_2 * 3**sp_i_j_3 * ..., Ei = 2**Ei_2 * 3**Ei_3 * ..., qi =
+    //    2**qi_2 * 3**qi_3 * ...
+    //    - Also get a list of all variables on the power (sp_i_j_2, Ei_2, qi_2, ...) from
+    //    exp_decomp.
+    // 2. shorthands: vi = ... (all variables in vi_vars)
+    // 3. quotient: all the qi variables in non-exponential form:
+    //    q0 = E0 / (sp_0_0 * sp_0_1 * ...), q1 = E1 / (sp_1_0 * sp_1_1 * ...), ...
+    //    - This can help simplify expressions like (q0 * sp_0_0 * sp_0_1), without having to expand
+    //    everything into exp form.
+    auto& exp_decomp = this->variables.at("exp_decomp");
+    auto exp_sched_vars = exp_decomp->GetAllUsedVars(std::nullopt);
+    VarMapT exp_subst = exp_decomp->IntoUnwindedVarMap(),
+            shorthands = this->variables.at("vi_vars")->IntoUnwindedVarMap();
+    this->variables.erase("vi_vars");
+    VarMapT quotient;
+    for (auto& group : this->split_groups) {
+      PrimExpr extent = group->extent;
+      for (auto& vname : group->vars) {
+        extent = div(extent, SizeVar(vname, SizeVarKind::kScheduleKnob));
+      }
+      quotient[group->quotient->name_hint] = extent;
+    }
+    auto& features = this->variables.at("vi_and_features");
+    DiffableApprox da(exp_sched_vars, exp_subst, shorthands, quotient);
+    {
+      Timer timer("DA");
+      features->MapExprs([&da](const PrimExpr& e) { return da(e); });
+    }
+    auto& diff_approx = this->variables["diff_approx"];
+    for (auto& [vname, expr] : da.vdefs_out) {
+      diff_approx->Append(vname, expr);
+    }
+    VarDefStack constraints;
+    for (auto& pair : this->variables.at("constraints")->GetExprs()) {
+      auto* expr_lt = pair->expr.as<LENode>();
+      ICHECK(expr_lt);
+      auto diff = log(expr_lt->b) - log(expr_lt->a);
+      auto linexpr =
+          LinExprExtractor()(SimplifyExpr(SubstByName(SubstByName(diff, shorthands), exp_subst)));
+      if (linexpr.defined()) {
+        this->linear_cons.push_back(linexpr);
+      } else {
+        constraints->Append(pair->var, diff);
+      }
+    }
+    this->variables["constraints"] = constraints;
+  }
+
+  void RunSizeSubstitution(const Map<String, Integer>& size_subst) {
+    VarMapT size_subst_(size_subst.begin(), size_subst.end());
+    std::unordered_map<std::string, std::vector<uint64_t>> shape_value_primes;
+    for (auto& pair : this->variables.at("ei_vars")->GetExprs()) {
+      auto& vname = pair->var->name_hint;
+      auto expr = SubAndSimplify(pair->expr, size_subst_);
+      auto* expr_int = expr.as<IntImmNode>();
+      ICHECK(expr_int) << "All shape variables must be substituted to integers; got " << pair->var
+                       << " = " << expr;
+      size_subst_[vname] = expr;
+      std::unordered_map<uint64_t, uint64_t> factors = Factorize(expr_int->value);
+      auto& shape_var_decomp = this->var_decomp.at(vname);
+      for (auto& [prime, power] : factors) {
+        shape_value_primes[vname].push_back(prime);
+        auto it = shape_var_decomp.find(prime);
+        ICHECK(it != shape_var_decomp.end())
+            << "Shape variable " << vname << " = " << expr_int->value << " contains factor "
+            << prime << " that the features weren't factorized for.";
+        size_subst_[it->second->name_hint] = Integer(power);
+      }
+      for (auto& [k, v] : shape_var_decomp) {
+        if (!size_subst_.count(v->name_hint)) {
+          size_subst_[v->name_hint] = Integer(0);
+        }
+      }
+    }
+    VarDefStack concated;
+    for (auto& [table_name, vdefs] : this->variables) {
+      if (table_name != "ei_vars") {
+        for (auto& pair : vdefs->GetExprs()) {
+          concated->Append(pair->var, pair->expr);
+        }
+      }
+    }
+    concated->MapExprsParallel([this, &size_subst_](const PrimExpr& expr) {
+      return SubAndSimplify(expr, size_subst_, true);
+    });
+    for (auto& linexpr : this->linear_cons) {
+      linexpr = LinExprExtractor()(SubAndSimplify(linexpr.ToPrimExpr(), size_subst_, true));
+      std::cout << linexpr << "\n";
+      ICHECK(linexpr.defined());
+    }
+
+    this->variables.clear();
+    this->variables["vdefs"] = concated;
+  }
+
+  FeaturePackPy IntoPythonFeaturePack() const {
+    auto node = make_object<FeaturePackPyNode>();
+    auto& vdefs = this->variables.at("vdefs");
+    for (auto& pair : vdefs->GetExprs()) {
+      node->expressions.push_back({pair->var->name_hint, pair->expr});
+    }
+    node->linear_cons = this->linear_cons;
+    std::unordered_set<std::string> free_vars;
+    for (auto& [k1, vs] : this->var_decomp) {
+      Map<Integer, SizeVar> m;
+      for (auto& [k2, v] : vs) {
+        if (v->kind == SizeVarKind::kScheduleKnob) {
+          free_vars.insert(v->name_hint);
+        }
+        m.Set(k2, v);
+      }
+      node->var_decomp.Set(k1, m);
+    }
+    for (auto& vname : free_vars) {
+      node->free_vars.push_back(vname);
+    }
+    return FeaturePackPy(node);
+  }
+
+  static std::optional<FeaturePack> LoadFromJsonReader(dmlc::JSONReader& reader) {
+    bool is_defined;
+    reader.BeginArray();
+    ICHECK(reader.NextArrayItem());
+    reader.Read(&is_defined);
+    if (!is_defined) {
+      return std::nullopt;
+    }
+    FeaturePack fp;
+    ICHECK(reader.NextArrayItem());
+    reader.Read(&fp.variables);
+    ICHECK(reader.NextArrayItem());
+    reader.Read(&fp.linear_cons);
+    ICHECK(reader.NextArrayItem());
+    reader.Read(&fp.split_groups);
+    ICHECK(reader.NextArrayItem());
+    reader.Read(&fp.var_decomp);
+    ICHECK(!reader.NextArrayItem());
+    return fp;
+  }
+
+  static void SaveAsJson(const String& filepath, const std::optional<FeaturePack>& fp) {
+    std::ofstream fout(filepath);
+    dmlc::JSONWriter writer(&fout);
+    writer.BeginArray(true);
+    writer.WriteArrayItem((bool)fp);
+    if (fp) {
+      writer.WriteArrayItem(fp->variables);
+      writer.WriteArrayItem(fp->linear_cons);
+      writer.WriteArrayItem(fp->split_groups);
+      writer.WriteArrayItem(fp->var_decomp);
+    }
+    writer.EndArray();
+  }
+
+  PrimExpr DecomposeOneVar(const std::vector<size_t>& prime_bases, const std::string& vname,
+                           SizeVarKind kind) {
+    PrimExpr subst_expr = Integer(1);
+    for (size_t prime : prime_bases) {
+      SizeVar var(vname + "_" + std::to_string(prime), kind);
+      this->var_decomp[vname][prime] = var;
+      subst_expr *= pow(Integer(prime), var);
+    }
+    return subst_expr;
+  }
+
+  // void ProcessConstraint() {
+  //   auto* expr_lt = simpl.as<LTNode>();
+  //     ICHECK(expr_lt);
+  //     LinearExpr linexpr;
+  //     for (auto &piece: CollectSameOps<MulNode>(expr_lt->a)) {
+  //       auto *evar = piece.as<SizeVarNode>();
+  //       if (!evar) {
+  //         break;
+  //       }
+  //       auto it = this->var_decomp.find(evar->name_hint);
+  //       if (it == this->var_decomp.end()) {
+  //         break;
+  //       }
+  //       it->second
+  //     }
+  // }
+
+  std::unordered_map<std::string, VarDefStack> variables;
+  std::vector<LinearExpr> linear_cons{};
+  Array<SplitGroup> split_groups;
+  DecompT var_decomp{};
+};
+
+class StmtSimplifier : public StmtExprMutator {
+ public:
+  StmtSimplifier(RangeInfer& rinf, const VarDefStack& vdefs) : rinf(rinf) {
+    for (auto& pair : vdefs->GetExprs()) {
+      if (pair->var->kind == SizeVarKind::kShapeVar) {
+        this->e_vars.emplace_back(pair->var, pair->expr);
+      }
+    }
+  }
+
+  Stmt VisitStmt_(const ForNode* node) final {
+    PrimExpr extent = SimplifyExpr(this->rinf.GetMax(node->extent));
+    this->rinf.Bind(node->loop_var, Range::FromMinExtent(0, extent), true);
+    Stmt body = StmtExprMutator::VisitStmt(node->body);
+    return For(node->loop_var, 0, extent, node->kind, body);
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* node) final {
+    if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) {
+      PrimExpr extent = SimplifyExpr(this->rinf.GetMax(node->value));
+      const Var& var = node->node.as<IterVarNode>()->var;
+      this->rinf.Bind(var, Range::FromMinExtent(0, extent), true);
+      Stmt body = StmtExprMutator::VisitStmt(node->body);
+      return AttrStmt(node->node, node->attr_key, extent, body);
+    } else {
+      return StmtExprMutator::VisitStmt_(node);
+    }
+  }
+
+  // Remove tir.likely() for if-then-else
+  // which doesn't do anything for feature extraction
+  // and is hard to read.
+  Stmt VisitStmt_(const IfThenElseNode* node) final {
+    auto* call = node->condition.as<CallNode>();
+    static auto op_likely = Op::Get("tir.likely");
+    if (!call || !call->op.same_as(op_likely)) {
+      return StmtExprMutator::VisitStmt_(node);
+    }
+    return StmtExprMutator::VisitStmt(node->then_case);
+  }
+
+  Stmt VisitStmt_(const AllocateNode* op) final {
+    Array<PrimExpr> extents;
+    for (const auto& x : op->extents) {
+      extents.push_back(SimplifyExpr(this->rinf.GetMax(x)));
+    }
+    return Allocate(op->buffer_var, op->dtype, extents, op->condition,
+                    StmtExprMutator::VisitStmt(op->body), op->annotations);
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* node) final {
+    TryChangingBufferShape(node->buffer);
+    Array<Range> bounds;
+    for (auto& r : node->bounds) {
+      auto max = SimplifyExpr(this->rinf.GetMax(r->extent));
+      bounds.push_back(Range::FromMinExtent(0, max));
+    }
+    return BufferRealize(node->buffer, bounds, node->condition,
+                         StmtExprMutator::VisitStmt(node->body));
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* node) final {
+    TryChangingBufferShape(node->buffer);
+    Array<PrimExpr> indices;
+    for (auto& index : node->indices) {
+      indices.push_back(SimplifyExpr(index));
+    }
+    return BufferStore(node->buffer, StmtExprMutator::VisitExpr(node->value), indices);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* node) final {
+    TryChangingBufferShape(node->buffer);
+    Array<PrimExpr> indices;
+    for (auto& index : node->indices) {
+      indices.push_back(SimplifyExpr(index));
+    }
+    return BufferLoad(node->buffer, indices);
+  }
+
+ private:
+  void TryChangingBufferShape(const Buffer& buf) {
+    auto* buf_ = buf.get();
+    if (this->touched_bufs.count(buf_)) {
+      return;
+    }
+    Array<PrimExpr>& buf_shape = *const_cast<Array<PrimExpr>*>(&buf_->shape);
+    for (size_t i = 0; i < buf_shape.size(); ++i) {
+      // HACK: pattern-match buffer shape stored in Buffer and BufferRealize
+      // against the expressions of the E{i} variables we defined in VarContext.
+      // If we were able to define these variables earlier, we wouldn't need to do this.
+      if (auto size_var = FindEquivalentMatch(buf_shape[i])) {
+        buf_shape.Set(i, size_var.value());
+      }
+    }
+    this->touched_bufs.insert(buf_);
+  }
+
+  std::optional<SizeVar> FindEquivalentMatch(const PrimExpr& expr) {
+    for (auto& [v, e] : this->e_vars) {
+      if (arith::IsExprEquivalent(e, expr)) {
+        return v;
+      }
+    }
+    return std::nullopt;
+  }
+
+  PrimExpr SimplifyExpr(PrimExpr expr) {
+    auto it = this->_memo.find(expr);
+    if (it != this->_memo.end()) {
+      return it->second;
+    }
+    return this->_memo[expr] = arith::SimplifyExpr(expr);
+  }
+
+  RangeInfer& rinf;
+  std::unordered_set<const BufferNode*> touched_bufs;
+  std::vector<std::pair<SizeVar, PrimExpr>> e_vars;
+  std::unordered_map<PrimExpr, PrimExpr, StructuralHash, StructuralEqual> _memo;
+};
+
+void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+std::optional<FeaturePack> GetFeaturePack(Stmt stmt, VarContext context,
+                                          const auto_scheduler::HardwareParams& hw_params,
+                                          size_t cache_line_size, size_t max_n_bufs) {
+  auto st_vdefs = context->var_defs;
+  FeaturePack fp;
+  try {
+    RangeInfer rinf;
+    {
+      Timer timer("StmtSimplifier");
+      stmt = StmtSimplifier(rinf, st_vdefs)(stmt);
+    }
+    // Feature and constraint extraction
+    Timer timer("FeatureExtraction");
+    VarDefStack features =
+        GetPerStoreFeatureExpr(stmt, *(st_vdefs.get()), rinf, cache_line_size, max_n_bufs);
+    auto constraints_ = GetConstraints(stmt, hw_params);
+    VarDefStack constraints;
+    for (size_t i = 0; i < constraints_.size(); ++i) {
+      auto name = "con_" + std::to_string(i);
+      constraints->Append("con_" + std::to_string(i), constraints_[i]);
+    }
+    fp = FeaturePack(std::move(context->var_defs), std::move(features), std::move(constraints),
+                     std::move(context->split_groups));
+  } catch (const std::exception& e) {
+    LOG_WARNING << "Feature extraction failed: " << e.what();
+    return std::nullopt;
+  }
+  // Simplify features that just came out from feature extraction
+  fp.RunRollingSimplify();
+  fp.RunExpTransform({2, 3, 5, 7});
+  fp.RunDiffTransform(1);
+  return fp;
+}
+
+TVM_REGISTER_GLOBAL("auto_scheduler.GetFeaturePack")
+    .set_body_typed([](Stmt stmt, VarContext context, auto_scheduler::HardwareParams hw_params,
+                       Map<String, Integer> sizes, size_t cache_line_size, size_t max_n_bufs,
+                       bool factorize, String save_load_path) {
+      std::ifstream fin(save_load_path);
+      FeaturePack fp;
+      if (fin.is_open()) {
+        dmlc::JSONReader reader(&fin);
+        auto fp_opt = FeaturePack::LoadFromJsonReader(reader);
+        if (!fp_opt) {
+          LOG_WARNING << "File " << save_load_path << " notes previous feature extraction failure";
+          return FeaturePackPy();  // None
+        }
+        fp = fp_opt.value();
+      } else {
+        LOG_INFO << "Extracting features to save to " << save_load_path;
+        auto fp_opt = GetFeaturePack(stmt, context, hw_params, cache_line_size, max_n_bufs);
+        if (fp_opt) {
+          fp = fp_opt.value();
+          FeaturePack::SaveAsJson(save_load_path, fp);
+        } else {
+          FeaturePack::SaveAsJson(save_load_path, std::nullopt);
+          return FeaturePackPy();  // None
+        }
+      }
+      fp.RunSizeSubstitution(sizes);
+      FeaturePack::SaveAsJson(save_load_path + ".json", fp);
+      return fp.IntoPythonFeaturePack();
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.LinearExprAsPrimExpr").set_body_typed([](LinearExpr e) {
+  return e.ToPrimExpr();
+});
+
+}  // namespace felix
+}  // namespace tvm
+
+namespace dmlc {
+namespace json {
+
+template <typename V>
+struct Handler<std::unordered_map<uint64_t, V>> {
+  static void Write(JSONWriter* writer, const std::unordered_map<uint64_t, V>& map) {
+    writer->BeginArray(map.size() > 1);
+    for (auto& [k, v] : map) {
+      writer->WriteArraySeperator();
+      writer->BeginArray(false);
+      writer->WriteArrayItem(k);
+      writer->WriteArrayItem(v);
+      writer->EndArray();
+    }
+    writer->EndArray();
+  }
+
+  static void Read(JSONReader* reader, std::unordered_map<uint64_t, V>* map) {
+    map->clear();
+    reader->BeginArray();
+    while (reader->NextArrayItem()) {
+      uint64_t k;
+      V v;
+      reader->BeginArray();
+      ICHECK(reader->NextArrayItem());
+      reader->Read(&k);
+      ICHECK(reader->NextArrayItem());
+      reader->Read(&v);
+      ICHECK(!reader->NextArrayItem());
+      map->emplace(k, v);
+    }
+  }
+};
+
+template <>
+struct Handler<::tvm::felix::LinearExpr> {
+  static void Write(JSONWriter* writer, const tvm::felix::LinearExpr& e) {
+    writer->BeginArray();
+    writer->WriteArrayItem(e->constant->value);
+    std::unordered_map<std::string, double> name2coef;
+    for (auto& [var, coef] : e->lin_terms) {
+      name2coef[var->name_hint] = coef->value;
+    }
+    writer->WriteArrayItem(name2coef);
+    writer->EndArray();
+  }
+  static void Read(JSONReader* reader, tvm::felix::LinearExpr* e) {
+    reader->BeginArray();
+    ICHECK(reader->NextArrayItem());
+    double constant;
+    reader->Read(&constant);
+    ICHECK(reader->NextArrayItem());
+    std::unordered_map<std::string, double> name2coef;
+    reader->Read(&name2coef);
+    ICHECK(!reader->NextArrayItem());
+    *e = tvm::felix::LinearExpr(constant, name2coef);
+  }
+};
+}  // namespace json
+}  // namespace dmlc
diff --git a/src/felix/features.cc b/src/felix/features.cc
new file mode 100644
index 000000000..fdcfa392d
--- /dev/null
+++ b/src/felix/features.cc
@@ -0,0 +1,1060 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "features.h"
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+
+#include <algorithm>
+#include <optional>
+#include <unordered_map>
+#include <vector>
+
+#include "utils.h"
+
+namespace tvm {
+namespace felix {
+
+using namespace tvm::tir;
+using namespace tvm::arith;
+
+// The number of samples to extract for arithmetic intensity curves
+// static const constexpr int ARITH_INTENSITY_CURVE_SAMPLE_N = 10;
+
+// Annotation position encoding
+enum class AnnotationPosType : int {
+  kPosNone = 0,           // Does not have this kind of annotation
+  kPosInnerSpatial = 1,   // The annotated iterator is the innermost spatial iterator
+  kPosMiddleSpatial = 2,  // The annotated iterator is a middle spatial iterator
+  kPosOuterSpatial = 3,   // The annotated iterator is the outermost spatial iterator
+  kPosInnerReduce = 4,    // The annotated iterator is the innermost reduce iterator
+  kPosMiddleReduce = 5,   // The annotated iterator is a middle reduce iterator
+  kPosOuterReduce = 6,    // The annotated iterator is the outermost reduce iterator
+  kPosMixed = 7           // The annotated iterator is a mixed space and reduce iterator
+};
+
+// Buffer access type
+enum class BufferAccessType : int { kRead = 0, kWrite = 1, kReadWrite = 2, kUnknownRW = 3 };
+
+// Accesses to a buffer
+struct BufferAccess {
+  // data reuse type
+  BufferAccessType acc_type{BufferAccessType::kUnknownRW};
+  // Use a two-dimensional array to store multiple multi-dimensional accesses.
+  // The innermost vector stores the multi-dimensional indices of one access.
+  std::vector<std::vector<PrimExpr>> indices;
+};
+
+// Feature for an access of a buffer
+struct BufferAccessFeature {
+  std::string buffer_name;    // The name of the buffer
+  BufferAccessType acc_type;  // The type of the access
+  PrimExpr bytes;             // The touched memory in bytes
+  PrimExpr unique_bytes;      // The touched unique memory in bytes
+  PrimExpr lines;             // The number of touched cache lines
+  PrimExpr unique_lines;      // The number touched unique cache lines
+  // Types of data reuse
+  PrimExpr multi_read_cond, serial_multi_rw_cond, no_reuse_cond;
+  PrimExpr reuse_dis_iter;           // The reuse distance in iterator number
+  PrimExpr reuse_dis_bytes;          // The reuse distance in total touched bytes
+  PrimExpr reuse_ct;                 // The reuse ratio
+  PrimExpr bytes_d_reuse_ct;         // bytes / reuse_ct
+  PrimExpr unique_bytes_d_reuse_ct;  // unique_bytes / reuse_ct
+  PrimExpr lines_d_reuse_ct;         // lines / reuse_ct
+  PrimExpr unique_lines_d_reuse_ct;  // unique_lines / reuse_ct
+  PrimExpr stride;                   // The stride in access
+};
+
+// Feature set of a BufferStore statement
+struct FeatureSet {
+  // Group 1: Computation related features
+  PrimExpr float_mad;               // The number of float MAD (Multiply–add) ops
+  PrimExpr float_addsub;            // The number of float add and sub ops
+  PrimExpr float_mul;               // The number of float multiply ops
+  PrimExpr float_divmod;            // The number of float div and mod ops
+  PrimExpr float_cmp;               // The number of float comparison ops
+  PrimExpr float_math_func;         // The number of float math func calls
+  PrimExpr float_other_func;        // The number of other float func calls
+  PrimExpr int_mad;                 // The number of integer MAD (Multiply–add) ops
+  PrimExpr int_addsub;              // The number of integer add and sub ops
+  PrimExpr int_mul;                 // The number of float multiply ops
+  PrimExpr int_divmod;              // The number of float div and mod ops
+  PrimExpr int_cmp;                 // The number of float comparison ops
+  PrimExpr int_math_func;           // The number of float math func calls
+  PrimExpr int_other_func;          // The number of other float func calls
+  PrimExpr bool_op;                 // The number of bool ops
+  PrimExpr select_op;               // The number of select ops
+  PrimExpr vec_num;                 // The number of vectorized iterators
+  PrimExpr vec_prod;                // The product of the lengths of vectorized iterators
+  PrimExpr vec_len;                 // The length of the innermost vectorized iterator
+  AnnotationPosType vec_type;       // The type of vectorization position
+  PrimExpr unroll_num;              // The number of unrolled iterators
+  PrimExpr unroll_prod;             // The product of the lengths of vectorized iterators
+  PrimExpr unroll_len;              // The length of the innermost unrolled iterator
+  AnnotationPosType unroll_type;    // The type of unroll position
+  PrimExpr parallel_num;            // The number of paralleled iterators
+  PrimExpr parallel_prod;           // The product of the lengths of paralleled iterators
+  PrimExpr parallel_len;            // The length of the innermost paralleled iterators
+  AnnotationPosType parallel_type;  // The type of parallel position
+  PrimExpr is_gpu;                  // Whether it is a GPU task
+  PrimExpr blockIdx_x_len;          // The length of blockIdx.x
+  PrimExpr blockIdx_y_len;          // The length of blockIdx.y
+  PrimExpr blockIdx_z_len;          // The length of blockIdx.z
+  PrimExpr threadIdx_x_len;         // The length of threadIdx.x
+  PrimExpr threadIdx_y_len;         // The length of threadIdx.y
+  PrimExpr threadIdx_z_len;         // The length of threadIdx.z
+  PrimExpr vthread_len;             // The length of virtual thread
+
+  // Group 2: Buffer access related features (per buffer)
+  std::vector<BufferAccessFeature> access_feas;
+
+  // Group 3: Arithmetic intensity related features
+  // PrimExpr arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N];  // points sampled from the
+  //                                                                  // arithmetic intensity curve
+
+  // Group 4: Allocation related features
+  PrimExpr alloc_size;        // The size of allocated buffer in bytes
+  PrimExpr alloc_outer_prod;  // The product of lengths of loops outside the scope of the allocation
+  PrimExpr alloc_inner_prod;  // The product of lengths of loops inside the score of the allocation
+  PrimExpr alloc_prod;        // alloc_outer_prod * alloc_inner_prod
+
+  // Group 5: Outer scope related features
+  PrimExpr outer_prod;            // The product of lengths of outer loops
+  PrimExpr num_loops;             // The number of outer loops
+  PrimExpr auto_unroll_max_step;  // The value of pragma "auto_unroll_max_step"
+};
+
+namespace {
+
+// Return whether a var is in an expr
+bool VarInExpr(const Var& var, const PrimExpr& expr) {
+  bool found = false;
+
+  // Find by name, because TVM duplicates some loops such as threadIdx.x,
+  // creating 2 loop variables that have the same name but are different as objects.
+  PostOrderVisit(expr, [&found, &var](const ObjectRef& node) {
+    const VarNode* op = node.as<VarNode>();
+    if (op && op->name_hint == var->name_hint) {
+      found = true;
+    }
+  });
+
+  return found;
+}
+
+PrimExpr SelectNonZero(const PrimExpr& expr, PrimExpr non_zero) {
+  auto as_select = expr.as<SelectNode>();
+  if (as_select) {
+    auto false_value = as_select->false_value.as<IntImmNode>();
+    if (false_value && false_value->value == 0) {
+      return select(as_select->condition, CastToFloat(as_select->true_value), non_zero);
+    }
+  }
+  return select(expr == 0, non_zero, CastToFloat(expr));
+}
+
+PrimExpr SelectLogOr0(PrimExpr cond, PrimExpr value) { return select(cond, log(value), 0); }
+
+// Count math ops in an expr
+class MathOpCounter : public ExprVisitor {
+ public:
+  MathOpCounter(RangeInfer& rinf) : rinf(rinf) {}
+
+  PrimExpr FromExprMap(const ExprMap<size_t>& expr_map) const {
+    PrimExpr result = 0;
+    for (auto& [cond, count] : expr_map) {
+      result += select(cond, 0, Integer(count));
+    }
+    return result;
+  }
+
+ private:
+  PrimExpr ConstCond(Range lhs, Range rhs, Integer const_goal) {
+    bool lhs_const = this->ana.CanProveEqual(lhs->extent, 0),
+         rhs_const = this->ana.CanProveEqual(rhs->extent, 0);
+    if (lhs_const && rhs_const) {
+      return const_true();
+    } else if (lhs_const && const_goal.defined()) {
+      return lhs->min == const_goal;
+    } else if (rhs_const && const_goal.defined()) {
+      return rhs->min == const_goal;
+    } else {
+      return const_false();
+    }
+  }
+
+#define DefineVisitBinOp(Type, float_ct, int_ct, const_goal)  \
+  void VisitExpr_(const Type* op) override {                  \
+    if (op->a.dtype().is_float()) {                           \
+      float_ct++;                                             \
+    } else {                                                  \
+      Range lhs = this->rinf(op->a), rhs = this->rinf(op->b); \
+      PrimExpr const_cond = ConstCond(lhs, rhs, const_goal);  \
+      int_ct[const_cond]++;                                   \
+    }                                                         \
+    ExprVisitor::VisitExpr_(op);                              \
+  }
+  DefineVisitBinOp(AddNode, float_addsub, int_addsub, 0);
+  DefineVisitBinOp(SubNode, float_addsub, int_addsub, 0);
+  DefineVisitBinOp(MulNode, float_mul, int_mul, 1);
+  DefineVisitBinOp(DivNode, float_divmod, int_divmod, 1);
+  DefineVisitBinOp(FloorDivNode, float_divmod, int_divmod, 1);
+  DefineVisitBinOp(ModNode, float_divmod, int_divmod, 1);
+  DefineVisitBinOp(FloorModNode, float_divmod, int_divmod, 1);
+  DefineVisitBinOp(MaxNode, float_cmp, int_cmp, Integer());
+  DefineVisitBinOp(MinNode, float_cmp, int_cmp, Integer());
+
+#define BoolOp(Type)                         \
+  void VisitExpr_(const Type* op) override { \
+    bool_op++;                               \
+    ExprVisitor::VisitExpr_(op);             \
+  }
+  BoolOp(AndNode);
+  BoolOp(OrNode);
+  BoolOp(NotNode);
+#define NumToBoolCmpOp(Type)                 \
+  void VisitExpr_(const Type* op) override { \
+    if (op->a.dtype().is_float()) {          \
+      float_cmp++;                           \
+    } else {                                 \
+      int_cmp[const_false()]++;              \
+    }                                        \
+    ExprVisitor::VisitExpr_(op);             \
+  }
+  NumToBoolCmpOp(EQNode);
+  NumToBoolCmpOp(NENode);
+  NumToBoolCmpOp(LTNode);
+  NumToBoolCmpOp(LENode);
+  NumToBoolCmpOp(GTNode);
+  NumToBoolCmpOp(GENode);
+
+#undef DefineVisitBinOp
+#undef BoolOp
+#undef NumToBoolCmpOp
+
+  void VisitExpr_(const SelectNode* op) override {
+    select_op++;
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  // Returning empty range as we have no idea what the range could be.
+  void VisitExpr_(const CallNode* op) override {
+    auto* pop = op->op.as<OpNode>();
+    ICHECK(pop != nullptr);
+    auto effect_kind = op_call_effect_[GetRef<Op>(pop)];
+    bool is_pure =
+        effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation;
+
+    if (is_pure) {
+      if (op->dtype.is_float()) {
+        float_math_func++;
+      } else {
+        int_math_func++;
+      }
+    } else {
+      if (op->dtype.is_float()) {
+        float_other_func++;
+      } else {
+        int_other_func++;
+      }
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  RangeInfer& rinf;
+  Analyzer ana;
+  OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
+
+ public:
+  size_t float_mad{0};         // The number of float MAD (Multiply–add) ops
+  size_t float_addsub{0};      // The number of float add and sub ops
+  size_t float_mul{0};         // The number of float multiply ops
+  size_t float_divmod{0};      // The number of float div and mod ops
+  size_t float_cmp{0};         // The number of float comparison ops
+  size_t float_math_func{0};   // The number of float math func calls
+  size_t float_other_func{0};  // The number of other float func calls
+  size_t int_mad{0};           // The number of integer MAD (Multiply–add) ops
+  ExprMap<size_t> int_addsub;  // The number of integer add and sub ops
+  ExprMap<size_t> int_mul;     // The number of integer multiply ops
+  ExprMap<size_t> int_divmod;  // The number of integer div and mod ops
+  ExprMap<size_t> int_cmp;     // The number of integer comparison ops
+  size_t int_math_func{0};     // The number of float math func calls
+  size_t int_other_func{0};    // The number of other float func calls
+  size_t bool_op{0};           // The number of bool ops
+  size_t select_op{0};         // The number of select ops
+};
+
+// Extract all buffer accesses in an expr
+class BufferAccessExtractor : public StmtExprVisitor {
+ public:
+  void ExtractReads(const PrimExpr& expr) { this->VisitExpr(expr); }
+
+  void InsertAccess(const Buffer& buf, BufferAccessType acc_type, const Array<PrimExpr>& indices) {
+    BufferAccess& acc = buf_accesses[buf];
+    acc.acc_type = acc_type;
+    acc.indices.push_back(std::vector<PrimExpr>(indices.begin(), indices.end()));
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    BufferAccess& acc = buf_accesses[op->buffer];
+    switch (acc.acc_type) {
+      case BufferAccessType::kRead:
+        break;
+      case BufferAccessType::kWrite:
+        acc.acc_type = BufferAccessType::kReadWrite;
+        break;
+      case BufferAccessType::kReadWrite:
+        break;
+      case BufferAccessType::kUnknownRW:
+      default:
+        acc.acc_type = BufferAccessType::kRead;
+        break;
+    }
+
+    if (acc.acc_type != BufferAccessType::kReadWrite) {
+      // If a buffer is both read and written, in the tvm DSL, it must be a update,
+      // so the indices should be the same. Then we can skip appending indices for it.
+      // Otherwise we do the following.
+      buf_accesses[op->buffer].indices.push_back(
+          std::vector<PrimExpr>(op->indices.begin(), op->indices.end()));
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  BufferMap<BufferAccess> buf_accesses;
+};
+
+// Compute the coefficient for an loop iterator in an expression
+// Note: we use an approximation strategy to find coefficient.
+// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear)
+class CoefficientExtractor : public StmtExprVisitor {
+ public:
+  void VisitExpr_(const MulNode* node) final {
+    StmtExprVisitor::VisitExpr_(node);
+    if (visited_var) {
+      if (!visited_add) {
+        if (auto a = node->a.as<IntImmNode>()) {
+          visited_mul = true;
+          stride = a->value;
+        } else if (auto b = node->b.as<IntImmNode>()) {
+          visited_mul = true;
+          stride = b->value;
+        }
+      }
+    }
+  }
+
+  void VisitExpr_(const AddNode* node) final {
+    StmtExprVisitor::VisitExpr_(node);
+    if (visited_var) {
+      if (!visited_mul) {
+        visited_add = true;
+        stride = 1;
+      }
+    }
+  }
+
+  void VisitExpr_(const VarNode* node) final {
+    if (node == var_) {
+      visited_var = true;
+      // This is a magic default stride in case our approximation strategy fails
+      stride = 2;
+    }
+  }
+
+  int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) {
+    visited_var = visited_mul = visited_add = false;
+    var_ = var;
+
+    this->VisitExpr(expr);
+
+    if (visited_var && !visited_mul && !visited_add) {
+      return 1;
+    } else {
+      return stride;
+    }
+  }
+
+  bool visited_var{false};
+  bool visited_mul{false};
+  bool visited_add{false};
+  int stride{0};
+
+ private:
+  const VarNode* var_{nullptr};
+};
+
+// Compute stride for the accesses to a buffer
+std::pair<bool, PrimExpr> ComputeStride(const std::vector<std::vector<PrimExpr>>& indices,
+                                        const Array<PrimExpr>& shape, const VarNode* stride_var) {
+  PrimExpr min_stride{};
+  bool found = false;
+  CoefficientExtractor extractor;
+
+  for (const auto& index : indices) {
+    PrimExpr shape_stride = 1;
+    for (int i = static_cast<int>(index.size()) - 1; i >= 0; i--) {
+      int coefficient = extractor.ExtractCoefficient(index[i], stride_var);
+      if (extractor.visited_var) {
+        found = true;
+        if (min_stride.defined()) {
+          min_stride = min(min_stride, std::abs(coefficient) * shape_stride);
+        } else {
+          min_stride = std::abs(coefficient) * shape_stride;
+        }
+        break;
+      }
+      shape_stride *= shape[i];
+    }
+  }
+  return {found, min_stride};
+}
+
+PrimExpr LoopNonTrivialCond(const ForNode* loop) {
+  std::string name = loop->loop_var->name_hint;
+  if (name.substr(0, 8) == "blockIdx" || name.substr(0, 9) == "threadIdx" || name == "vthread") {
+    // These loops are always there no matter what the loop size is.
+    return Bool(true);
+  }
+  return loop->extent > 1;
+}
+
+std::tuple<PrimExpr, PrimExpr, PrimExpr> ComputeStrideForLoops(
+    const std::vector<std::vector<PrimExpr>>& indices, const Array<PrimExpr>& shape,
+    const std::vector<const ForNode*> loops_reversed) {
+  PrimExpr reduce_ratio_acc = 1;
+  PrimExpr reduce_ratio = 0, stride = 0, innermost_stride = 0;
+  PrimExpr found = Bool(false), in_loop = Bool(true), first_loop = Bool(true);
+  for (const auto& loop : loops_reversed) {
+    PrimExpr non_trivial_loop = LoopNonTrivialCond(loop);
+    // If loop is trivial, then the following don't happen and we effectively have a `continue;`.
+    auto [found_, stride_] = ComputeStride(indices, shape, loop->loop_var.get());
+    PrimExpr found_this = in_loop && non_trivial_loop && Bool(found_);
+    reduce_ratio_acc *= loop->extent;
+    if (found_) {
+      // innermost_stride is non-zero only when the stride is found from the innermost loop.
+      innermost_stride += SelectLogOr0(found_this && first_loop, stride_);
+      stride += SelectLogOr0(found_this, stride_);
+    }
+    reduce_ratio += SelectLogOr0(found_this, reduce_ratio_acc);
+    found = found || found_this;
+    // Breaks out when we actually find something.
+    in_loop = in_loop && (!non_trivial_loop || Bool(!found_));
+    first_loop = first_loop && (!non_trivial_loop);
+  }
+  // Default value. Can also use !found here, but that expression is more complex.
+  reduce_ratio += SelectLogOr0(in_loop, reduce_ratio_acc);
+  ICHECK(innermost_stride.defined());
+  return {exp(stride), exp(innermost_stride), exp(reduce_ratio)};
+}
+
+// Compute touched bytes and cache lines for accesses to a buffer
+std::vector<PrimExpr> ComputeRegion(const std::vector<std::vector<PrimExpr>>& indices,
+                                    RangeInfer& rinf, VarDefStackNode& vdefs) {
+  std::vector<PrimExpr> ret;
+  if (indices.empty()) return ret;
+  if (indices.size() == 1) {
+    for (const auto& index : indices[0]) {
+      Range range = rinf(index);
+      ret.push_back(vdefs.DefineConstShorthand(range->extent + 1));
+    }
+  } else {
+    for (const auto& indices_ : indices) {
+      Range range = rinf(indices_[0]);
+      PrimExpr size = range->extent + 1;
+      for (size_t i = 1; i < indices_.size(); ++i) {
+        size = max(size, rinf(indices_[i])->extent + 1);
+      }
+      ret.push_back(vdefs.DefineConstShorthand(size));
+    }
+  }
+  return ret;
+}
+
+using BufferInfo3 = std::tuple<BufferAccessType, PrimExpr, int>;
+using ForTouchRegionT = std::unordered_map<const ForNode*, BufferMap<std::vector<BufferInfo3>>>;
+
+bool LoopIterInIndices(Var for_var, const std::vector<std::vector<PrimExpr>>& indices) {
+  for (size_t j = 0; j < indices.size(); j++) {
+    for (size_t k = 0; k < indices[j].size(); k++) {
+      if (VarInExpr(for_var, indices[j][k])) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+PrimExpr ReuseDistInBytes(const BufferMap<std::vector<BufferInfo3>>& this_for_region,
+                          bool include_n_elems) {
+  PrimExpr reuse_dis_bytes = 0;
+  for (const auto& iter : this_for_region) {
+    for (auto& [_, n_elem, dtype_bytes] : iter.second) {
+      if (include_n_elems) {
+        reuse_dis_bytes += n_elem * dtype_bytes;
+      } else {
+        reuse_dis_bytes += dtype_bytes;
+      }
+    }
+  }
+  return reuse_dis_bytes;
+}
+
+PrimExpr MultiRWReuseDistance(const std::vector<BufferInfo3>& buffers, PrimExpr for_extent) {
+  ICHECK(!buffers.empty());
+  PrimExpr reuse_dis_iter = std::get<1>(buffers[0]);
+  for (size_t i = 1; i < buffers.size(); ++i) {
+    reuse_dis_iter = min(reuse_dis_iter, std::get<1>(buffers[i]));
+  }
+  return div(reuse_dis_iter, for_extent);
+}
+
+// Compute reuse distance and reuse ratio for accesses to a buffer
+// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct
+std::tuple<PrimExpr, PrimExpr, PrimExpr, PrimExpr, PrimExpr> ComputeReuse(
+    const Buffer& buf, const std::vector<std::vector<PrimExpr>>& indices,
+    const std::vector<const ForNode*>& for_loops, const ForTouchRegionT& for_touch_regions) {
+  // for (i = 0; i < N; i++) {
+  //   if (trivial_loop[i]) continue;
+  //   if (has_loop_iter[i]) { ... }
+  //   else return x_i;                     // kLoopMultipleRead
+  //   if (has_serial_uses[i]) return y_i;  // kSerialMultipleReadWrite
+  // }
+  // return 0;                              // NoReuse
+  //
+  // Denote the condition that we're still in the loop for i-th iteration by `in_loop[i]`. Then
+  //   multi_read_reuse[i] := in_loop[i] && !trivial_loop[i] && !has_loop_iter[i]
+  //   serial_rw_reuse[i] = in_loop[i] && !trivial_loop[i] && has_serial_uses[i]
+  // Also
+  //   in_loop[i + 1] := in_loop[i] && (!multi_read_reuse[i] && !serial_rw_reuse[i])
+  //     in_loop[0] = true
+  // The return value R from this function is then
+  //   \sum_i select(multi_read_reuse[i], x_i, 0) + select(serial_rw_reuse[i], y_i, 0)
+  // This function has multiple return values, they all follow this same idea.
+  int n_loops = static_cast<int>(for_loops.size());
+  PrimExpr in_loop = Bool(true), multi_read_reuse = Bool(false), serial_rw_reuse = Bool(false);
+  PrimExpr read_reuse_dist_iter = 1;
+  PrimExpr reuse_dist_iter = 0, reuse_dist_bytes = 0, reuse_count = 0;
+  for (int i = n_loops - 1; i >= 0; --i) {
+    auto* loop = for_loops[i];
+    PrimExpr extent = loop->extent;
+    const auto& this_for_region = for_touch_regions.at(loop);
+    const auto& this_buffer = this_for_region.at(buf);
+    int serial_reuse = (int)this_buffer.size() - 1;
+    PrimExpr has_loop_iter = Bool(LoopIterInIndices(loop->loop_var, indices));
+    // Use extent > 1 here (instead of LoopNonTrivialCond) to skip _all_ loops with extent 1
+    // including threadIdx/blockIdx, because that's what the concrete version does
+    // (and it makes more sense because if extent is 1 then there won't really be a "reuse").
+    PrimExpr non_trivial_loop = extent > 1, has_serial_uses = Bool(serial_reuse > 0);
+    PrimExpr multi_read_reuse_ = in_loop && non_trivial_loop && !has_loop_iter,
+             serial_rw_reuse_ = in_loop && non_trivial_loop && has_serial_uses;
+    PrimExpr no_exit = !non_trivial_loop || (has_loop_iter && !has_serial_uses);
+    in_loop = in_loop && no_exit;
+
+    // accumulate/update reuse distance
+    PrimExpr rw_reuse_dist_iter = MultiRWReuseDistance(this_buffer, extent);
+    PrimExpr reuse_dist_iter_ = SelectLogOr0(multi_read_reuse_, read_reuse_dist_iter) +
+                                SelectLogOr0(serial_rw_reuse_, rw_reuse_dist_iter);
+    read_reuse_dist_iter *= extent;  // This after reuse_dist_iter_
+    // For multi read, reuse_dist_bytes is computed based on the previous (1-level inner) loop.
+    // When this is the innermost loop, it's computed from this loop with a slightly different
+    // algorithm (`n_elems` is not counted).
+    PrimExpr read_reuse_dist_bytes_ =
+        i == n_loops - 1 ? ReuseDistInBytes(this_for_region, false)
+                         : ReuseDistInBytes(for_touch_regions.at(for_loops[i + 1]), true);
+    PrimExpr rw_reuse_dist_bytes_ = div(ReuseDistInBytes(this_for_region, true), extent);
+    PrimExpr reuse_dist_bytes_ = SelectLogOr0(multi_read_reuse_, read_reuse_dist_bytes_) +
+                                 SelectLogOr0(serial_rw_reuse_, rw_reuse_dist_bytes_);
+    PrimExpr reuse_count_ =
+        SelectLogOr0(multi_read_reuse_, extent) + SelectLogOr0(serial_rw_reuse_, serial_reuse);
+
+    multi_read_reuse = multi_read_reuse || multi_read_reuse_;
+    serial_rw_reuse = serial_rw_reuse || serial_rw_reuse_;
+    reuse_dist_bytes += reuse_dist_bytes_;
+    reuse_dist_iter += reuse_dist_iter_;
+    reuse_count += reuse_count_;
+  }
+  return std::make_tuple(multi_read_reuse, serial_rw_reuse, exp(reuse_dist_iter),
+                         exp(reuse_dist_bytes), exp(reuse_count));
+}
+
+// Extract features for every BufferStore statement
+class PerStoreFeatureExtractor : public StmtExprVisitor {
+ public:
+  explicit PerStoreFeatureExtractor(VarDefStackNode& vdefs, RangeInfer& rinf, int cache_line_size)
+      : vdefs(vdefs), rinf(rinf), cache_line_size_(cache_line_size) {}
+
+  void VisitStmt_(const AttrStmtNode* node) final {
+    if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) {
+      const Var& var = node->node.as<IterVarNode>()->var;
+      this->is_gpu = true;
+
+      // make a fake for node for blockIdx.x or threadIdx.x
+      For fake_for(var, 0, node->value, ForKind::kParallel, node->body);
+      auto fake_node = fake_for.as<ForNode>();
+      this->for_loop_stack.push_back(fake_node);
+      StmtExprVisitor::VisitStmt_(node);
+      for_loop_stack.pop_back();
+    } else if (node->attr_key == "pragma_auto_unroll_max_step") {
+      PrimExpr old_value = cur_auto_unroll_max_step_;
+      cur_auto_unroll_max_step_ = node->value;
+      StmtExprVisitor::VisitStmt_(node);
+      cur_auto_unroll_max_step_ = old_value;
+    } else {
+      StmtExprVisitor::VisitStmt_(node);
+    }
+  }
+
+  void VisitStmt_(const ForNode* node) final {
+    for_loop_stack.push_back(node);
+    StmtExprVisitor::VisitStmt(node->body);
+    for_loop_stack.pop_back();
+  }
+
+  void VisitStmt_(const BufferStoreNode* node) final {
+    PrimExpr loop_prod = LoopProd(FilterForLoops(std::nullopt));
+
+    // Group 1: Computation related features
+    MathOpCounter moc(this->rinf);
+    moc(node->value);
+    ExtractComputationFeature(node, moc, loop_prod);
+
+    // Group 2: Buffer access related features (per buffer)
+    std::vector<PrimExpr> mem_bytes_list, compute_ops_list;
+    PrimExpr cur_compute_ops;
+    ExtractBufferAccessFeature(node, moc, loop_prod, &cur_compute_ops, &compute_ops_list,
+                               &mem_bytes_list);
+
+    // Group 3: Arithmetic intensity related features
+    // LOG_WARNING << "ExtractArithmeticIntensityFeature is unsupported yet";
+    // ExtractArithmeticIntensityFeature(node, cur_compute_ops, compute_ops_list, mem_bytes_list);
+
+    // Group 5: Outer scope related features
+    ExtractOuterScopeFeature(node, loop_prod);
+  }
+
+  void VisitStmt_(const BufferRealizeNode* node) final {
+    StmtExprVisitor::VisitStmt_(node);
+
+    // Group 4: Allocation related features
+    ExtractAllocationFeature(node);
+  }
+
+  // Extract computation related features (group 1)
+  void ExtractComputationFeature(const BufferStoreNode* node, const MathOpCounter& moc,
+                                 const PrimExpr& loop_prod) {
+    // Computation related features
+    FeatureSet& fea = bufstore_feats[node->buffer];
+
+    fea.float_mad = loop_prod * (int)moc.float_mad;
+    fea.float_addsub = loop_prod * (int)moc.float_addsub;
+    fea.float_mul = loop_prod * (int)moc.float_mul;
+    fea.float_divmod = loop_prod * (int)moc.float_divmod;
+    fea.float_cmp = loop_prod * (int)moc.float_cmp;
+    fea.float_math_func = loop_prod * (int)moc.float_math_func;
+    fea.float_other_func = loop_prod * (int)moc.float_other_func;
+    fea.int_mad = loop_prod * (int)moc.int_mad;
+    fea.int_addsub = loop_prod * moc.FromExprMap(moc.int_addsub);
+    fea.int_mul = loop_prod * moc.FromExprMap(moc.int_mul);
+    fea.int_divmod = loop_prod * moc.FromExprMap(moc.int_divmod);
+    fea.int_math_func = loop_prod * (int)moc.int_math_func;
+    fea.int_cmp = loop_prod * moc.FromExprMap(moc.int_cmp);
+    fea.int_other_func = loop_prod * (int)moc.int_other_func;
+    fea.bool_op = loop_prod * (int)moc.bool_op;
+    fea.select_op = loop_prod * (int)moc.select_op;
+
+    FillLoopFeatures(ForKind::kVectorized, fea.vec_num, fea.vec_len, fea.vec_prod, fea.vec_type);
+    FillLoopFeatures(ForKind::kUnrolled, fea.unroll_num, fea.unroll_len, fea.unroll_prod,
+                     fea.unroll_type);
+    FillLoopFeatures(ForKind::kParallel, fea.parallel_num, fea.parallel_len, fea.parallel_prod,
+                     fea.parallel_type);
+
+    // GPU threads
+    fea.is_gpu = Bool(this->is_gpu);
+    Map<String, PrimExpr> loop_sizes;
+    for (const auto& loop : this->for_loop_stack) {
+      loop_sizes.Set(loop->loop_var->name_hint, loop->extent);
+    }
+    fea.blockIdx_x_len = loop_sizes.Get("blockIdx.x").value_or(1);
+    fea.blockIdx_y_len = loop_sizes.Get("blockIdx.y").value_or(1);
+    fea.blockIdx_z_len = loop_sizes.Get("blockIdx.z").value_or(1);
+    fea.threadIdx_x_len = loop_sizes.Get("threadIdx.x").value_or(1);
+    fea.threadIdx_y_len = loop_sizes.Get("threadIdx.y").value_or(1);
+    fea.threadIdx_z_len = loop_sizes.Get("threadIdx.z").value_or(1);
+    fea.vthread_len = loop_sizes.Get("vthread").value_or(1);
+  }
+
+  // Extract buffer access related features (group 2)
+  void ExtractBufferAccessFeature(const BufferStoreNode* node, const MathOpCounter& moc,
+                                  PrimExpr loop_prod, PrimExpr* cur_compute_ops,
+                                  std::vector<PrimExpr>* compute_ops_list,
+                                  std::vector<PrimExpr>* mem_bytes_list) {
+    std::vector<BufferAccessFeature>& acc_feas = bufstore_feats[node->buffer].access_feas;
+    // We may have multiple bufferstore nodes for the same buffer (e.g., 1 for initializing an
+    // array, and 1 for computing it). In that case, delibrately overwrite the previous result.
+    acc_feas.clear();
+
+    BufferAccessExtractor buf_extractor;
+    buf_extractor.InsertAccess(node->buffer, BufferAccessType::kWrite, node->indices);
+    buf_extractor.ExtractReads(node->value);
+    auto for_loops = FilterForLoops(std::nullopt), for_loops_rev = for_loops;
+    std::reverse(for_loops_rev.begin(), for_loops_rev.end());
+
+    // Compute touched region for all outer loops
+    // * Make a copy of our global RangeInfer and override all loop variables to be [min, min]
+    RangeInfer rangeinf = this->rinf;
+    for (auto* loop : for_loops) {
+      // Using [a, b] convension for range, this means [x->min, x->min].
+      rangeinf.Bind(loop->loop_var, Range::FromMinExtent(loop->min, 0), true);
+    }
+
+    mem_bytes_list->reserve(for_loops.size());
+    compute_ops_list->reserve(for_loops.size());
+
+    *cur_compute_ops = (int)(moc.float_mad + moc.float_addsub + moc.float_mul + moc.float_divmod +
+                             moc.float_cmp + moc.float_math_func + moc.float_other_func);
+
+    // std::cout << "In BufferStoreNode " << node->buffer->name << "\n";
+    std::vector<PrimExpr> tmp_region;
+    for (auto* loop : for_loops_rev) {
+      rangeinf.BindLoop(loop, true);
+      // std::cout << "  in for loop " << loop->loop_var->name_hint << "\n";
+      // Note, here we do overwrite.
+      // So if there are multiple BufferStoreNode, the last one will overwrite the first few.
+      // e.g. The update part in gemm will overwrite the init part.
+      BufferMap<std::vector<BufferInfo3>>& buffer_regions_map = for_touch_regions_[loop];
+      PrimExpr mem_bytes = 0;
+      for (const auto& x : buf_extractor.buf_accesses) {
+        const Buffer& t = x.first;
+        const BufferAccess& acc = x.second;
+        // std::cout << "    in buffer access name " << t->name << "\n";
+        tmp_region = ComputeRegion(acc.indices, rangeinf, this->vdefs);
+        PrimExpr touched_size = ElementProduct(tmp_region);
+        buffer_regions_map[t].emplace_back(acc.acc_type, touched_size, t->dtype.bytes());
+        mem_bytes += touched_size * t->dtype.bytes();
+      }
+
+      mem_bytes_list->push_back(log2(mem_bytes));
+      *cur_compute_ops *= loop->extent;
+      compute_ops_list->push_back(log2(*cur_compute_ops));
+    }
+
+    //  Buffer access related features (per buffer)
+    auto bufmap = buf_extractor.buf_accesses;
+    std::vector<std::pair<Buffer, BufferAccess>> buf_accs(bufmap.begin(), bufmap.end());
+    for (size_t i = 0; i < buf_accs.size(); ++i) {
+      auto [buf, acc] = buf_accs[i];
+      Integer ele_bytes = buf->dtype.bytes();
+      // calculate bytes
+      PrimExpr bytes = loop_prod * ele_bytes, unique_bytes;
+      // calculate cache lines
+      PrimExpr stride, lines, unique_lines;
+      if (for_loops.empty()) {
+        unique_bytes = ele_bytes;
+        stride = 0;
+        lines = 1.0f;
+        unique_lines = 1.0f;
+      } else {
+        unique_bytes = this->vdefs.DefineConstShorthand(
+            std::get<1>(for_touch_regions_[for_loops.front()][buf].front()) * ele_bytes);
+        auto [stride_, innermost_stride, reduce_ratio] =
+            ComputeStrideForLoops(acc.indices, buf->shape, for_loops_rev);
+        // convert `stride` back to the stride of the innermost iterator
+        stride = innermost_stride;
+        auto term1 = min(1.0f, div(CastToFloat(stride_ * ele_bytes), (float)cache_line_size_));
+        lines = max(div(CastToFloat(loop_prod), CastToFloat(reduce_ratio)) * term1, 1.0f);
+
+        // Modeled after this:
+        // PrimExpr n_continuous = ele_bytes;
+        // for (int i = std::min(tmp_region.size() - 1, t->shape.size() - 1); i >= 0; i--) {
+        //   if (this->ana_.CanProveEqual(tmp_region[i], t->shape[i])) {
+        //     n_continuous *= tmp_region[i];
+        //     break;
+        //   }
+        // }
+        PrimExpr n_continuous = 0, in_loop = Bool(true);
+        for (int i = std::min(tmp_region.size() - 1, buf->shape.size() - 1); i >= 0; i--) {
+          PrimExpr is_equal = tmp_region[i] == buf->shape[i];
+          n_continuous += SelectLogOr0(in_loop && is_equal, ele_bytes * tmp_region[i]);
+          in_loop = in_loop && (!is_equal);
+        }
+        // If we've done the whole loop without `is_equal == True`, then the value
+        // should just be `ele_bytes`.
+        n_continuous += SelectLogOr0(in_loop, ele_bytes);
+        unique_lines =
+            max(div(CastToFloat(unique_bytes), min(exp(n_continuous), cache_line_size_)), 1.0f);
+      }
+
+      auto [multi_read_cond, serial_multi_rw_cond, reuse_dis_iter, reuse_dis_bytes, reuse_ct] =
+          ComputeReuse(buf, acc.indices, for_loops, for_touch_regions_);
+      multi_read_cond = SimplifyExpr(multi_read_cond);
+      serial_multi_rw_cond = SimplifyExpr(serial_multi_rw_cond);
+      reuse_dis_iter = SimplifyExpr(reuse_dis_iter);
+      reuse_dis_bytes = SimplifyExpr(reuse_dis_bytes);
+      reuse_ct = SimplifyExpr(reuse_ct);
+      PrimExpr no_reuse_cond = SimplifyExpr(!(serial_multi_rw_cond || multi_read_cond));
+
+      acc_feas.emplace_back();
+      BufferAccessFeature& acc_fea = acc_feas.back();
+
+      acc_fea.buffer_name = buf->name;
+      acc_fea.acc_type = acc.acc_type;
+      acc_fea.stride = stride;
+      acc_fea.bytes = bytes;
+      acc_fea.unique_bytes = unique_bytes;
+      acc_fea.lines = lines;
+      acc_fea.unique_lines = unique_lines;
+      acc_fea.multi_read_cond = multi_read_cond;
+      acc_fea.serial_multi_rw_cond = serial_multi_rw_cond;
+      acc_fea.no_reuse_cond = no_reuse_cond;
+      acc_fea.reuse_dis_iter = reuse_dis_iter;
+      acc_fea.reuse_dis_bytes = reuse_dis_bytes;
+      acc_fea.reuse_ct = reuse_ct;
+      // no reuse, multiply by a magic number '2'
+      PrimExpr coef = SelectNonZero(reuse_ct, 0.5f);
+      acc_fea.bytes_d_reuse_ct = bytes / coef;
+      acc_fea.unique_bytes_d_reuse_ct = unique_bytes / coef;
+      acc_fea.lines_d_reuse_ct = lines / coef;
+      acc_fea.unique_lines_d_reuse_ct = unique_lines / coef;
+    }
+  }
+
+  // Extract allocation related features (group 4)
+  void ExtractAllocationFeature(const BufferRealizeNode* node) {
+    FeatureSet& fea = bufstore_feats[node->buffer];
+    PrimExpr allocation_size = 1;
+    for (const auto& x : node->bounds) {
+      allocation_size *= this->rinf.GetMax(x->extent);
+    }
+    // allocation feature
+    allocation_size = this->vdefs.DefineConstShorthand(allocation_size);
+    auto loop_prod = LoopProd(FilterForLoops(std::nullopt));
+    fea.alloc_size = allocation_size * node->buffer->dtype.bytes();
+    fea.alloc_prod = allocation_size * loop_prod;
+    fea.alloc_outer_prod = loop_prod;
+    fea.alloc_inner_prod = div(fea.outer_prod, loop_prod);
+  }
+
+  // Extract outer scope related features (group 5)
+  void ExtractOuterScopeFeature(const BufferStoreNode* node, const PrimExpr& loop_prod) {
+    FeatureSet& fea = bufstore_feats[node->buffer];
+    fea.outer_prod = loop_prod;
+    fea.num_loops = CountLoops(for_loop_stack);
+    fea.auto_unroll_max_step = cur_auto_unroll_max_step_;
+  }
+
+  void FillLoopFeatures(ForKind kind, PrimExpr& num, PrimExpr& len, PrimExpr& prod,
+                        AnnotationPosType& type) {
+    auto loops = FilterForLoops(kind);
+    num = CountLoops(loops);
+    if (loops.empty()) {
+      len = prod = 0;
+      type = AnnotationPosType::kPosNone;
+    } else {
+      len = loops.back()->extent;
+      prod = 1;
+      for (auto* loop : loops) {
+        prod *= loop->extent;
+      }
+      type = AnnotationPosType::kPosMixed;
+    }
+  }
+
+  std::vector<const ForNode*> FilterForLoops(std::optional<ForKind> kind) {
+    std::unordered_set<std::string> var_registered;
+    std::vector<const ForNode*> loops;
+    for (auto* loop : this->for_loop_stack) {
+      if (kind && kind.value() != loop->kind) {
+        continue;
+      }
+      std::string var_name = loop->loop_var->name_hint;
+      if (var_registered.count(var_name)) {
+        ICHECK(var_name.substr(0, 8) == "blockIdx" || var_name.substr(0, 9) == "threadIdx")
+            << "Duplicate non-gpu-grid loop var: " << var_name;
+        continue;
+      }
+      loops.push_back(loop);
+    }
+    return loops;
+  }
+
+  PrimExpr CountLoops(const std::vector<const ForNode*>& loops) {
+    PrimExpr num = 0;
+    for (auto* loop : loops) {
+      num += select(LoopNonTrivialCond(loop), 1, 0);
+    }
+    return num;
+  }
+
+  PrimExpr LoopProd(const std::vector<const ForNode*>& loops) {
+    PrimExpr ret = 1.0f;
+    for (auto* loop : loops) {
+      ret *= loop->extent;
+    }
+    return ret;
+  }
+
+ public:
+  BufferMap<FeatureSet> bufstore_feats;
+
+ private:
+  // The shared arithmetic analyzers
+  VarDefStackNode& vdefs;
+  RangeInfer& rinf;
+  VarMapT flatmap;
+
+  std::vector<const ForNode*> for_loop_stack;
+  PrimExpr previous_outer;
+  // GPU-related features
+  bool is_gpu;
+  PrimExpr cur_auto_unroll_max_step_{0};
+
+  // Store touch region information for all for loops. The format of this nested map:
+  // For a loop, for all its touched buffers, for all different accesses to the buffers,
+  // its (access type, number of touched elements, number of bytes of single element)
+  ForTouchRegionT for_touch_regions_;
+
+  // The default cache line size in bytes
+  const int cache_line_size_ = 64;
+};
+
+PrimExpr slog(PrimExpr x) { return x.dtype().is_bool() ? x : log(x); }
+
+}  // namespace
+
+VarDefStack GetPerStoreFeatureExpr(const Stmt& stmt, VarDefStackNode& vdefs, RangeInfer& rinf,
+                                   size_t cache_line_size, size_t max_n_bufs) {
+  // Extract features
+  PerStoreFeatureExtractor extractor(vdefs, rinf, cache_line_size);
+  extractor(stmt);
+  std::vector<std::pair<Buffer, FeatureSet>> buffer_features(extractor.bufstore_feats.begin(),
+                                                             extractor.bufstore_feats.end());
+  std::sort(buffer_features.begin(), buffer_features.end(),
+            [](auto& a, auto& b) { return a.first->name < b.first->name; });
+
+  // Define features in context, and put the resulted variable names in ret.
+  VarDefStack feats;
+  for (size_t i = 0; i < buffer_features.size(); ++i) {
+    auto& [buf, fea_set] = buffer_features[i];
+    auto PushFeature = [&i, &feats](const std::string& name, const PrimExpr& val) {
+      auto name_ = "BS" + std::to_string(i) + "." + name;
+      feats->Append(name_, val);
+    };
+    auto PushEnumFeature = [&i, &feats](const std::string& field,
+                                        const std::vector<std::string>& kind_names, auto val) {
+      for (size_t j = 0; j < kind_names.size(); j++) {
+        auto name_ = "BS" + std::to_string(i) + "." + field + "." + kind_names[j];
+        feats->Append(name_, Bool(static_cast<size_t>(val) == j));
+      }
+    };
+
+#define PUSH_FEATURE(feature) PushFeature(#feature, fea_set.feature);
+#define PUSH_ENUM_FEATURE(feature, names) PushEnumFeature(#feature, names, fea_set.feature);
+    /***** Group 1: Computation related features *****/
+    PUSH_FEATURE(float_mad);
+    PUSH_FEATURE(float_addsub);
+    PUSH_FEATURE(float_mul);
+    PUSH_FEATURE(float_divmod);
+    PUSH_FEATURE(float_cmp);
+    PUSH_FEATURE(float_math_func);
+    PUSH_FEATURE(float_other_func);
+    PUSH_FEATURE(int_mad);
+    PUSH_FEATURE(int_addsub);
+    PUSH_FEATURE(int_mul);
+    PUSH_FEATURE(int_divmod);
+    PUSH_FEATURE(int_cmp);
+    PUSH_FEATURE(int_math_func);
+    PUSH_FEATURE(int_other_func);
+    PUSH_FEATURE(bool_op);
+    PUSH_FEATURE(select_op);
+
+    static const std::vector<std::string> annot_pos_names = {
+        "kPosNone",        "kPosInnerSpatial", "kPosMiddleSpatial", "kPosOuterSpatial",
+        "kPosInnerReduce", "kPosMiddleReduce", "kPosOuterReduce",   "kPosMixed",
+    };
+    PUSH_FEATURE(vec_num);
+    PUSH_FEATURE(vec_prod);
+    PUSH_FEATURE(vec_len);
+    PUSH_ENUM_FEATURE(vec_type, annot_pos_names);
+    PUSH_FEATURE(unroll_num);
+    PUSH_FEATURE(unroll_prod);
+    PUSH_FEATURE(unroll_len);
+    PUSH_ENUM_FEATURE(unroll_type, annot_pos_names);
+    PUSH_FEATURE(parallel_num);
+    PUSH_FEATURE(parallel_prod);
+    PUSH_FEATURE(parallel_len);
+    PUSH_ENUM_FEATURE(parallel_type, annot_pos_names);
+
+    PUSH_FEATURE(is_gpu);
+    PUSH_FEATURE(blockIdx_x_len);
+    PUSH_FEATURE(blockIdx_y_len);
+    PUSH_FEATURE(blockIdx_z_len);
+    PUSH_FEATURE(threadIdx_x_len);
+    PUSH_FEATURE(threadIdx_y_len);
+    PUSH_FEATURE(threadIdx_z_len);
+    PUSH_FEATURE(vthread_len);
+
+    /***** Group 2: Buffer access related features *****/
+    static const std::vector<std::string> acc_type_names = {"kRead", "kWrite", "kReadWrite"};
+    auto& buf_feats = fea_set.access_feas;
+    std::sort(buf_feats.begin(), buf_feats.end(),
+              [](auto& a, auto& b) { return a.buffer_name < b.buffer_name; });
+    buf_feats.resize(max_n_bufs);
+    for (size_t j = 0; j < max_n_bufs; ++j) {
+      const auto& acc_fea = buf_feats[j];
+
+#define PUSH_BUF_FEATURE(feature)                       \
+  PushFeature("B" + std::to_string(j) + "." + #feature, \
+              acc_fea.feature.defined() ? slog(acc_fea.feature) : Integer(0));
+
+      PushEnumFeature("B" + std::to_string(j) + ".acc_type", acc_type_names, acc_fea.acc_type);
+      PUSH_BUF_FEATURE(bytes);
+      PUSH_BUF_FEATURE(unique_bytes);
+      PUSH_BUF_FEATURE(lines);
+      PUSH_BUF_FEATURE(unique_lines);
+      PUSH_BUF_FEATURE(multi_read_cond);
+      PUSH_BUF_FEATURE(serial_multi_rw_cond);
+      PUSH_BUF_FEATURE(no_reuse_cond);
+      PUSH_BUF_FEATURE(reuse_dis_iter);
+      PUSH_BUF_FEATURE(reuse_dis_bytes);
+      PUSH_BUF_FEATURE(reuse_ct);
+      PUSH_BUF_FEATURE(bytes_d_reuse_ct);
+      PUSH_BUF_FEATURE(unique_bytes_d_reuse_ct);
+      PUSH_BUF_FEATURE(lines_d_reuse_ct);
+      PUSH_BUF_FEATURE(unique_lines_d_reuse_ct);
+      PUSH_BUF_FEATURE(stride);
+    }
+
+    /***** Group 4: Allocation related features *****/
+    PUSH_FEATURE(alloc_size);
+    PUSH_FEATURE(alloc_prod);
+    PUSH_FEATURE(alloc_outer_prod);
+    PUSH_FEATURE(alloc_inner_prod);
+
+    /***** Group 5: Outer scope related features *****/
+    PUSH_FEATURE(outer_prod);
+    PUSH_FEATURE(num_loops);
+    PUSH_FEATURE(auto_unroll_max_step);
+  }
+  return feats;
+}
+
+}  // namespace felix
+}  // namespace tvm
diff --git a/src/felix/features.h b/src/felix/features.h
new file mode 100644
index 000000000..586c879a9
--- /dev/null
+++ b/src/felix/features.h
@@ -0,0 +1,29 @@
+#ifndef TVM_AUTO_SCHEDULER_FEATURES_H_
+#define TVM_AUTO_SCHEDULER_FEATURES_H_
+
+#include <tvm/auto_scheduler/search_task.h>
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+#include "rangeinfer.h"
+
+namespace tvm {
+namespace felix {
+
+template <class T>
+using BufferMap = std::unordered_map<tir::Buffer, T, ObjectHash, ObjectEqual>;
+template <class T>
+using ExprMap = std::unordered_map<PrimExpr, T, StructuralHash, StructuralEqual>;
+
+arith::VarDefStack GetPerStoreFeatureExpr(const tir::Stmt& stmt, arith::VarDefStackNode& vdefs,
+                                          RangeInfer& rinf, size_t cache_line_size,
+                                          size_t max_n_bufs);
+
+std::vector<PrimExpr> GetConstraints(const tir::Stmt& code,
+                                     const auto_scheduler::HardwareParams& hw_params);
+
+}  // namespace felix
+}  // namespace tvm
+
+#endif  // TVM_AUTO_SCHEDULER_FEATURE_H_
diff --git a/src/felix/rangeinfer.h b/src/felix/rangeinfer.h
new file mode 100644
index 000000000..2d86370c4
--- /dev/null
+++ b/src/felix/rangeinfer.h
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+
+#include <algorithm>
+
+namespace tvm {
+namespace felix {
+
+using namespace tir;
+
+inline FloatImm ToFloatImm(double e) { return FloatImm(DataType::Float(32), e); }
+
+inline PrimExpr CastToFloat(const PrimExpr& value) {
+  return value->dtype.is_float() ? value : cast(DataType::Float(32), value);
+}
+
+inline bool ToConstNumber(const PrimExpr& x, double& val) {
+  if (!x.defined()) {
+    return false;
+  }
+  if (const auto op = x.as<IntImmNode>()) {
+    val = static_cast<float>(op->value);
+    return true;
+  } else if (const auto op = x.as<FloatImmNode>()) {
+    val = op->value;
+    return true;
+  } else {
+    return false;
+  }
+}
+
+inline bool ToConstRange(const Range& range, double& min, double& max) {
+  double extent;
+  if (!ToConstNumber(range->min, min) || !ToConstNumber(range->extent, extent)) {
+    return false;
+  }
+  max = min + extent;
+  return true;
+}
+
+inline bool ToConstNumber(const Range& range, double& x) {
+  double extent;
+  return ToConstNumber(range->min, x) && ToConstNumber(range->extent, extent) && extent == 0;
+}
+
+template <typename Ret>
+class MemoizedExprFunctor : public ExprFunctor<Ret(const PrimExpr&)> {
+ public:
+  Ret VisitExpr(const PrimExpr& expr) override {
+    auto it = this->memo.find(expr);
+    if (it != this->memo.end()) {
+      return it->second;
+    }
+    return this->memo[expr] = ExprFunctor<Ret(const PrimExpr&)>::VisitExpr(expr);
+  }
+
+ protected:
+  std::unordered_map<PrimExpr, Ret, StructuralHash, StructuralEqual> memo;
+};
+
+template <>
+class MemoizedExprFunctor<PrimExpr> : public ExprMutator {
+ public:
+  PrimExpr VisitExpr(const PrimExpr& expr) override {
+    auto it = this->memo.find(expr);
+    if (it != this->memo.end()) {
+      return it->second;
+    }
+    return this->memo[expr] = ExprMutator::VisitExpr(expr);
+  }
+
+ protected:
+  std::unordered_map<PrimExpr, PrimExpr, StructuralHash, StructuralEqual> memo;
+};
+
+class RangeInfer : public MemoizedExprFunctor<Range> {
+ public:
+  RangeInfer(Range range_for_sizevar = Range()) : range_for_sizevar(range_for_sizevar) {}
+
+  void BindLoop(const ForNode* loop, bool allow_override) {
+    // To tighten the bound, we adopt an [a, b] convension for the range instead of [a, b).
+    Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent - 1), allow_override);
+  }
+
+  void Bind(Var var, Range range, bool allow_override) {
+    ICHECK(arith::ExprIsConstant(range->min) && arith::ExprIsConstant(range->extent))
+        << "Cannot bind non-constant range for variable " << var << " = " << range;
+    if (this->var_bind.count(var->name_hint)) {
+      if (allow_override) {
+        this->memo.clear();
+      } else {
+        LOG_FATAL << "Cannot override range for " << var->name_hint << " (already defined)";
+      }
+    }
+    this->var_bind[var->name_hint] = range;
+  }
+
+  PrimExpr GetMax(PrimExpr e) {
+    auto range = VisitExpr(e);
+    return range->min + range->extent;
+  }
+
+ private:
+#define VisitBinOpSameMono(Type, Func)                                     \
+  Range VisitExpr_(const Type##Node* op) final {                           \
+    Range lhs = VisitExpr(op->a), rhs = VisitExpr(op->b);                  \
+    PrimExpr lmax = lhs->min + lhs->extent, rmax = rhs->min + rhs->extent; \
+    PrimExpr begin = Func(lhs->min, rhs->min), end = Func(lmax, rmax);     \
+    return Range(begin, end);                                              \
+  }
+
+#define VisitBinOpRevMono(Type, Func)                                      \
+  Range VisitExpr_(const Type##Node* op) final {                           \
+    Range lhs = VisitExpr(op->a), rhs = VisitExpr(op->b);                  \
+    PrimExpr lmax = lhs->min + lhs->extent, rmax = rhs->min + rhs->extent; \
+    PrimExpr begin = Func(lhs->min, rmax), end = Func(lmax, rhs->min);     \
+    return Range(begin, end);                                              \
+  }
+
+#define VisitMods(Type)                                              \
+  Range VisitExpr_(const Type##Node* op) final {                     \
+    Range lhs = VisitExpr(op->a), rhs = VisitExpr(op->b);            \
+    PrimExpr zero = Integer(0);                                      \
+    if (is_const_int(lhs->min, 0) && is_const_int(lhs->extent, 0)) { \
+      return Range(zero, zero);                                      \
+    }                                                                \
+    return Range(zero, rhs->min + rhs->extent - 1);                  \
+  }
+
+  VisitBinOpSameMono(Add, add);
+  VisitBinOpSameMono(Mul, mul);
+  VisitBinOpRevMono(Sub, sub);
+  VisitBinOpRevMono(Div, div);
+  VisitBinOpRevMono(FloorDiv, floordiv);
+  VisitBinOpSameMono(Min, min);
+  VisitBinOpSameMono(Max, max);
+  VisitMods(Mod);
+  VisitMods(FloorMod);
+
+  Range VisitExpr_(const IntImmNode* op) final {
+    return Range::FromMinExtent(GetRef<IntImm>(op), Integer(0));
+  }
+  Range VisitExpr_(const FloatImmNode* op) final {
+    return Range::FromMinExtent(GetRef<FloatImm>(op), ToFloatImm(0.0));
+  }
+  // SizeVar that we insert as schedule vars are seen as constants
+  // as they eventually will be constants for each given configuration.
+  Range VisitExpr_(const SizeVarNode* op) final {
+    if (this->range_for_sizevar.defined()) {
+      return this->range_for_sizevar;
+    }
+    return Range::FromMinExtent(GetRef<SizeVar>(op), 0);
+  }
+  Range VisitExpr_(const VarNode* op) final {
+    auto it = this->var_bind.find(op->name_hint);
+    if (it != this->var_bind.end()) {
+      return it->second;
+    } else {
+      LOG_FATAL << "Cannot find var \'" << op->name_hint << "\' in range inference";
+      return Range();
+    }
+  }
+
+  Range VisitExpr_(const SelectNode* op) final {
+    // Don't have much use for range of condition.
+    Range lhs = VisitExpr(op->true_value), rhs = VisitExpr(op->false_value);
+    PrimExpr lmax = lhs->min + lhs->extent, rmax = rhs->min + rhs->extent;
+    return Range(min(lhs->min, rhs->min), max(lmax, rmax));
+  }
+
+  Range VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
+
+  Range VisitExpr_(const CallNode* op) final {
+    auto* pop = op->op.as<OpNode>();
+    ICHECK(pop != nullptr);
+    if (pop->name == "tir.exp") {
+      Range r = VisitExpr(op->args[0]);
+      return Range(exp(r->min), exp(r->min + r->extent));
+    } else if (pop->name == "tir.log") {
+      Range r = VisitExpr(op->args[0]);
+      return Range(log(r->min), log(r->min + r->extent));
+    } else if (pop->name == "tir.logk") {
+      double base_ = 0;
+      if (!ToConstNumber(VisitExpr(op->args[0]), base_)) {
+        LOG_FATAL << "logk only supports constant base";
+      }
+      Range x = VisitExpr(op->args[1]);
+      double x_min = 0, x_max = 0;
+      if (ToConstRange(x, x_min, x_max)) {
+        return Range(ToFloatImm(std::log(x_min) / std::log(base_)),
+                     ToFloatImm(std::log(x_max) / std::log(base_)));
+      } else {
+        return Range(div(log(x->min), std::log(base_)),
+                     div(log(x->min + x->extent), std::log(base_)));
+      }
+    } else if (pop->name == "tir.pow") {
+      Range r1 = VisitExpr(op->args[0]), r2 = VisitExpr(op->args[1]);
+      double r1_min = 0, r1_max = 0, r2_min = 0, r2_max = 0;
+      bool is_r1_const = ToConstRange(r1, r1_min, r1_max),
+           is_r2_const = ToConstRange(r2, r2_min, r2_max);
+      if (is_r1_const && is_r2_const) {
+        return Range(ToFloatImm(std::pow(r1_min, r2_min)), ToFloatImm(std::pow(r1_max, r2_max)));
+      } else if (is_r1_const && r1_min == r1_max) {
+        FloatImm r1f = ToFloatImm(r1_min);
+        if (r1_min <= 0) {
+          LOG_FATAL << "only pow with base >= 1 is supported; got base " << r1 << " in "
+                    << GetRef<PrimExpr>(op);
+        } else if (r1_min < 1) {
+          return Range(pow(r1f, r2->min + r2->extent), pow(r1f, r2->min));
+        } else {
+          return Range(pow(r1f, r2->min), pow(r1f, r2->min + r2->extent));
+        }
+      } else if (is_r2_const && r2_min == r2_max) {
+        FloatImm r2f = ToFloatImm(r2_min);
+        if (r2_min < 0) {
+          return Range(pow(r1->min + r1->extent, r2f), pow(r1->min, r2f));
+        } else {
+          return Range(pow(r1->min, r2f), pow(r1->min + r1->extent, r2f));
+        }
+      } else {
+        LOG_FATAL << "pow with non-constant base and exponent is unsupported: "
+                  << GetRef<PrimExpr>(op);
+      }
+    } else {
+      LOG_FATAL << "Call to " << pop->name << " not supported";
+    }
+    return Range();  // unreachable
+  }
+
+  Range VisitExprDefault_(const Object* op) final {
+    LOG_FATAL << "Expression of type " << op->GetTypeKey() << " is unsupported in RangeInfer";
+    return Range();
+  }
+
+ public:
+  std::unordered_map<std::string, Range> var_bind;
+  Range range_for_sizevar;
+};
+
+}  // namespace felix
+}  // namespace tvm
diff --git a/src/felix/utils.h b/src/felix/utils.h
index c06d7b4f5..5f6f28134 100644
--- a/src/felix/utils.h
+++ b/src/felix/utils.h
@@ -5,6 +5,8 @@
 #include <tvm/auto_scheduler/transform_step.h>
 #include <tvm/runtime/container/string.h>
 
+#include <numeric>
+
 namespace tvm {
 namespace felix {
 
@@ -14,6 +16,12 @@ using auto_scheduler::Step;
 
 String PrintTrStep(const Step& step);
 
+/*! \brief Compute the product of all elements in a vector */
+inline PrimExpr ElementProduct(const std::vector<PrimExpr>& array) {
+  return std::accumulate(array.begin(), array.end(), PrimExpr(1),
+                         [](const PrimExpr& a, const PrimExpr& b) { return a * b; });
+}
+
 inline std::pair<PrimExpr, PrimExpr> GetCumulativeSpaceAndReductionLength_(const Stage& stage) {
   PrimExpr cum_space_len = 1, cum_reduce_len = 1;
   for (const auto& iter : stage->iters) {
@@ -26,6 +34,47 @@ inline std::pair<PrimExpr, PrimExpr> GetCumulativeSpaceAndReductionLength_(const
   return std::make_pair(cum_space_len, cum_reduce_len);
 }
 
+#define LOG_TIME
+
+class Timer {
+ public:
+  Timer(std::string event_name)
+      : event_name(std::move(event_name)),
+        duration(),
+        now(std::chrono::high_resolution_clock::now()),
+        stopped(false) {}
+
+  void Start() {
+#ifdef LOG_TIME
+    this->now = std::chrono::high_resolution_clock::now();
+    this->stopped = false;
+#endif
+  }
+
+  void Stop() {
+#ifdef LOG_TIME
+    if (!this->stopped) {
+      this->duration += std::chrono::high_resolution_clock::now() - this->now;
+      this->stopped = true;
+    }
+#endif
+  }
+
+  ~Timer() {
+#ifdef LOG_TIME
+    Stop();
+    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(this->duration);
+    LOG_INFO << this->event_name << " -- " << duration.count() << " ms";
+#endif
+  }
+
+ private:
+  std::string event_name;
+  std::chrono::duration<double> duration;
+  std::chrono::time_point<std::chrono::high_resolution_clock> now;
+  bool stopped;
+};
+
 }  // namespace felix
 }  // namespace tvm
 
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 64142c621..d27160b82 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -31,34 +31,30 @@
 namespace tvm {
 namespace tir {
 
-#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name)                                                   \
-  Name::Name(PrimExpr a, PrimExpr b, Span span) {                                            \
-    using T = Name::ContainerType;                                                           \
-    ICHECK(a.defined()) << "ValueError: a is undefined\n";                                   \
-    ICHECK(b.defined()) << "ValueError: b is undefined\n";                                   \
-    CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \
-                                  << b.dtype() << "\n";                                      \
-    ObjectPtr<T> node = make_object<T>();                                                    \
-    node->dtype = a.dtype();                                                                 \
-    node->a = std::move(a);                                                                  \
-    node->b = std::move(b);                                                                  \
-    node->span = std::move(span);                                                            \
-    data_ = std::move(node);                                                                 \
+#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name)                 \
+  Name::Name(PrimExpr a, PrimExpr b, Span span) {          \
+    using T = Name::ContainerType;                         \
+    ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
+    ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
+    ObjectPtr<T> node = make_object<T>();                  \
+    node->dtype = a.dtype();                               \
+    node->a = std::move(a);                                \
+    node->b = std::move(b);                                \
+    node->span = std::move(span);                          \
+    data_ = std::move(node);                               \
   }
 
-#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name)                                                   \
-  Name::Name(PrimExpr a, PrimExpr b, Span span) {                                            \
-    using T = Name::ContainerType;                                                           \
-    ICHECK(a.defined()) << "ValueError: a is undefined\n";                                   \
-    ICHECK(b.defined()) << "ValueError: b is undefined\n";                                   \
-    CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \
-                                  << b.dtype() << "\n";                                      \
-    ObjectPtr<T> node = make_object<T>();                                                    \
-    node->dtype = DataType::Bool(a.dtype().lanes());                                         \
-    node->a = std::move(a);                                                                  \
-    node->b = std::move(b);                                                                  \
-    node->span = std::move(span);                                                            \
-    data_ = std::move(node);                                                                 \
+#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name)                 \
+  Name::Name(PrimExpr a, PrimExpr b, Span span) {          \
+    using T = Name::ContainerType;                         \
+    ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
+    ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
+    ObjectPtr<T> node = make_object<T>();                  \
+    node->dtype = DataType::Bool(a.dtype().lanes());       \
+    node->a = std::move(a);                                \
+    node->b = std::move(b);                                \
+    node->span = std::move(span);                          \
+    data_ = std::move(node);                               \
   }
 
 // Var
@@ -610,7 +606,6 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp
   ICHECK(false_value.defined()) << "ValueError: true_value is undefined";
   ICHECK(condition.dtype().is_bool());
   ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1);
-  ICHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types";
 
   ObjectPtr<SelectNode> node = make_object<SelectNode>();
   node->dtype = true_value.dtype();
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index d9d6093f2..1b4ed4c9e 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -372,8 +372,8 @@ PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span) { return floordiv(a, b, spa
 PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span) { return floormod(a, b, span); }
 
 PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
-  ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
-  ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
+  // ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
+  // ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
   BinaryOpMatchTypes(a, b, span);
   PrimExpr ret = arith::TryConstFold<tir::FloorDiv>(a, b);
   if (ret.defined()) return ret;
@@ -381,8 +381,8 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
 }
 
 PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) {
-  ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
-  ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
+  // ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
+  // ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
   BinaryOpMatchTypes(a, b, span);
   PrimExpr ret = arith::TryConstFold<tir::FloorMod>(a, b);
   if (ret.defined()) return ret;
@@ -612,8 +612,6 @@ TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span)
 
 // pow
 PrimExpr pow(PrimExpr x, PrimExpr y, Span span) {
-  BinaryOpMatchTypes(x, y, span);
-  ICHECK(x.dtype().is_float()) << "power only applies to float";
   static auto op = Op::Get("tir.pow");
   return tir::Call(x.dtype(), op, {x, y}, span);
 }
@@ -879,6 +877,8 @@ TIR_REGISTER_PURE_UNARY_OP("tir.atanh");
 
 TIR_REGISTER_PURE_UNARY_OP("tir.clz");
 
+TIR_REGISTER_PURE_UNARY_OP("tir.hump");
+
 // binary intrinsics
 TIR_REGISTER_PURE_BINARY_OP("tir.atan2");
 
@@ -890,6 +890,8 @@ TIR_REGISTER_PURE_BINARY_OP("tir.copysign");
 
 TIR_REGISTER_PURE_BINARY_OP("tir.ldexp");
 
+TIR_REGISTER_PURE_BINARY_OP("tir.logk");
+
 // expose basic functions to node namespace
 TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) {
   if (args[0].type_code() == kDLInt) {
-- 
GitLab