diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h
index 0ca14c43eb470da85193492868c2f71831eb6433..55c023507e9d410503cd486dd4db0ed535478546 100755
--- a/include/tvm/auto_scheduler/loop_state.h
+++ b/include/tvm/auto_scheduler/loop_state.h
@@ -350,7 +350,7 @@ class State : public ObjectRef {
    * most iterator of split results will become the new attach point.
    */
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
-                                const Array<Optional<Integer>>& lengths,
+                                const Array<PrimExpr>& lengths,
                                 bool inner_to_outer = true);
   /*!
    * \brief The schedule primitive similar to split, but uses split factors from previous steps.
diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h
index 4cc1551e76fcb33cfb0a9527610fde035204c8a6..45bc9f922e69d6d9932900f189ee5e42823fccdc 100755
--- a/include/tvm/auto_scheduler/transform_step.h
+++ b/include/tvm/auto_scheduler/transform_step.h
@@ -505,7 +505,7 @@ class SplitStepNode : public StepNode {
   /*! \brief The extent length of the axis to split. */
   Optional<PrimExpr> extent;
   /*! \brief The split factors. */
-  Array<Optional<Integer>> lengths;
+  Array<PrimExpr> lengths;
   /*!
    * \brief If true, the `lengths` denote the lengths of iterators
    * from inner level to outer level
@@ -561,7 +561,7 @@ class SplitStep : public Step {
    * \param inner_to_outer The split direction.
    */
   SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
-            const Array<Optional<Integer>>& lengths, bool inner_to_outer);
+            const Array<PrimExpr>& lengths, bool inner_to_outer);
 
   /*!
    * \brief The constructor used to read a step record from JSONReader and create the
@@ -591,7 +591,7 @@ class FollowSplitStepNode : public StepNode {
    * \param transform_steps An array of history transform steps.
    * \return The multiple split factors.
    */
-  Array<Optional<Integer>> ExtractSplitLengths(const Array<Step>& transform_steps) const;
+  Array<PrimExpr> ExtractSplitLengths(const Array<Step>& transform_steps) const;
 
   /*!
    * \brief Apply the current step to State.
@@ -672,7 +672,7 @@ class FollowFusedSplitStepNode : public StepNode {
    * \param transform_steps An array of history transform steps.
    * \return Split factor.
    */
-  Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
+  PrimExpr ExtractSplitLength(const Array<Step>& transform_steps) const;
 
   /*!
    * \brief Apply the current step to State.
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
deleted file mode 100755
index 2bb722781f2f577e77eecb0465a38179ecc2d45f..0000000000000000000000000000000000000000
--- a/src/auto_scheduler/feature.cc
+++ /dev/null
@@ -1,1677 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file auto_scheduler/feature.cc
- * \brief Feature extraction for the cost model
- */
-
-#include <tvm/arith/analyzer.h>
-#include <tvm/auto_scheduler/feature.h>
-#include <tvm/auto_scheduler/measure.h>
-#include <tvm/auto_scheduler/measure_record.h>
-#include <tvm/driver/driver_api.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/support/parallel_for.h>
-#include <tvm/te/operation.h>
-#include <tvm/te/schedule_pass.h>
-#include <tvm/tir/analysis.h>
-#include <tvm/tir/op_attr_types.h>
-#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
-
-#include <algorithm>
-#include <cmath>
-#include <numeric>
-#include <unordered_map>
-#include <vector>
-
-#include "search_policy/utils.h"
-#include "utils.h"
-
-namespace tvm {
-namespace auto_scheduler {
-
-using namespace tvm::tir;
-using arith::Analyzer;
-using arith::ConstIntBound;
-
-template <class T>
-using BufferMap = std::unordered_map<Buffer, T, ObjectHash, ObjectEqual>;
-
-// The number of samples to extract for arithmetic intensity curves
-static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10;
-
-// Annotation position encoding
-enum class AnnotationPosType : int {
-  kPosNone = 0,           // Does not have this kind of annotation
-  kPosInnerSpatial = 1,   // The annotated iterator is the innermost spatial iterator
-  kPosMiddleSpatial = 2,  // The annotated iterator is a middle spatial iterator
-  kPosOuterSpatial = 3,   // The annotated iterator is the outermost spatial iterator
-  kPosInnerReduce = 4,    // The annotated iterator is the innermost reduce iterator
-  kPosMiddleReduce = 5,   // The annotated iterator is a middle reduce iterator
-  kPosOuterReduce = 6,    // The annotated iterator is the outermost reduce iterator
-  kPosMixed = 7           // The annotated iterator is a mixed space and reduce iterator
-};
-
-// Buffer access type
-enum class BufferAccessType : int { kRead = 0, kWrite = 1, kReadWrite = 2, kUnknownRW = 3 };
-
-// Accesses to a buffer
-struct BufferAccess {
-  // data reuse type
-  BufferAccessType acc_type{BufferAccessType::kUnknownRW};
-  // Use a two-dimensional array to store multiple multi-dimensional accesses.
-  // The innermost vector stores the multi-dimensional indices of one access.
-  std::vector<std::vector<PrimExpr>> indices;
-};
-
-// Data reuse type
-enum class ReuseType : int { kLoopMultipleRead = 0, kSerialMultipleReadWrite = 1, kNoReuse = 2 };
-
-// Feature for an access of a buffer
-struct BufferAccessFeature {
-  std::string buffer_name;        // The name of the buffer
-  BufferAccessType acc_type;      // The type of the access
-  float bytes;                    // The touched memory in bytes
-  float unique_bytes;             // The touched unique memory in bytes
-  float lines;                    // The number of touched cache lines
-  float unique_lines;             // The number touched unique cache lines
-  ReuseType reuse_type;           // Tye type of data reuse
-  float reuse_dis_iter;           // The reuse distance in iterator number
-  float reuse_dis_bytes;          // The reuse distance in total touched bytes
-  float reuse_ct;                 // The reuse ratio
-  float bytes_d_reuse_ct;         // bytes / reuse_ct
-  float unique_bytes_d_reuse_ct;  // unique_bytes / reuse_ct
-  float lines_d_reuse_ct;         // lines / reuse_ct
-  float unique_lines_d_reuse_ct;  // unique_lines / reuse_ct
-  float stride;                   // The stride in access
-};
-
-// Feature set of a BufferStore statement
-struct FeatureSet {
-  // Group 1: Computation related features
-  float float_mad;                  // The number of float MAD (Multiply–add) ops
-  float float_addsub;               // The number of float add and sub ops
-  float float_mul;                  // The number of float multiply ops
-  float float_divmod;               // The number of float div and mod ops
-  float float_cmp;                  // The number of float comparison ops
-  float float_math_func;            // The number of float math func calls
-  float float_other_func;           // The number of other float func calls
-  float int_mad;                    // The number of integer MAD (Multiply–add) ops
-  float int_addsub;                 // The number of integer add and sub ops
-  float int_mul;                    // The number of float multiply ops
-  float int_divmod;                 // The number of float div and mod ops
-  float int_cmp;                    // The number of float comparison ops
-  float int_math_func;              // The number of float math func calls
-  float int_other_func;             // The number of other float func calls
-  float bool_op;                    // The number of bool ops
-  float select_op;                  // The number of select ops
-  float vec_num;                    // The number of vectorized iterators
-  float vec_prod;                   // The product of the lengths of vectorized iterators
-  float vec_len;                    // The length of the innermost vectorized iterator
-  AnnotationPosType vec_type;       // The type of vectorization position
-  float unroll_num;                 // The number of unrolled iterators
-  float unroll_prod;                // The product of the lengths of vectorized iterators
-  float unroll_len;                 // The length of the innermost unrolled iterator
-  AnnotationPosType unroll_type;    // The type of unroll position
-  float parallel_num;               // The number of paralleled iterators
-  float parallel_prod;              // The product of the lengths of paralleled iterators
-  float parallel_len;               // The length of the innermost paralleled iterators
-  AnnotationPosType parallel_type;  // The type of parallel position
-  float is_gpu;                     // Whether it is a GPU task
-  float blockIdx_x_len;             // The length of blockIdx.x
-  float blockIdx_y_len;             // The length of blockIdx.y
-  float blockIdx_z_len;             // The length of blockIdx.z
-  float threadIdx_x_len;            // The length of threadIdx.x
-  float threadIdx_y_len;            // The length of threadIdx.y
-  float threadIdx_z_len;            // The length of threadIdx.z
-  float vthread_len;                // The length of virtual thread
-
-  // Group 2: Buffer access related features (per buffer)
-  std::vector<BufferAccessFeature> access_feas;
-
-  // Group 3: Arithmetic intensity related features
-  float arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N];  // points sampled from the
-                                                                // arithmetic intensity curve
-
-  // Group 4: Allocation related features
-  float alloc_size;        // The size of allocated buffer in bytes
-  float alloc_outer_prod;  // The product of lengths of loops outside the scope of the allocation
-  float alloc_inner_prod;  // The product of lengths of loops inside the score of the allocation
-  float alloc_prod;        // alloc_outer_prod * alloc_inner_prod
-
-  // Group 5: Outer scope related features
-  float outer_prod;            // The product of lengths of outer loops
-  float num_loops;             // The number of outer loops
-  float auto_unroll_max_step;  // The value of pragma "auto_unroll_max_step"
-};
-
-// Return whether a var is in an expr
-bool VarInExpr(const Var& var, const PrimExpr& expr) {
-  bool find = false;
-
-  PostOrderVisit(expr, [&find, &var](const ObjectRef& node) {
-    if (find) {
-      return;
-    }
-
-    if (const VarNode* op = node.as<VarNode>()) {
-      if (op == var.get()) {
-        find = true;
-      }
-    }
-  });
-
-  return find;
-}
-
-// Get position encoding for annotation
-AnnotationPosType GetAnnotationPosEncoding(const Var& var, const Array<PrimExpr>& spatial_args,
-                                           const Array<IterVar>& axis,
-                                           const Array<IterVar>& reduce_axis) {
-  // Try to match spatial args first
-  size_t find_i = 0;
-  size_t find_ct = 0;
-  for (size_t i = 0; i < spatial_args.size(); ++i) {
-    if (VarInExpr(var, spatial_args[i])) {
-      find_i = i;
-      find_ct += 1;
-    }
-  }
-
-  if (find_ct == 0) {
-    // If it is not found in spacial args, then it is a reduce iterator.
-    // Use name to match
-    const std::string& var_name = var->name_hint;
-    for (size_t i = 0; i < reduce_axis.size(); ++i) {
-      if (var_name.find(reduce_axis[i]->var->name_hint) != std::string::npos) {
-        find_i = i;
-        find_ct++;
-      }
-    }
-    if (find_ct >= 1) {
-      if (find_i == 0) {
-        return AnnotationPosType::kPosInnerReduce;
-      } else if (find_i == reduce_axis.size() - 1) {
-        return AnnotationPosType::kPosOuterReduce;
-      } else {
-        return AnnotationPosType::kPosMiddleReduce;
-      }
-    } else {
-      // If the axis is not found in both spatial args and reduce axis,
-      // then this stage must compute_at somewhere under this axis and this axis is simplified out
-      // We assume it is an outer spatial
-      return AnnotationPosType::kPosOuterSpatial;
-    }
-  } else if (find_ct == 1) {
-    if (find_i == spatial_args.size() - 1) {
-      return AnnotationPosType::kPosInnerSpatial;
-    } else if (find_i == 0) {
-      return AnnotationPosType::kPosOuterSpatial;
-    } else {
-      return AnnotationPosType::kPosMiddleSpatial;
-    }
-  } else {
-    return AnnotationPosType::kPosMixed;
-  }
-}
-
-// Return the extent of a for loop
-int64_t GetLoopExtent(const ForNode* node) {
-  auto pint = node->extent.as<IntImmNode>();
-  if (pint != nullptr) {
-    return pint->value;
-  } else {
-    return 1;
-  }
-}
-
-// Count math ops in an expr
-class MathOpCounter : public StmtExprVisitor {
- public:
-#define VisitBinary(Type, float_ct, int_ct) \
-  void VisitExpr_(const Type* op) final {   \
-    if (op->a.dtype().is_float()) {         \
-      float_ct++;                           \
-    } else {                                \
-      int_ct++;                             \
-    }                                       \
-    StmtExprVisitor::VisitExpr_(op);        \
-  }
-
-  VisitBinary(AddNode, float_addsub, int_addsub);
-  VisitBinary(SubNode, float_addsub, int_addsub);
-  VisitBinary(MulNode, float_mul, int_mul);
-  VisitBinary(DivNode, float_divmod, int_divmod);
-  VisitBinary(ModNode, float_divmod, int_divmod);
-  VisitBinary(FloorDivNode, float_divmod, int_divmod);
-  VisitBinary(FloorModNode, float_divmod, int_divmod);
-  VisitBinary(MaxNode, float_cmp, int_cmp);
-  VisitBinary(MinNode, float_cmp, int_cmp);
-  VisitBinary(EQNode, float_cmp, int_cmp);
-  VisitBinary(NENode, float_cmp, int_cmp);
-  VisitBinary(LTNode, float_cmp, int_cmp);
-  VisitBinary(LENode, float_cmp, int_cmp);
-  VisitBinary(GTNode, float_cmp, int_cmp);
-  VisitBinary(GENode, float_cmp, int_cmp);
-
-#undef VisitBinary
-
-  void VisitExpr_(const AndNode* op) final {
-    bool_op++;
-    StmtExprVisitor::VisitExpr_(op);
-  }
-  void VisitExpr_(const OrNode* op) final {
-    bool_op++;
-    StmtExprVisitor::VisitExpr_(op);
-  }
-  void VisitExpr_(const NotNode* op) final {
-    bool_op++;
-    StmtExprVisitor::VisitExpr_(op);
-  }
-  void VisitExpr_(const SelectNode* op) final {
-    select_op++;
-    StmtExprVisitor::VisitExpr_(op);
-  }
-
-  void VisitExpr_(const CallNode* op) final {
-    auto* pop = op->op.as<OpNode>();
-    ICHECK(pop != nullptr);
-    auto effect_kind = op_call_effect_[GetRef<Op>(pop)];
-    bool is_pure =
-        effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation;
-
-    if (is_pure) {
-      if (op->dtype.is_float()) {
-        float_math_func++;
-      } else {
-        int_math_func++;
-      }
-    } else {
-      if (op->dtype.is_float()) {
-        float_other_func++;
-      } else {
-        int_other_func++;
-      }
-    }
-    StmtExprVisitor::VisitExpr_(op);
-  }
-
-  // todo(merrymercy): Detect MAD (Multiply–add)
-  size_t float_mad{0};         // The number of float MAD (Multiply–add) ops
-  size_t float_addsub{0};      // The number of float add and sub ops
-  size_t float_mul{0};         // The number of float multiply ops
-  size_t float_divmod{0};      // The number of float div and mod ops
-  size_t float_cmp{0};         // The number of float comparison ops
-  size_t float_math_func{0};   // The number of float math func calls
-  size_t float_other_func{0};  // The number of other float func calls
-  size_t int_mad{0};           // The number of integer MAD (Multiply–add) ops
-  size_t int_addsub{0};        // The number of integer add and sub ops
-  size_t int_mul{0};           // The number of float multiply ops
-  size_t int_divmod{0};        // The number of float div and mod ops
-  size_t int_cmp{0};           // The number of float comparison ops
-  size_t int_math_func{0};     // The number of float math func calls
-  size_t int_other_func{0};    // The number of other float func calls
-  size_t bool_op{0};           // The number of bool ops
-  size_t select_op{0};         // The number of select ops
-
-  OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
-};
-
-// Extract all buffer accesses in an expr
-class BufferAccessExtractor : public StmtExprVisitor {
- public:
-  void ExtractReads(const PrimExpr& expr) { this->VisitExpr(expr); }
-
-  void InsertAccess(const Buffer& buf, BufferAccessType acc_type, const Array<PrimExpr>& indices) {
-    BufferAccess& acc = buf_accesses[buf];
-    acc.acc_type = acc_type;
-    acc.indices.push_back(std::vector<PrimExpr>(indices.begin(), indices.end()));
-  }
-
-  void VisitExpr_(const BufferLoadNode* op) final {
-    BufferAccess& acc = buf_accesses[op->buffer];
-    switch (acc.acc_type) {
-      case BufferAccessType::kRead:
-        break;
-      case BufferAccessType::kWrite:
-        acc.acc_type = BufferAccessType::kReadWrite;
-        break;
-      case BufferAccessType::kReadWrite:
-        break;
-      case BufferAccessType::kUnknownRW:
-      default:
-        acc.acc_type = BufferAccessType::kRead;
-        break;
-    }
-
-    if (acc.acc_type != BufferAccessType::kReadWrite) {
-      // If a buffer is both read and written, in the tvm DSL, it must be a update,
-      // so the indices should be the same. Then we can skip appending indices for it.
-      // Otherwise we do the following.
-      buf_accesses[op->buffer].indices.push_back(
-          std::vector<PrimExpr>(op->indices.begin(), op->indices.end()));
-    }
-    StmtExprVisitor::VisitExpr_(op);
-  }
-
-  BufferMap<BufferAccess> buf_accesses;
-};
-
-// Compute the coefficient for an loop iterator in an expression
-// Note: we use an approximation strategy to find coefficient.
-// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear)
-class CoefficientExtractor : public StmtExprVisitor {
- public:
-  void VisitExpr_(const MulNode* node) final {
-    StmtExprVisitor::VisitExpr_(node);
-    if (visited_var) {
-      if (!visited_add) {
-        if (auto a = node->a.as<IntImmNode>()) {
-          visited_mul = true;
-          stride = a->value;
-        } else if (auto b = node->b.as<IntImmNode>()) {
-          visited_mul = true;
-          stride = b->value;
-        }
-      }
-    }
-  }
-
-  void VisitExpr_(const AddNode* node) final {
-    StmtExprVisitor::VisitExpr_(node);
-    if (visited_var) {
-      if (!visited_mul) {
-        visited_add = true;
-        stride = 1;
-      }
-    }
-  }
-
-  void VisitExpr_(const VarNode* node) final {
-    if (node == var_) {
-      visited_var = true;
-      // This is a magic default stride in case our approximation strategy fails
-      stride = 2;
-    }
-  }
-
-  int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) {
-    visited_var = visited_mul = visited_add = false;
-    var_ = var;
-
-    this->VisitExpr(expr);
-
-    if (visited_var && !visited_mul && !visited_add) {
-      return 1;
-    } else {
-      return stride;
-    }
-  }
-
-  bool visited_var{false};
-  bool visited_mul{false};
-  bool visited_add{false};
-  int stride{0};
-
- private:
-  const VarNode* var_{nullptr};
-};
-
-// Compute stride for the accesses to a buffer
-int64_t ComputeStride(const std::vector<std::vector<PrimExpr>>& indices,
-                      const std::vector<int>& shape, const VarNode* stride_var) {
-  int64_t min_stride = std::numeric_limits<int64_t>::max();
-  bool find = false;
-  CoefficientExtractor extractor;
-
-  for (const auto& index : indices) {
-    int64_t shape_stride = 1;
-    for (int i = static_cast<int>(index.size()) - 1; i >= 0; i--) {
-      int coefficient = extractor.ExtractCoefficient(index[i], stride_var);
-      if (extractor.visited_var) {
-        find = true;
-        min_stride = std::min(min_stride, std::abs(coefficient) * shape_stride);
-        break;
-      }
-      shape_stride *= shape[i];
-    }
-  }
-
-  return find ? min_stride : 0;
-}
-
-// Compute touched bytes and cache lines for accesses to a buffer
-void ComputeRegion(const std::vector<std::vector<PrimExpr>>& indices, arith::Analyzer* ana,
-                   std::vector<int>* region) {
-  region->clear();
-
-  if (indices.empty()) {
-    return;
-  }
-
-  region->reserve(indices[0].size());
-
-  if (indices.size() == 1) {
-    for (const auto& index : indices[0]) {
-      ConstIntBound bound = ana->const_int_bound(index);
-      region->push_back(bound->max_value - bound->min_value + 1);
-    }
-  } else {
-    // future(lmzheng): implement a more accurate IntSet?
-    for (size_t i = 0; i < indices[0].size(); ++i) {
-      int64_t minimum = ConstIntBound::kPosInf, maximum = ConstIntBound::kNegInf;
-      for (size_t j = 0; j < indices.size(); ++j) {
-        ConstIntBound bound = ana->const_int_bound(indices[j][i]);
-
-        minimum = std::min(minimum, bound->min_value);
-        maximum = std::max(maximum, bound->max_value);
-      }
-      region->push_back(maximum - minimum + 1);
-    }
-  }
-}
-
-// Compute reuse distance and reuse ratio for accesses to a buffer
-// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct
-std::tuple<ReuseType, float, float, float> ComputeReuse(
-    const Buffer& buf, const std::vector<std::vector<PrimExpr>>& indices,
-    const std::vector<const ForNode*>& for_loop_stack,
-    const std::unordered_map<const ForNode*,
-                             BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>>&
-        for_touch_regions) {
-  float reuse_dis_iter = 1.0f;
-  float reuse_dis_bytes = -1.0f;
-
-  for (int i = static_cast<int>(for_loop_stack.size()) - 1; i >= 0; --i) {
-    const ForNode* cur_for = for_loop_stack[i];
-    bool find = false;
-
-    for (size_t j = 0; j < indices.size(); j++) {
-      for (size_t k = 0; k < indices[j].size(); k++) {
-        if (VarInExpr(cur_for->loop_var, indices[j][k])) {
-          find = true;
-          break;
-        }
-      }
-      if (find) {
-        break;
-      }
-    }
-
-    int64_t extent = GetLoopExtent(for_loop_stack[i]);
-    if (find) {
-      // accumulate/update reuse distance
-      reuse_dis_iter *= extent;
-      reuse_dis_bytes = 0.0f;
-      for (const auto& iter : for_touch_regions.at(cur_for)) {
-        for (const auto& access : iter.second) {
-          reuse_dis_bytes += std::get<1>(access) * std::get<2>(access);
-        }
-      }
-    } else {
-      // Have LoopMultipleRead reuse
-      if (reuse_dis_bytes < 0) {
-        // For the reuse in the innermost axis, the above code won't be executed.
-        // So we compute bytes here
-        reuse_dis_bytes = 0.0f;
-        for (const auto& iter : for_touch_regions.at(cur_for)) {
-          for (const auto& access : iter.second) {
-            reuse_dis_bytes += 1 * std::get<2>(access);
-          }
-        }
-      }
-      return std::make_tuple(ReuseType::kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent);
-    }
-
-    const BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>& buffer_map =
-        for_touch_regions.at(cur_for);
-
-    int serial_reuse = static_cast<int>(buffer_map.at(buf).size()) - 1;
-    if (serial_reuse > 0) {
-      int64_t extent = GetLoopExtent(cur_for);
-
-      // Have SerialMultipleReadWrite reuse
-      reuse_dis_iter = std::numeric_limits<float>::max();
-      for (const auto& acc_info : buffer_map.at(buf)) {
-        reuse_dis_iter = std::min(reuse_dis_iter, static_cast<float>(std::get<1>(acc_info)));
-      }
-
-      reuse_dis_bytes = 0.0f;
-      for (const auto& iter : for_touch_regions.at(cur_for)) {
-        for (const auto& access : iter.second) {
-          reuse_dis_bytes += std::get<1>(access) * std::get<2>(access);
-        }
-      }
-
-      return std::make_tuple(ReuseType::kSerialMultipleReadWrite, reuse_dis_iter / extent,
-                             reuse_dis_bytes / extent, serial_reuse);
-    }
-  }
-
-  return std::make_tuple(ReuseType::kNoReuse, 0, 0, 0);
-}
-
-class StateFeaturesNode : public Object {
- public:
-  Stmt state_code;
-  Array<BufferStore> buffers;
-  tvm::runtime::NDArray features;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("state_code", &state_code);
-    v->Visit("buffers", &buffers);
-    v->Visit("features", &features);
-  }
-
-  void add_buffer_and_features(const BufferStore& bufNode, std::vector<size_t> features_) {
-    this->buffers.push_back(bufNode);
-  }
-
-  static constexpr const char* _type_key = "ansor.StateFeatures";
-  TVM_DECLARE_BASE_OBJECT_INFO(StateFeaturesNode, Object);
-};
-
-constexpr size_t NFeatures = 17;
-using Features = std::array<size_t, NFeatures>;
-
-class StateFeatures : public ObjectRef {
- public:
-  explicit StateFeatures(
-      const std::vector<std::pair<BufferStore, Features>>& buffers_and_features) {
-    auto node = make_object<StateFeaturesNode>();
-    size_t n = buffers_and_features.size();
-    node->features = tvm::runtime::NDArray::Empty({(int64_t)n, NFeatures},
-                                                  DLDataType{kDLFloat, 32, 1}, {kDLCPU, 0});
-    for (size_t i = 0; i < n; i++) {
-      const auto& [buf, features] = buffers_and_features[i];
-      node->buffers.push_back(buf);
-      for (size_t j = 0; j < NFeatures; ++j)
-        static_cast<float*>(node->features->data)[i * NFeatures + j] = (float)features[j];
-    }
-    data_ = std::move(node);
-  }
-
-  TVM_DEFINE_OBJECT_REF_METHODS(StateFeatures, ObjectRef, StateFeaturesNode);
-  TVM_DEFINE_OBJECT_REF_COW_METHOD(StateFeaturesNode);
-};
-
-TVM_REGISTER_NODE_TYPE(StateFeaturesNode);
-
-// Extract features for every BufferStore statement
-class PerStoreFeatureExtractor : public StmtExprVisitor {
- public:
-  explicit PerStoreFeatureExtractor(int cache_line_size) : cache_line_size_(cache_line_size) {}
-
-  void VisitStmt_(const AttrStmtNode* node) final {
-    if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) {
-      const Var& var = node->node.as<IterVarNode>()->var;
-      int extent = GetIntImm(node->value);
-
-      int* plen = nullptr;
-
-      const std::string& name = var.get()->name_hint;
-      if (node->attr_key == tir::attr::thread_extent) {
-        if (name == "blockIdx.x") {
-          plen = &blockIdx_x_len_;
-        } else if (name == "blockIdx.y") {
-          plen = &block_idx_y_len_;
-        } else if (name == "blockIdx.z") {
-          plen = &block_idx_z_len_;
-        } else if (name == "threadIdx.x") {
-          plen = &threadIdx_x_len_;
-        } else if (name == "threadIdx.y") {
-          plen = &thread_idx_y_len_;
-        } else if (name == "threadIdx.z") {
-          plen = &thread_idx_z_len_;
-        } else {
-          LOG(FATAL) << "invalid thread itervar " + name;
-        }
-      } else {
-        plen = &vthread_len_;
-      }
-
-      int extent_before = *plen;
-      if (node->attr_key == tir::attr::thread_extent) {
-        *plen = extent;
-      } else {
-        *plen *= extent;
-      }
-
-      is_gpu_ = true;
-
-      // make a fake for node for blockIdx.x or threadIdx.x
-      Stmt fake_for_node = For(var, 0, extent, ForKind::kParallel, node->body);
-
-      outer_loop_prod_ *= extent;
-      for_loop_stack_.push_back(fake_for_node.as<ForNode>());
-      StmtExprVisitor::VisitStmt_(node);
-      for_loop_stack_.pop_back();
-      outer_loop_prod_ /= extent;
-
-      *plen = extent_before;
-    } else if (node->attr_key == "pragma_auto_unroll_max_step") {
-      int value = GetIntImm(node->value);
-
-      int16_t old_value = cur_auto_unroll_max_step_;
-      cur_auto_unroll_max_step_ = value;
-      StmtExprVisitor::VisitStmt_(node);
-      cur_auto_unroll_max_step_ = old_value;
-    } else {
-      StmtExprVisitor::VisitStmt_(node);
-    }
-  }
-
-  void VisitStmt_(const ForNode* node) final {
-    int64_t loop_extent = GetLoopExtent(node);
-
-    if (node->kind == ForKind::kVectorized) {
-      vec_for_stack_.push_back(node);
-    } else if (node->kind == ForKind::kUnrolled) {
-      unroll_for_stack_.push_back(node);
-    } else if (node->kind == ForKind::kParallel) {
-      parallel_for_stack_.push_back(node);
-    }
-
-    outer_loop_prod_ *= loop_extent;
-    for_loop_stack_.push_back(node);
-    StmtExprVisitor::VisitStmt_(node);
-    for_loop_stack_.pop_back();
-    outer_loop_prod_ /= loop_extent;
-
-    if (node->kind == ForKind::kVectorized) {
-      vec_for_stack_.pop_back();
-    } else if (node->kind == ForKind::kUnrolled) {
-      unroll_for_stack_.pop_back();
-    } else if (node->kind == ForKind::kParallel) {
-      parallel_for_stack_.pop_back();
-    }
-  }
-
-  void VisitStmt_(const BufferStoreNode* node) final {
-    MathOpCounter math_op_counter;
-    math_op_counter(node->value);
-
-    BufferStore store(node->buffer, node->value, node->indices, node->span);
-    std::array<size_t, 17> features = {
-        math_op_counter.float_mad,        math_op_counter.float_addsub,
-        math_op_counter.float_mul,        math_op_counter.float_divmod,
-        math_op_counter.float_cmp,        math_op_counter.float_math_func,
-        math_op_counter.float_other_func, math_op_counter.int_mad,
-        math_op_counter.int_addsub,       math_op_counter.int_mul,
-        math_op_counter.int_divmod,       math_op_counter.int_math_func,
-        math_op_counter.int_cmp,          math_op_counter.int_other_func,
-        math_op_counter.bool_op,          math_op_counter.select_op,
-        (size_t)this->outer_loop_prod_};
-    this->bufstore_and_feats.emplace_back(store, features);
-
-    std::vector<float> mem_bytes_list;
-    std::vector<float> compute_ops_list;
-    double cur_compute_ops;
-
-    // Group 1: Computation related features
-    ExtractComputationFeature(node, math_op_counter);
-
-    // Group 2: Buffer access related features (per buffer)
-    ExtractBufferAccessFeature(node, math_op_counter, &cur_compute_ops, &compute_ops_list,
-                               &mem_bytes_list);
-
-    // Group 3: Arithmetic intensity related features
-    ExtractArithmeticIntensityFeature(node, cur_compute_ops, compute_ops_list, mem_bytes_list);
-
-    // Group 4: Allocation related features
-    ExtractOuterScopeFeature(node);
-  }
-
-  void VisitStmt_(const BufferRealizeNode* node) final {
-    StmtExprVisitor::VisitStmt_(node);
-
-    // Group 5: Outer scope related features
-    ExtractAllocationFeature(node);
-  }
-
-  // Extract computation related features (group 1)
-  void ExtractComputationFeature(const BufferStoreNode* node,
-                                 const MathOpCounter& math_op_counter) {
-    FeatureSet& fea = buffer_features[node->buffer];
-
-    // Computation related features
-    fea.float_mad = outer_loop_prod_ * math_op_counter.float_mad;
-    fea.float_addsub = outer_loop_prod_ * math_op_counter.float_addsub;
-    fea.float_mul = outer_loop_prod_ * math_op_counter.float_mul;
-    fea.float_divmod = outer_loop_prod_ * math_op_counter.float_divmod;
-    fea.float_cmp = outer_loop_prod_ * math_op_counter.float_cmp;
-    fea.float_math_func = outer_loop_prod_ * math_op_counter.float_math_func;
-    fea.float_other_func = outer_loop_prod_ * math_op_counter.float_other_func;
-    fea.int_mad = outer_loop_prod_ * math_op_counter.int_mad;
-    fea.int_addsub = outer_loop_prod_ * math_op_counter.int_addsub;
-    fea.int_mul = outer_loop_prod_ * math_op_counter.int_mul;
-    fea.int_divmod = outer_loop_prod_ * math_op_counter.int_divmod;
-    fea.int_math_func = outer_loop_prod_ * math_op_counter.int_math_func;
-    fea.int_cmp = outer_loop_prod_ * math_op_counter.int_cmp;
-    fea.int_other_func = outer_loop_prod_ * math_op_counter.int_other_func;
-    fea.bool_op = outer_loop_prod_ * math_op_counter.bool_op;
-    fea.select_op = outer_loop_prod_ * math_op_counter.select_op;
-
-    fea.vec_len = fea.unroll_len = fea.parallel_len = 0.0f;
-    fea.vec_type = fea.unroll_type = fea.parallel_type = AnnotationPosType::kPosNone;
-
-    fea.vec_num = vec_for_stack_.size();
-    if (!vec_for_stack_.empty()) {
-      fea.vec_len = GetLoopExtent(vec_for_stack_.back());
-      fea.vec_prod = 1.0;
-      for (const ForNode* pfor : vec_for_stack_) {
-        fea.vec_prod *= GetLoopExtent(pfor);
-      }
-      fea.vec_type = AnnotationPosType::kPosMixed;
-      // todo(merrymercy): this feature requires operation (tvm.compute) information
-      // GetAnnotationPosEncoding(vec_for_stack_.back()->loop_var,
-      // node->args, pcompute->axis, pcompute->reduce_axis);
-    }
-
-    fea.unroll_num = unroll_for_stack_.size();
-    if (!unroll_for_stack_.empty()) {
-      fea.unroll_len = GetLoopExtent(unroll_for_stack_.back());
-      fea.unroll_prod = 1.0;
-      for (const ForNode* pfor : unroll_for_stack_) {
-        fea.unroll_prod *= GetLoopExtent(pfor);
-      }
-      fea.unroll_type = AnnotationPosType::kPosMixed;
-      // GetAnnotationPosEncoding(unroll_for_stack_.back()->loop_var,
-      // node->args, pcompute->axis, pcompute->reduce_axis);
-    }
-
-    fea.parallel_num = parallel_for_stack_.size();
-    if (!parallel_for_stack_.empty()) {
-      fea.parallel_len = GetLoopExtent(parallel_for_stack_.back());
-      fea.parallel_prod = 1.0;
-      for (const ForNode* pfor : parallel_for_stack_) {
-        fea.parallel_prod *= GetLoopExtent(pfor);
-      }
-      fea.parallel_type = AnnotationPosType::kPosMixed;
-      // GetAnnotationPosEncoding(parallel_for_stack_.back()->loop_var,
-      // node->args, pcompute->axis, pcompute->reduce_axis);
-    }
-
-    // GPU threads
-    fea.is_gpu = is_gpu_;
-    fea.blockIdx_x_len = blockIdx_x_len_;
-    fea.blockIdx_y_len = block_idx_y_len_;
-    fea.blockIdx_z_len = block_idx_z_len_;
-    fea.threadIdx_x_len = threadIdx_x_len_;
-    fea.threadIdx_y_len = thread_idx_y_len_;
-    fea.threadIdx_z_len = thread_idx_z_len_;
-    fea.vthread_len = vthread_len_;
-  }
-
-  // Extract buffer access related features (group 2)
-  void ExtractBufferAccessFeature(const BufferStoreNode* node, const MathOpCounter& math_op_counter,
-                                  double* cur_compute_ops, std::vector<float>* compute_ops_list,
-                                  std::vector<float>* mem_bytes_list) {
-    FeatureSet& fea = buffer_features[node->buffer];
-
-    // Extract all buffer accesses
-    std::vector<BufferAccessFeature> acc_feas;
-    BufferAccessExtractor buf_extractor;
-    buf_extractor.InsertAccess(node->buffer, BufferAccessType::kWrite, node->indices);
-    buf_extractor.ExtractReads(node->value);
-
-    // Compute touched region for all outer loops
-    for (auto x : for_loop_stack_) {
-      ana_.Bind(x->loop_var, Range::FromMinExtent(x->min, 1), true);
-    }
-
-    mem_bytes_list->reserve(for_loop_stack_.size());
-    compute_ops_list->reserve(for_loop_stack_.size());
-
-    *cur_compute_ops = math_op_counter.float_mad + math_op_counter.float_addsub +
-                       math_op_counter.float_mul + math_op_counter.float_divmod +
-                       math_op_counter.float_cmp + math_op_counter.float_math_func +
-                       math_op_counter.float_other_func;
-
-    std::vector<int> tmp_region;
-    for (int i = static_cast<int>(for_loop_stack_.size()) - 1; i >= 0; i--) {
-      const ForNode* p_for = for_loop_stack_[i];
-
-      ana_.Bind(p_for->loop_var,
-                Range::FromMinExtent(for_loop_stack_[i]->min, for_loop_stack_[i]->extent), true);
-
-      // Note, here we do overwrite.
-      // So if there are multiple BufferStoreNode, the last one will overwrite the first few.
-      // e.g. The update part in gemm will overwrite the init part.
-      BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>& buffer_regions_map =
-          for_touch_regions_[p_for];
-
-      int64_t mem_bytes = 0;
-      for (const auto& x : buf_extractor.buf_accesses) {
-        const Buffer& t = x.first;
-        const BufferAccess& acc = x.second;
-
-        ComputeRegion(acc.indices, &ana_, &tmp_region);
-        int64_t touched_size = ElementProduct(tmp_region);
-        buffer_regions_map[t].push_back(
-            std::make_tuple(acc.acc_type, touched_size, t->dtype.bytes()));
-        mem_bytes += touched_size * t->dtype.bytes();
-      }
-
-      mem_bytes_list->push_back(std::log2(mem_bytes));
-      *cur_compute_ops *= GetLoopExtent(for_loop_stack_[i]);
-      compute_ops_list->push_back(std::log2(*cur_compute_ops));
-    }
-
-    //  Buffer access related features (per buffer)
-    for (const auto& x : buf_extractor.buf_accesses) {
-      const Buffer& t = x.first;
-      const BufferAccess& acc = x.second;
-
-      std::vector<int> int_shape;
-      for (const auto& dim : t->shape) {
-        int_shape.push_back(GetIntImm(dim));
-      }
-
-      size_t ele_bytes = t->dtype.bytes();
-
-      // calculate bytes
-      float bytes = outer_loop_prod_ * ele_bytes;
-      float unique_bytes;
-
-      // calculate cache lines
-      int64_t stride;
-      float lines;
-      float unique_lines;
-
-      if (for_loop_stack_.empty()) {
-        unique_bytes = ele_bytes;
-        stride = 0;
-        lines = 1.0f;
-        unique_lines = 1.0f;
-      } else {
-        unique_bytes =
-            std::get<1>(for_touch_regions_[for_loop_stack_.front()][t].front()) * ele_bytes;
-
-        stride = 0;
-        int64_t reduce_ratio = 1;
-
-        int i;
-        for (i = static_cast<int>(for_loop_stack_.size()) - 1; i >= 0; i--) {
-          stride = ComputeStride(acc.indices, int_shape, for_loop_stack_[i]->loop_var.get());
-          if (stride != 0) {
-            break;
-          }
-          reduce_ratio *= GetLoopExtent(for_loop_stack_.back());
-        }
-
-        lines = outer_loop_prod_ / reduce_ratio *
-                std::min(1.0f, 1.0f * stride * ele_bytes / cache_line_size_);
-        lines = std::max(lines, 1.0f);
-
-        // convert `stride` back to the stride of the innermost iterator
-        stride = (i == static_cast<int>(for_loop_stack_.size()) - 1 ? stride : 0);
-
-        float n_continuous = ele_bytes;
-        for (int i = std::min(static_cast<int>(tmp_region.size()) - 1,
-                              static_cast<int>(int_shape.size()) - 1);
-             i >= 0; i--) {
-          if (tmp_region[i] == int_shape[i]) {
-            n_continuous *= tmp_region[i];
-            break;
-          }
-        }
-        unique_lines = unique_bytes / std::min(n_continuous, static_cast<float>(cache_line_size_));
-        unique_lines = std::max(unique_lines, 1.0f);
-      }
-
-      ReuseType reuse_type;
-      float reuse_dis_iter, reuse_dis_bytes, reuse_ct;
-      std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) =
-          ComputeReuse(t, acc.indices, for_loop_stack_, for_touch_regions_);
-
-      acc_feas.emplace_back();
-      BufferAccessFeature& acc_fea = acc_feas.back();
-
-      acc_fea.buffer_name = t->name;
-      acc_fea.acc_type = acc.acc_type;
-      acc_fea.stride = stride;
-      acc_fea.bytes = bytes;
-      acc_fea.unique_bytes = unique_bytes;
-      acc_fea.lines = lines;
-      acc_fea.unique_lines = unique_lines;
-      acc_fea.reuse_type = reuse_type;
-      acc_fea.reuse_dis_iter = reuse_dis_iter;
-      acc_fea.reuse_dis_bytes = reuse_dis_bytes;
-      acc_fea.reuse_ct = reuse_ct;
-      if (acc_fea.reuse_ct > 0.5) {
-        acc_fea.bytes_d_reuse_ct = bytes / reuse_ct;
-        acc_fea.unique_bytes_d_reuse_ct = unique_bytes / reuse_ct;
-        acc_fea.lines_d_reuse_ct = lines / reuse_ct;
-        acc_fea.unique_lines_d_reuse_ct = unique_lines / reuse_ct;
-      } else {
-        // no reuse, multiply by a magic number '2'
-        acc_fea.bytes_d_reuse_ct = bytes * 2;
-        acc_fea.unique_bytes_d_reuse_ct = unique_bytes * 2;
-        acc_fea.lines_d_reuse_ct = lines * 2;
-        acc_fea.unique_lines_d_reuse_ct = unique_lines * 2;
-      }
-    }
-
-    fea.access_feas = acc_feas;
-  }
-
-  // Extract arithmetic intensity related feature (group 3)
-  void ExtractArithmeticIntensityFeature(const BufferStoreNode* node, double cur_compute_ops,
-                                         const std::vector<float>& compute_ops_list,
-                                         const std::vector<float>& mem_bytes_list) {
-    FeatureSet& fea = buffer_features[node->buffer];
-
-    // Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops).
-    // We use piecewise linear interpolation to fit this curve.
-    int pt = 0;
-    if (cur_compute_ops <= 0 || compute_ops_list.empty()) {
-      std::fill(fea.arith_intensity_curve,
-                fea.arith_intensity_curve + ARITH_INTENSITY_CURVE_SAMPLE_N, 0.0);
-    } else {
-      for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) {
-        float cur_compute_ops = compute_ops_list.back() * (i + 1) / ARITH_INTENSITY_CURVE_SAMPLE_N;
-        while (compute_ops_list[pt] < cur_compute_ops - 1e-4) {
-          pt++;
-        }
-        ICHECK_LT(pt, compute_ops_list.size());
-
-        float value;
-        if (pt == 0) {
-          value = compute_ops_list[pt] / mem_bytes_list[pt];
-        } else {
-          float base = compute_ops_list[pt - 1] / mem_bytes_list[pt - 1];
-          float slope = (compute_ops_list[pt] / mem_bytes_list[pt] -
-                         compute_ops_list[pt - 1] / mem_bytes_list[pt - 1]) /
-                        (compute_ops_list[pt] - compute_ops_list[pt - 1]);
-          value = base + slope * (cur_compute_ops - compute_ops_list[pt - 1]);
-        }
-        fea.arith_intensity_curve[i] = value;
-      }
-    }
-  }
-
-  // Extract allocation related features (group 4)
-  void ExtractAllocationFeature(const BufferRealizeNode* node) {
-    FeatureSet& fea = buffer_features[node->buffer];
-
-    float allocation_size = 1.0f;
-    for (const auto& x : node->bounds) {
-      allocation_size *= GetIntImm(x->extent);
-    }
-    // allocation feature
-    fea.alloc_size = allocation_size * node->buffer->dtype.bytes();
-    fea.alloc_prod = allocation_size * outer_loop_prod_;
-    fea.alloc_outer_prod = outer_loop_prod_;
-    fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod_;
-  }
-
-  // Extract outer scope related features (group 5)
-  void ExtractOuterScopeFeature(const BufferStoreNode* node) {
-    FeatureSet& fea = buffer_features[node->buffer];
-
-    fea.outer_prod = outer_loop_prod_;
-    fea.num_loops = for_loop_stack_.size();
-    fea.auto_unroll_max_step = cur_auto_unroll_max_step_;
-  }
-
-  // Stores FeatureSet for every buffer
-  BufferMap<FeatureSet> buffer_features;
-
-  std::vector<std::pair<BufferStore, std::array<size_t, 17>>> bufstore_and_feats;
-
- private:
-  // The shared arithmetic analyzer
-  Analyzer ana_;
-
-  // The product of outer loop
-  float outer_loop_prod_ = 1.0f;
-
-  // The stacks to store parent loops during DFS
-  std::vector<const ForNode*> for_loop_stack_;
-  std::vector<const ForNode*> parallel_for_stack_;
-  std::vector<const ForNode*> vec_for_stack_;
-  std::vector<const ForNode*> unroll_for_stack_;
-
-  // GPU-related features
-  bool is_gpu_{false};
-  int blockIdx_x_len_{1};
-  int block_idx_y_len_{1};
-  int block_idx_z_len_{1};
-  int threadIdx_x_len_{1};
-  int thread_idx_y_len_{1};
-  int thread_idx_z_len_{1};
-  int vthread_len_{1};
-  int16_t cur_auto_unroll_max_step_{0};
-
-  // Store touch region information for all for loops. The format of this nested map:
-  // For a loop, for all its touched buffers, for all different accesses to the buffers,
-  // its (access type, number of touched elements, number of bytes of single element)
-  std::unordered_map<const ForNode*,
-                     BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>>
-      for_touch_regions_;
-
-  // The default cache line size in bytes
-  const int cache_line_size_ = 64;
-};
-
-// shifted log to incorporate the property that slog(0) = 0
-// inline float slog(float x) { return x < 0 ? -std::log2(-x + 1) : std::log2(x + 1); }
-inline float slog(float x) { return x; }
-
-void GetPerStoreFeature(const Stmt& stmt, int cache_line_size, int max_n_bufs,
-                        std::vector<float>* ret, StateFeatures* st_features) {
-  PerStoreFeatureExtractor extractor(cache_line_size);
-  extractor(stmt);
-
-  ret->push_back(extractor.buffer_features.size());
-  if (st_features) *st_features = StateFeatures(extractor.bufstore_and_feats);
-
-  for (const auto& x : extractor.buffer_features) {
-    const FeatureSet& fea_set = x.second;
-
-    /***** Group 1: Computation related features *****/
-    ret->push_back(slog(fea_set.float_mad));
-    ret->push_back(slog(fea_set.float_addsub));
-    ret->push_back(slog(fea_set.float_mul));
-    ret->push_back(slog(fea_set.float_divmod));
-    ret->push_back(slog(fea_set.float_cmp));
-    ret->push_back(slog(fea_set.float_math_func));
-    ret->push_back(slog(fea_set.float_other_func));
-    ret->push_back(slog(fea_set.int_mad));
-    ret->push_back(slog(fea_set.int_addsub));
-    ret->push_back(slog(fea_set.int_mul));
-    ret->push_back(slog(fea_set.int_divmod));
-    ret->push_back(slog(fea_set.int_cmp));
-    ret->push_back(slog(fea_set.int_math_func));
-    ret->push_back(slog(fea_set.int_other_func));
-    ret->push_back(slog(fea_set.bool_op));
-    ret->push_back(slog(fea_set.select_op));
-
-    ret->push_back(slog(fea_set.vec_num));
-    ret->push_back(slog(fea_set.vec_prod));
-    ret->push_back(slog(fea_set.vec_len));
-    for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) {
-      ret->push_back(i == static_cast<int>(fea_set.vec_type));
-    }
-
-    ret->push_back(slog(fea_set.unroll_num));
-    ret->push_back(slog(fea_set.unroll_prod));
-    ret->push_back(slog(fea_set.unroll_len));
-    for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) {
-      ret->push_back(i == static_cast<int>(fea_set.unroll_type));
-    }
-
-    ret->push_back(slog(fea_set.parallel_num));
-    ret->push_back(slog(fea_set.parallel_prod));
-    ret->push_back(slog(fea_set.parallel_len));
-    for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) {
-      ret->push_back(i == static_cast<int>(fea_set.parallel_type));
-    }
-
-    ret->push_back(fea_set.is_gpu);
-    ret->push_back(slog(fea_set.blockIdx_x_len));
-    ret->push_back(slog(fea_set.blockIdx_y_len));
-    ret->push_back(slog(fea_set.blockIdx_z_len));
-    ret->push_back(slog(fea_set.threadIdx_x_len));
-    ret->push_back(slog(fea_set.threadIdx_y_len));
-    ret->push_back(slog(fea_set.threadIdx_z_len));
-    ret->push_back(slog(fea_set.vthread_len));
-
-    /***** Group 2: Buffer access related features *****/
-    // sort according to pair (lines, bytes)
-    std::vector<std::pair<float, float>> buf_order_key;
-    for (const auto& acc_fea : fea_set.access_feas) {
-      buf_order_key.emplace_back(acc_fea.lines, acc_fea.bytes);
-    }
-    std::vector<int> buf_order(buf_order_key.size());
-    std::iota(buf_order.begin(), buf_order.end(), 0);
-
-    auto cmp = [&buf_order_key](int l, int r) {
-      return buf_order_key[l].first > buf_order_key[r].first ||
-             (buf_order_key[l].first == buf_order_key[r].first &&
-              buf_order_key[l].second > buf_order_key[r].second);
-    };
-    std::sort(buf_order.begin(), buf_order.end(), cmp);
-    int n_bufs = std::min(max_n_bufs, static_cast<int>(buf_order.size()));
-    buf_order.resize(n_bufs);
-
-    for (int idx : buf_order) {
-      const auto& acc_fea = fea_set.access_feas[idx];
-      for (int j = 0; j <= static_cast<int>(BufferAccessType::kReadWrite); ++j) {
-        ret->push_back(j == static_cast<int>(acc_fea.acc_type));
-      }
-      ret->push_back(slog(acc_fea.bytes));
-      ret->push_back(slog(acc_fea.unique_bytes));
-      ret->push_back(slog(acc_fea.lines));
-      ret->push_back(slog(acc_fea.unique_lines));
-      for (int j = 0; j <= static_cast<int>(ReuseType::kNoReuse); ++j) {
-        ret->push_back(j == static_cast<int>(acc_fea.reuse_type));
-      }
-      ret->push_back(slog(acc_fea.reuse_dis_iter));
-      ret->push_back(slog(acc_fea.reuse_dis_bytes));
-      ret->push_back(slog(acc_fea.reuse_ct));
-      ret->push_back(slog(acc_fea.bytes_d_reuse_ct));
-      ret->push_back(slog(acc_fea.unique_bytes_d_reuse_ct));
-      ret->push_back(slog(acc_fea.lines_d_reuse_ct));
-      ret->push_back(slog(acc_fea.unique_lines_d_reuse_ct));
-      ret->push_back(slog(acc_fea.stride));
-    }
-    // - fill padding
-    for (int i = 0; i < max_n_bufs - n_bufs; ++i) {
-      for (int j = 0; j <= static_cast<int>(BufferAccessType::kReadWrite); ++j) {  // 3
-        ret->push_back(0.0f);
-      }
-      ret->push_back(0.0f);
-      ret->push_back(0.0f);
-      ret->push_back(0.0f);
-      ret->push_back(0.0f);
-      for (int j = 0; j <= static_cast<int>(ReuseType::kNoReuse); ++j) {  // 3
-        ret->push_back(0.0f);
-      }
-      ret->push_back(0.0f);
-      ret->push_back(0.0f);
-      ret->push_back(0.0f);
-      ret->push_back(0.0f);
-      ret->push_back(0.0f);
-      ret->push_back(0.0f);
-      ret->push_back(0.0f);
-      ret->push_back(0.0f);
-    }
-
-    /***** Group 3: Arithmetic intensity related features *****/
-    for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) {
-      ret->push_back(fea_set.arith_intensity_curve[i]);
-    }
-
-    /***** Group 4: Allocation related features *****/
-    ret->push_back(slog(fea_set.alloc_size));
-    ret->push_back(slog(fea_set.alloc_prod));
-    ret->push_back(slog(fea_set.alloc_outer_prod));
-    ret->push_back(slog(fea_set.alloc_inner_prod));
-
-    /***** Group 5: Outer scope related features *****/
-    ret->push_back(slog(fea_set.outer_prod));
-    ret->push_back(slog(fea_set.num_loops));
-    ret->push_back(slog(fea_set.auto_unroll_max_step));
-  }
-}
-
-void GetPerStoreFeatureName(int max_n_bufs, std::vector<std::string>* ret) {
-  /***** Group 1: Computation related features *****/
-  ret->push_back(("float_mad"));
-  ret->push_back(("float_addsub"));
-  ret->push_back(("float_mul"));
-  ret->push_back(("float_divmod"));
-  ret->push_back(("float_cmp"));
-  ret->push_back(("float_mathfunc"));
-  ret->push_back(("float_otherfunc"));
-  ret->push_back(("int_mad"));
-  ret->push_back(("int_addsub"));
-  ret->push_back(("int_mul"));
-  ret->push_back(("int_divmod"));
-  ret->push_back(("int_cmp"));
-  ret->push_back(("int_mathfunc"));
-  ret->push_back(("int_otherfunc"));
-  ret->push_back(("bool_op"));
-  ret->push_back(("select_op"));
-  ret->push_back(("vec_num"));
-  ret->push_back(("vec_prod"));
-  ret->push_back(("vec_len"));
-  ret->push_back(("vec_type.kPosNone"));
-  ret->push_back(("vec_type.kPosInnerSpatial"));
-  ret->push_back(("vec_type.kPosMiddleSpatial"));
-  ret->push_back(("vec_type.kPosOuterSpatial"));
-  ret->push_back(("vec_type.kPosInnerReduce"));
-  ret->push_back(("vec_type.kPosMiddleReduce"));
-  ret->push_back(("vec_type.kPosOuterReduce"));
-  ret->push_back(("vec_type.kPosMixed"));
-  ret->push_back(("unroll_num"));
-  ret->push_back(("unroll_prod"));
-  ret->push_back(("unroll_len"));
-  ret->push_back(("unroll_type.kPosNone"));
-  ret->push_back(("unroll_type.kPosInnerSpatial"));
-  ret->push_back(("unroll_type.kPosMiddleSpatial"));
-  ret->push_back(("unroll_type.kPosOuterSpatial"));
-  ret->push_back(("unroll_type.kPosInnerReduce"));
-  ret->push_back(("unroll_type.kPosMiddleReduce"));
-  ret->push_back(("unroll_type.kPosOuterReduce"));
-  ret->push_back(("unroll_type.kPosMixed"));
-  ret->push_back(("parallel_num"));
-  ret->push_back(("parallel_prod"));
-  ret->push_back(("parallel_len"));
-  ret->push_back(("parallel_type.kPosNone"));
-  ret->push_back(("parallel_type.kPosInnerSpatial"));
-  ret->push_back(("parallel_type.kPosMiddleSpatial"));
-  ret->push_back(("parallel_type.kPosOuterSpatial"));
-  ret->push_back(("parallel_type.kPosInnerReduce"));
-  ret->push_back(("parallel_type.kPosMiddleReduce"));
-  ret->push_back(("parallel_type.kPosOuterReduce"));
-  ret->push_back(("parallel_type.kPosMixed"));
-  ret->push_back(("is_gpu"));
-  ret->push_back(("blockIdx_x_len"));
-  ret->push_back(("blockIdx_y_len"));
-  ret->push_back(("blockIdx_z_len"));
-  ret->push_back(("threadIdx_x_len"));
-  ret->push_back(("threadIdx_y_len"));
-  ret->push_back(("threadIdx_z_len"));
-  ret->push_back(("vthread_len"));
-  // section total: 57
-
-  /***** Group 2: Buffer access related features *****/
-  for (size_t i = 0; i < static_cast<size_t>(max_n_bufs); ++i) {
-    std::string prefix = "B" + std::to_string(i) + ".";
-    ret->push_back((prefix + "acc_type.kRead"));
-    ret->push_back((prefix + "acc_type.kWrite"));
-    ret->push_back((prefix + "acc_type.kReadWrite"));
-    ret->push_back((prefix + "bytes"));
-    ret->push_back((prefix + "unique_bytes"));
-    ret->push_back((prefix + "lines"));
-    ret->push_back((prefix + "unique_lines"));
-    ret->push_back((prefix + "reuse_type.kLoopMultipleRead"));
-    ret->push_back((prefix + "reuse_type.kSerialMultipleReadWrite"));
-    ret->push_back((prefix + "reuse_type.kNoReuse"));
-    ret->push_back((prefix + "reuse_dis_iter"));
-    ret->push_back((prefix + "reuse_dis_bytes"));
-    ret->push_back((prefix + "reuse_ct"));
-    ret->push_back((prefix + "bytes_d_reuse_ct"));
-    ret->push_back((prefix + "unique_bytes_d_reuse_ct"));
-    ret->push_back((prefix + "lines_d_reuse_ct"));
-    ret->push_back((prefix + "unique_lines_d_reuse_ct"));
-    ret->push_back((prefix + "stride"));
-  }
-  // section total : max_n_bufs * 18
-
-  /***** Group 3: Arithmetic intensity related features *****/
-  for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) {
-    ret->push_back(("arith_intensity_curve_" + std::to_string(i)));
-  }
-  // section total: ARITH_INTENSITY_CURVE_SAMPLE_N = 10
-
-  /***** Group 4: Allocation related features *****/
-  ret->push_back(("alloc_size"));
-  ret->push_back(("alloc_prod"));
-  ret->push_back(("alloc_outer_prod"));
-  ret->push_back(("alloc_inner_prod"));
-  // section total : 4
-
-  /***** Group 5: Outer scope related features *****/
-  ret->push_back(("outer_prod"));
-  ret->push_back(("num_loops"));
-  ret->push_back(("auto_unroll_max_step"));
-  // section total : 3
-}
-
-void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, int max_n_bufs,
-                                   std::vector<float>* feature, std::atomic<int>* error_ct,
-                                   StateFeatures* st_features = nullptr) {
-  Stmt code_body = GenerateCodeForState(task, state);
-  if (!code_body.defined())
-    (*error_ct)++;
-  else
-    GetPerStoreFeature(code_body, task->hardware_params->cache_line_bytes, max_n_bufs, feature,
-                       st_features);
-}
-
-void GetPerStoreFeaturesFromStates(const Array<State>& states, const SearchTask& task,
-                                   int skip_first_n_feature_extraction, int max_n_bufs,
-                                   std::vector<std::vector<float>>* features) {
-  // extract features
-  features->assign(states.size(), std::vector<float>());
-
-  std::atomic<int> error_ct(0);
-
-  support::parallel_for(skip_first_n_feature_extraction, states.size(),
-                        [&task, &states, &max_n_bufs, &features, &error_ct](int i) {
-                          GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs,
-                                                        &(*features)[i], &error_ct);
-                        });
-}
-
-void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector<SearchTask>& tasks,
-                                   int skip_first_n_feature_extraction, int max_n_bufs,
-                                   std::vector<std::vector<float>>* features) {
-  // extract features
-  features->assign(states.size(), std::vector<float>());
-
-  std::atomic<int> error_ct(0);
-
-  support::parallel_for(skip_first_n_feature_extraction, states.size(),
-                        [&tasks, &states, &max_n_bufs, &features, &error_ct](int i) {
-                          GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs,
-                                                        &(*features)[i], &error_ct);
-                        });
-}
-
-void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs,
-                                 std::vector<std::vector<float>>* features,
-                                 std::vector<float>* normalized_throughputs,
-                                 std::vector<int>* task_ids) {
-  Array<State> states;
-  std::vector<SearchTask> tasks;
-
-  normalized_throughputs->clear();
-  task_ids->clear();
-
-  // (workload_key, target) -> (search_task, task_id)
-  std::unordered_map<std::pair<std::string, std::string>, std::pair<SearchTask, size_t>> task_cache;
-  // task_id -> min_cost
-  std::vector<float> min_costs;
-
-  const auto* workload_key_to_tensors =
-      tvm::runtime::Registry::Get("auto_scheduler.workload_key_to_tensors");
-  ICHECK(workload_key_to_tensors != nullptr);
-
-  // read from file
-  RecordReader reader(filename);
-  auto cur_inp = make_object<MeasureInputNode>();
-  auto cur_res = make_object<MeasureResultNode>();
-  while (reader->ReadNext(cur_inp.get(), cur_res.get())) {
-    float cost = static_cast<float>(FloatArrayMean(cur_res->costs));
-    const std::string& workload_key = cur_inp->task->workload_key;
-
-    SearchTask task;
-    size_t task_id;
-    std::pair<std::string, std::string> key(workload_key, cur_inp->task->target->str());
-    auto find_res = task_cache.find(key);
-    if (find_res == task_cache.end()) {
-      // rebuild task
-      Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
-      Target target = cur_inp->task->target;
-      Target target_host = cur_inp->task->target_host;
-      CheckAndUpdateHostConsistency(&target, &target_host);
-      task = SearchTask(ComputeDAG(tensors), workload_key, target, target_host,
-                        cur_inp->task->hardware_params, cur_inp->task->layout_rewrite_option,
-                        cur_inp->task->task_input_names);
-      task_id = task_cache.size();
-
-      // compute min cost for each task
-      task_cache.insert(std::make_pair(key, std::make_pair(task, task_id)));
-      min_costs.push_back(cost);
-    } else {
-      std::tie(task, task_id) = find_res->second;
-      min_costs[task_id] = std::min(min_costs[task_id], cost);
-    }
-
-    tasks.push_back(std::move(task));
-    task_ids->push_back(task_id);
-    states.push_back(cur_inp->state);
-    normalized_throughputs->push_back(cost);
-
-    if (max_lines > 0 && static_cast<int>(states.size()) >= max_lines) {
-      break;
-    }
-  }
-
-  for (size_t i = 0; i < normalized_throughputs->size(); ++i) {
-    (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i];
-  }
-
-  GetPerStoreFeaturesFromStates(states, tasks, 0, max_n_bufs, features);
-}
-
-void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
-                                         const Array<MeasureResult>& results,
-                                         int skip_first_n_feature_extraction, int max_n_bufs,
-                                         std::vector<std::vector<float>>* features,
-                                         std::vector<float>* normalized_throughputs,
-                                         std::vector<int>* task_ids) {
-  Array<State> states;
-  std::vector<SearchTask> tasks;
-
-  normalized_throughputs->clear();
-  task_ids->clear();
-
-  // (workload_key, target) -> (search_task, task_id)
-  std::unordered_map<std::pair<std::string, std::string>, std::pair<SearchTask, size_t>> task_cache;
-  // task_id -> min_cost
-  std::vector<float> min_costs;
-
-  const auto* workload_key_to_tensors =
-      tvm::runtime::Registry::Get("auto_scheduler.workload_key_to_tensors");
-  ICHECK(workload_key_to_tensors != nullptr);
-
-  tasks.reserve(inputs.size());
-  normalized_throughputs->reserve(inputs.size());
-  task_ids->reserve(inputs.size());
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    float cost = static_cast<float>(FloatArrayMean(results[i]->costs));
-    const std::string& workload_key = inputs[i]->task->workload_key;
-    SearchTask task;
-
-    size_t task_id;
-    std::pair<std::string, std::string> key(workload_key, inputs[i]->task->target->str());
-    auto find_res = task_cache.find(key);
-    if (find_res == task_cache.end()) {
-      if (inputs[i]->task->compute_dag.defined()) {  // the measure input is complete
-        task = inputs[i]->task;
-      } else {
-        // The measure input is incomplete, rebuild task for incomplete measure pairs read from file
-        try {
-          Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
-          Target target = inputs[i]->task->target;
-          Target target_host = inputs[i]->task->target_host;
-          CheckAndUpdateHostConsistency(&target, &target_host);
-          task =
-              SearchTask(ComputeDAG(tensors), workload_key, target, target_host,
-                         inputs[i]->task->hardware_params, inputs[i]->task->layout_rewrite_option,
-                         inputs[i]->task->task_input_names);
-        } catch (std::exception& e) {
-          // Cannot build ComputeDAG from workload key, the task may have not been registered in
-          // this search round
-          continue;
-        }
-      }
-      task_id = task_cache.size();
-
-      // compute min cost for each task
-      task_cache.insert(std::make_pair(key, std::make_pair(task, task_id)));
-      min_costs.push_back(cost);
-    } else {
-      std::tie(task, task_id) = find_res->second;
-      min_costs[task_id] = std::min(min_costs[task_id], cost);
-    }
-
-    tasks.push_back(std::move(task));
-    task_ids->push_back(task_id);
-    states.push_back(inputs[i]->state);
-    normalized_throughputs->push_back(cost);
-  }
-
-  for (size_t i = 0; i < normalized_throughputs->size(); ++i) {
-    (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i];
-  }
-
-  GetPerStoreFeaturesFromStates(states, tasks, skip_first_n_feature_extraction, max_n_bufs,
-                                features);
-}
-
-/*
- * \brief Serialize a two-dimensional variable-size feature vector with normalized throughputs
- * and task ids to a one-dimensional flatten byte array.
- *
- * For faster data copy between c++ and python, the c++ part returns features in a single
- * flatten array using a packed format. The python part then unpacks the flatten array.
- *
- * The packed format for n records is:
- * {
- *   int   n;
- *   int   sizes[n+2];           // The sizes for the following arrays
- *
- *   float features_0[size[0]];  // The features for record 0
- *   float features_1[size[1]];  // The features for record 1
- *   ...
- *   float features_i[size[i]];  // The features for record i
- *   ... // until i == n - 1
- *
- *   float throughputs[sizes[n]];  // The normalized throughputs for n records
- *   int   task_ids[size[n+1]];   // The task ids for n records
- *
- * }
- * To implement this format, we also store int as float, so we can store all numbers
- * into a single float array.
- */
-TVMByteArray SerializeFeatures(std::vector<std::vector<float>>&& features,
-                               std::vector<float>&& normalized_throughputs,
-                               std::vector<int>&& task_ids, std::vector<char>* out_data) {
-  size_t total_bytes = 0;
-  std::vector<int> size_vector;
-
-  int n = features.size();
-
-  // serialize sizes
-  size_t size_vector_size = 1 + n + 2;
-  total_bytes += size_vector_size * sizeof(int);
-
-  size_vector.reserve(size_vector_size);
-  size_vector.push_back(features.size());
-  for (const auto& x : features) {
-    size_vector.push_back(static_cast<int>(x.size()));
-    total_bytes += sizeof(float) * x.size();
-  }
-  size_vector.push_back(static_cast<int>(normalized_throughputs.size()));
-  total_bytes += sizeof(float) * normalized_throughputs.size();
-  size_vector.push_back(static_cast<int>(task_ids.size()));
-  total_bytes += sizeof(int) * task_ids.size();
-
-  ICHECK_EQ(size_vector.size(), size_vector_size);
-
-  // allocate memory
-  out_data->reserve(total_bytes);
-  char* ptr = out_data->data();
-
-  // serialize size_vector
-  memmove(ptr, reinterpret_cast<char*>(size_vector.data()), size_vector.size() * sizeof(int));
-  ptr += size_vector.size() * sizeof(int);
-
-  // serialize features
-  for (auto& x : features) {
-    memmove(ptr, x.data(), sizeof(float) * x.size());
-    ptr += sizeof(float) * x.size();
-    x.clear();
-  }
-
-  // serialize normalized_throughputs
-  memmove(ptr, reinterpret_cast<char*>(normalized_throughputs.data()),
-          normalized_throughputs.size() * sizeof(int));
-  ptr += normalized_throughputs.size() * sizeof(int);
-
-  // serialize task_ids
-  memmove(ptr, reinterpret_cast<char*>(task_ids.data()), task_ids.size() * sizeof(int));
-  ptr += task_ids.size() * sizeof(int);
-
-  ICHECK_EQ(ptr - out_data->data(), total_bytes);
-
-  return TVMByteArray{out_data->data(), total_bytes};
-}
-
-TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromFile")
-    .set_body([](TVMArgs args, TVMRetValue* ret) {
-      std::string filename = args[0];
-      int max_lines = args[1];
-      int max_n_bufs = args[2];
-
-      std::vector<std::vector<float>> features;
-      std::vector<float> normalized_throughputs;
-      std::vector<int> task_ids;
-
-      GetPerStoreFeaturesFromFile(filename, max_lines, max_n_bufs, &features,
-                                  &normalized_throughputs, &task_ids);
-
-      std::vector<char> byte_data;
-      *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs),
-                               std::move(task_ids), &byte_data);
-    });
-
-TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromMeasurePairs")
-    .set_body([](TVMArgs args, TVMRetValue* ret) {
-      Array<MeasureInput> inputs = args[0];
-      Array<MeasureResult> results = args[1];
-      int skip_first_n_feature_extraction = args[2];
-      int max_n_bufs = args[3];
-
-      std::vector<std::vector<float>> features;
-      std::vector<float> normalized_throughputs;
-      std::vector<int> task_ids;
-
-      GetPerStoreFeaturesFromMeasurePairs(inputs, results, skip_first_n_feature_extraction,
-                                          max_n_bufs, &features, &normalized_throughputs,
-                                          &task_ids);
-
-      std::vector<char> byte_data;
-      *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs),
-                               std::move(task_ids), &byte_data);
-    });
-
-TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromStates")
-    .set_body([](TVMArgs args, TVMRetValue* ret) {
-      Array<State> states = args[0];
-      SearchTask task = args[1];
-      int max_n_bufs = args[2];
-
-      std::vector<std::vector<float>> features;
-      std::vector<float> normalized_throughputs;
-      std::vector<int> task_ids;
-
-      GetPerStoreFeaturesFromStates(states, task, 0, max_n_bufs, &features);
-
-      std::vector<char> byte_data;
-      *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs),
-                               std::move(task_ids), &byte_data);
-    });
-
-TVM_REGISTER_GLOBAL("auto_scheduler.GetBufAndFeatsFromStates")
-    .set_body_typed([](Array<State> states, SearchTask task, int max_n_bufs) {
-      std::atomic<int> error_ct(0);
-      Array<StateFeatures> buf_and_feats;
-      for (size_t i = 0; i < states.size(); ++i) {
-        StateFeatures st_features;
-        std::vector<float> features;
-        GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, &features, &error_ct,
-                                      &st_features);
-        if (!st_features.defined()) {
-          LOG_WARNING << "No code generated for state " << i << ", error_ct =" << error_ct
-                      << std::endl;
-          st_features = StateFeatures(std::vector<std::pair<BufferStore, Features>>());
-        }
-        st_features.CopyOnWrite()->state_code = GenerateCodeForState(task, states[i], true);
-        buf_and_feats.push_back(st_features);
-      }
-      return buf_and_feats;
-    });
-
-TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeatureNames")
-    .set_body([](TVMArgs args, TVMRetValue* ret) {
-      int max_n_bufs = args[0];
-      std::vector<std::string> names;
-
-      GetPerStoreFeatureName(max_n_bufs, &names);
-
-      Array<String> arr;
-      for (const auto& x : names) {
-        arr.push_back(x);
-      }
-      *ret = arr;
-    });
-
-}  // namespace auto_scheduler
-}  // namespace tvm
diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc
index 517f7ff91f558a5e1b5564358fa5087c388d58eb..1a685cfe01a7c88f251fd5bcce67c2c5010b22d6 100755
--- a/src/auto_scheduler/loop_state.cc
+++ b/src/auto_scheduler/loop_state.cc
@@ -275,7 +275,7 @@ void State::reorder(int stage_id, const Array<Iterator>& order) {
 }
 
 Array<Iterator> State::split(int stage_id, const Iterator& it,
-                             const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+                             const Array<PrimExpr>& lengths, bool inner_to_outer) {
   const Stage& stage = operator->()->stages[stage_id];
   SplitStep step =
       SplitStep(stage_id, GetIndex(stage->iters, it),
@@ -501,7 +501,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder")
 
 TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
     .set_body_typed([](State state, int stage_id, const Iterator& it,
-                       const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+                       const Array<PrimExpr>& lengths, bool inner_to_outer) {
       const auto& res = state.split(stage_id, it, lengths, inner_to_outer);
       return Array<ObjectRef>{state, res};
     });
diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc
index e652c1baf87a68d7d2950d0b8eba46fd3737a012..3bb4cf9608fd4fcba2b1ded8f23b55283b5a0908 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy.cc
@@ -111,10 +111,10 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel program_cost_model,
     node->init_rules.push_back(&init_vectorization);
 
     // Mutation Rules for Evolutionary Search
-    node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
-    node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.04));
-    node->mutation_rules.push_back(std::make_shared<MutateComputeLocation>(0.05));
-    node->mutation_rules.push_back(std::make_shared<MutateParallel>(0.01));
+    // node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
+    // node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.04));
+    // node->mutation_rules.push_back(std::make_shared<MutateComputeLocation>(0.05));
+    // node->mutation_rules.push_back(std::make_shared<MutateParallel>(0.01));
   } else if (IsGPUTask(node->search_task)) {
     // Sketch Generation Rules
     if (node->search_task->target->GetAttr<String>("device", "") == "mali") {
@@ -147,8 +147,8 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel program_cost_model,
     }
 
     // Mutation Rules for Evolutionary Search
-    node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
-    node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.10));
+    // node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
+    // node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.10));
   } else {
     LOG(FATAL) << "No default sketch rules for target: " << task->target;
   }
@@ -370,7 +370,7 @@ Array<State> SketchPolicyNode::GenerateSketches() {
         auto step = pstate->transform_steps[split_step_id].as<SplitStepNode>();
         ICHECK(step != nullptr);
         pstate->transform_steps.Set(
-            split_step_id, SplitStep(step->stage_id, step->iter_id, step->extent, {NullOpt},
+            split_step_id, SplitStep(step->stage_id, step->iter_id, step->extent, {PrimExpr()},
                                      step->inner_to_outer));
       }
     }
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index 8df69fc7ce3b93daa27e589e6cce385d9b8ec6b8..83d12b0edf2e4750451278ca03c95b701307894e 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -311,7 +311,7 @@ std::vector<std::pair<State, int>> RuleSimplifyComputeWithConstTensor::Apply(
       // tile other space indices
       ICHECK(iter->iter_kind == IteratorKind::kSpatial);
       tiled_outer_iters.push_back(
-          tmp_s.split(stage_id, iter, Array<Optional<Integer>>(tile_level - 1, NullOpt)));
+          tmp_s.split(stage_id, iter, Array<PrimExpr>(tile_level - 1, PrimExpr())));
     }
   }
 
@@ -494,36 +494,27 @@ std::vector<std::pair<State, int>> RuleCustomSketch::Apply(const SketchPolicyNod
 PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
                                                              std::mt19937* rand_gen) const {
   SplitFactorizationMemo split_memo;
-  int max_innermost_split_factor =
-      GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
-
+  // TODO: remember variable range constraints
+  // int max_innermost_split_factor =
+  //     GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
   StateNode* pstate = state->CopyOnWrite();
   // Scan the transformation history and randomly fill tiles size for all SplitStep
   for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) {
     if (auto ps = (*state)->transform_steps[step_id].as<SplitStepNode>()) {
-      bool all_defined = true;
-      for (const auto& len : ps->lengths) {
-        if (!len) {
-          all_defined = false;
-          break;
-        }
-      }
-      if (all_defined) {
-        continue;
-      }
-
+      bool all_defined =
+          std::accumulate(ps->lengths.begin(), ps->lengths.end(), true,
+                          [](bool acc, const ObjectRef& item) { return acc && item.defined(); });
+      if (all_defined) continue;
       ICHECK(ps->extent);
-      int extent = GetIntImm(ps->extent.value());
-      const auto& candidate_lens = split_memo.GetFactorizationSchemes(extent, ps->lengths.size(),
-                                                                      max_innermost_split_factor);
-      ICHECK(!candidate_lens.empty());
-      const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()];
-
-      pstate->transform_steps.Set(
-          step_id,
-          SplitStep(ps->stage_id, ps->iter_id, ps->extent,
-                    Array<Optional<Integer>>(candidate_lengths.begin(), candidate_lengths.end()),
-                    ps->inner_to_outer));
+      Array<PrimExpr> new_split_lengths;
+      for (size_t len_id = 0; len_id < ps->lengths.size(); ++len_id) {
+        if (ps->lengths[len_id].defined())
+          LOG_FATAL << "SplitStep length should not be partially defined";
+        String var_name = "sp" + std::to_string(step_id) + "_" + std::to_string(len_id);
+        new_split_lengths.push_back(Var(var_name));
+      }
+      pstate->transform_steps.Set(step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent,
+                                                     new_split_lengths, ps->inner_to_outer));
     }
   }
   pstate->concrete = true;
@@ -912,330 +903,23 @@ PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* pol
 
 PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state,
                                                            std::mt19937* rand_gen) const {
-  int max_innermost_split_factor =
-      GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
-
-  // Extract all SplitStep
-  std::vector<size_t> split_step_ids;
-  for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
-    if (auto ps = (*state)->transform_steps[i].as<SplitStepNode>()) {
-      if (!ps->extent.defined() || !ps->extent.value()->IsInstance<IntImmNode>()) {
-        continue;
-      }
-      auto innermost_factor = ps->lengths.back().value_or(max_innermost_split_factor + 1);
-      if (GetIntImm(innermost_factor) <= max_innermost_split_factor) {
-        split_step_ids.push_back(i);
-      }
-    }
-  }
-  if (split_step_ids.empty()) {
-    // No tile size could be mutated.
-    return ResultKind::kInvalid;
-  }
-
-  // Select a SplitStep with extent larger than one to mutate.
-  int retry_ct = 0;
-  int64_t extent = 1;
-  int step_id;
-  const SplitStepNode* ps;
-
-  do {
-    step_id = split_step_ids[(*rand_gen)() % split_step_ids.size()];
-    ps = (*state)->transform_steps[step_id].as<SplitStepNode>();
-    ICHECK(ps != nullptr);
-    extent = GetIntImm(ps->extent.value());
-    retry_ct += 1;
-  } while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 && (extent == 1 || extent == 0));
-
-  if (extent <= 1) {
-    // Cannot find a step with extent larger than one.
-    return ResultKind::kInvalid;
-  }
-
-  // Fetch the current tile sizes.
-  std::vector<int> lengths(ps->lengths.size() + 1, 1);
-  for (int i = 0; i < static_cast<int>(ps->lengths.size()); ++i) {
-    lengths[i + 1] = GetIntImm(ps->lengths[i].value());
-  }
-  lengths[0] = extent / ElementProduct(lengths);
-
-  // Random permute the tile size order.
-  std::vector<int> random_perm;
-  RandomPermutation(lengths.size(), &random_perm, rand_gen);
-
-  // Try to divide a factor from one tile size and multiple it to another.
-  for (size_t i = 0; i < random_perm.size(); ++i) {
-    size_t src_idx = random_perm[i];
-    int length = lengths[src_idx];
-    if (length <= 1) {
-      continue;
-    }
-
-    // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx]
-    size_t dst_idx = random_perm[(i + 1) % random_perm.size()];
-    const std::vector<int>& factors = policy->split_memo.GetFactors(length);
-    ICHECK_GE(factors.size(), 1);
-
-    int divide_factor;
-    if (dst_idx == lengths.size() - 1) {
-      // Maintain the restriction of hardware_params.max_innermost_split_factor.
-      int max_factor_index = static_cast<int>(factors.size()) - 1;
-      for (; max_factor_index >= 1; max_factor_index--) {
-        if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) {
-          break;
-        }
-      }
-      if (max_factor_index == 0) {
-        // Failed on this dst_idx, try next one.
-        continue;
-      }
-      divide_factor = factors[1 + (*rand_gen)() % (max_factor_index)];
-    } else {
-      divide_factor = factors[1 + (*rand_gen)() % (factors.size() - 1)];
-    }
-
-    // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx].
-    Array<Integer> new_lengths;
-    for (size_t j = 1; j < lengths.size(); ++j) {
-      if (j == src_idx) {
-        new_lengths.push_back(Integer(lengths[j] / divide_factor));
-      } else if (j == dst_idx) {
-        new_lengths.push_back(Integer(lengths[j] * divide_factor));
-      } else {
-        new_lengths.push_back(Integer(lengths[j]));
-      }
-    }
-
-    ICHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor);
-
-    StateNode* pstate = state->CopyOnWrite();
-    pstate->transform_steps.Set(
-        step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent,
-                           Array<Optional<Integer>>(new_lengths.begin(), new_lengths.end()),
-                           ps->inner_to_outer));
-    return ResultKind::kValid;
-  }
   return ResultKind::kInvalid;
 }
 
 PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, State* state,
                                                              std::mt19937* rand_gen) const {
-  // Extract all auto_unroll_max_step pragma steps.
-  std::vector<int> pragma_steps;
-  for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
-    if (auto ps = (*state)->transform_steps[i].as<PragmaStepNode>()) {
-      if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) {
-        pragma_steps.push_back(i);
-      }
-    }
-  }
-  if (pragma_steps.empty()) {
-    return ResultKind::kInvalid;
-  }
-
-  std::vector<int>& auto_unroll_configs =
-      IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu;
-
-  // Randomly pick up an auto unroll pragma step
-  auto step_id = pragma_steps[(*rand_gen)() % pragma_steps.size()];
-  auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>();
-  ICHECK(ps);
-
-  // Mutate its value to a random candidates
-  int val = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()];
-  StateNode* pstate = state->CopyOnWrite();
-  pstate->transform_steps.Set(
-      step_id, PragmaStep(ps->stage_id, ps->iter_id,
-                          std::string("auto_unroll_max_step") + "$" + std::to_string(val)));
-  Stage new_stage = pstate->stages[ps->stage_id];
-  new_stage.CopyOnWrite()->attrs.auto_unroll_max_step = val;
-  pstate->stages.Set(ps->stage_id, new_stage);
-  return ResultKind::kValid;
+  return ResultKind::kInvalid;
 }
 
 PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
                                                                   State* state,
                                                                   std::mt19937* rand_gen) const {
-  if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
-    return ResultKind::kInvalid;
-  }
-
-  // Extract all compute_at steps.
-  std::vector<int> compute_at_steps;
-  for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
-    if (auto ps = (*state)->transform_steps[s].as<ComputeAtStepNode>()) {
-      int stage_inc = GetTargetStageIDInState(*state, s) - ps->stage_id;
-
-      if (IsTiled((*state)->stages[ps->stage_id + stage_inc])) {
-        continue;
-      }
-
-      if (NeedsMultilevelTiling(policy->search_task, *state, ps->stage_id + stage_inc)) {
-        continue;
-      }
-      compute_at_steps.push_back(s);
-    }
-  }
-  if (compute_at_steps.empty()) {
-    return ResultKind::kInvalid;
-  }
-
-  // Randomly pick one step
-  size_t step_id = compute_at_steps[(*rand_gen)() % compute_at_steps.size()];
-  auto ps = (*state)->transform_steps[step_id].as<ComputeAtStepNode>();
-  int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id;
-  ICHECK(ps != nullptr);
-
-  // Randomly pick a new computation location
-  std::vector<std::pair<int, int>> candidates =
-      GetComputeLocationCandidates(policy->search_task, *state, ps->stage_id + stage_inc);
-  if (candidates.empty()) {
-    return ResultKind::kInvalid;
-  }
-  int choice = (*rand_gen)() % (candidates.size());
-  int new_compute_at_stage_id = candidates[choice].first;
-  int new_compute_at_iter_id = candidates[choice].second;
-
-  // Replay a new state.
-  State tmp_s = policy->search_task->compute_dag->init_state;
-  for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
-    if (s == step_id) {
-      tmp_s.CopyOnWrite()->transform_steps.push_back(
-          ComputeAtStep(ps->stage_id, new_compute_at_stage_id - stage_inc, new_compute_at_iter_id));
-    } else {
-      tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[s]);
-    }
-    try {
-      StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag);
-    } catch (Error& e) {
-      return ResultKind::kInvalid;
-    }
-  }
-
-  *state = tmp_s;
-  return ResultKind::kValid;
+  return ResultKind::kInvalid;
 }
 
 PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, State* state,
                                                            std::mt19937* rand_gen) const {
-  // This mutation rule only focuses on a case that parallel was added to
-  // the outermost loop and the loop is generated by fusing other loops.
-  // In short, we mutate the fusion step before the parallel step.
-
-  // Extract all parallel steps.
-  std::vector<int> parallel_steps;
-  for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
-    auto ps = (*state)->transform_steps[s].as<AnnotationStepNode>();
-    if (!ps || ps->annotation != IteratorAnnotation::kParallel) {
-      continue;
-    }
-
-    // Skip non-outermost loop or the parallel step without fusion beforehand.
-    if (ps->iter_id != 0 || s == 0 || !(*state)->transform_steps[s - 1].as<FuseStepNode>()) {
-      continue;
-    }
-    auto fuse_step = (*state)->transform_steps[s - 1].as<FuseStepNode>();
-    if (fuse_step->fused_ids[0] != 0) {
-      continue;
-    }
-
-    parallel_steps.push_back(s);
-  }
-  if (parallel_steps.empty()) {
-    return ResultKind::kInvalid;
-  }
-
-  // Randomly pick one parallel step.
-  size_t step_id = parallel_steps[(*rand_gen)() % parallel_steps.size()];
-
-  // Replay a new state until the picked fuse step.
-  State tmp_s = policy->search_task->compute_dag->init_state;
-  for (size_t s = 0; s < step_id - 1; ++s) {
-    const auto& step = (*state)->transform_steps[s];
-    tmp_s.CopyOnWrite()->transform_steps.push_back(step);
-    StepApplyToState(step, &tmp_s, policy->search_task->compute_dag);
-  }
-
-  // Compute all possible fusion granularities
-  auto fuse_step = (*state)->transform_steps[step_id - 1].as<FuseStepNode>();
-  int stage_id = fuse_step->stage_id;
-  const Stage& stage = tmp_s->stages[stage_id];
-  size_t max_fusable_iter_id;
-  for (max_fusable_iter_id = 0; max_fusable_iter_id < stage->iters.size(); ++max_fusable_iter_id) {
-    const Iterator& it = stage->iters[max_fusable_iter_id];
-    if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) {
-      break;
-    }
-
-    if (tmp_s->attach_map->iter_to_attached_stages.count(
-            std::make_pair(stage_id, max_fusable_iter_id))) {
-      break;
-    }
-  }
-
-  if (max_fusable_iter_id == 0) {
-    return ResultKind::kInvalid;
-  }
-
-  // Randomly pick one granularity
-  int fuse_to_iter_id = (*rand_gen)() % max_fusable_iter_id + 1;
-  Array<Integer> fused_ids;
-  for (int i = 0; i < fuse_to_iter_id; ++i) {
-    fused_ids.push_back(i);
-  }
-  int iter_offset = fuse_step->fused_ids.back()->value - fused_ids.back()->value;
-  if (iter_offset == 0) {
-    return ResultKind::kInvalid;
-  }
-
-  // Replay the mutated fused and annotation step.
-  auto new_fuse_step = FuseStep(stage_id, fused_ids);
-  tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step);
-  StepApplyToState(new_fuse_step, &tmp_s, policy->search_task->compute_dag);
-  tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[step_id]);
-  StepApplyToState((*state)->transform_steps[step_id], &tmp_s, policy->search_task->compute_dag);
-
-  // Replay the rest steps.
-  for (size_t s = step_id + 1; s < (*state)->transform_steps.size(); ++s) {
-    auto step = (*state)->transform_steps[s];
-    if (step->stage_id == stage_id) {
-      // Since we changed the loop structure, iter ID in later steps to the same stage
-      // has to be adjusted.
-      if (auto ps = step.as<AnnotationStepNode>()) {
-        if (ps->iter_id == 0) {
-          step = AnnotationStep(ps->stage_id, 0, ps->annotation);
-        } else {
-          ICHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size());
-          step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation);
-        }
-      } else if (auto ps = step.as<PragmaStepNode>()) {
-        if (ps->iter_id == 0) {
-          step = PragmaStep(ps->stage_id, 0, ps->pragma_type);
-        } else {
-          ICHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size());
-          step = PragmaStep(ps->stage_id, ps->iter_id + iter_offset, ps->pragma_type);
-        }
-      } else {
-        return ResultKind::kInvalid;
-      }
-    }
-    if (IsStageNumberChangingStep(step)) {
-      // For these steps, we have to update stage_id because these steps will make stage_id
-      // out-dated. But here we just simply give up this mutation for simplicity.
-      // This is not an issue because this will never happend in normal cases where all these steps
-      // are before parallel steps.
-      return ResultKind::kInvalid;
-    }
-    tmp_s.CopyOnWrite()->transform_steps.push_back(step);
-    try {
-      StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag);
-    } catch (Error& e) {
-      return ResultKind::kInvalid;
-    }
-  }
-
-  *state = tmp_s;
-  return ResultKind::kValid;
+  return ResultKind::kInvalid;
 }
 
 }  // namespace auto_scheduler
diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc
index ac1cf2dd82c9176cacce0e323360df10407dffc9..c55e8dbd055b01e3c1a127f61b43a0e0acb9fb66 100644
--- a/src/auto_scheduler/search_policy/utils.cc
+++ b/src/auto_scheduler/search_policy/utils.cc
@@ -182,7 +182,7 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
       levels[0].push_back(iter);
     } else {
       Array<Iterator> split_res =
-          tmp_s.split(stage_id, iter, Array<Optional<Integer>>(size - 1, NullOpt));
+          tmp_s.split(stage_id, iter, Array<PrimExpr>(size - 1, PrimExpr()));
       for (int i = 0; i < size; i++) {
         levels[i].push_back(split_res[i]);
       }
diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc
index b67d5cdd7bd93c464bea2911b0896afc5d3cced0..8c8742130819b6df62d6ab72888163340d91bb62 100644
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -818,7 +818,7 @@ String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 /********** Split **********/
 // common part for SplitStep, FollowSplitStep, and FollowFusedSplitStep
 Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id,
-                                  const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+                                  const Array<PrimExpr>& lengths, bool inner_to_outer) {
   const Stage& stage = (*state)->stages[stage_id];
   const Iterator& it = stage->iters[iter_id];
   size_t old_iter_size = stage->iters.size();
@@ -835,7 +835,7 @@ Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id,
 
   Array<Iterator> outs;
   for (size_t i = 0; i < lengths.size(); ++i) {
-    Optional<Integer> l;
+    PrimExpr l;
     String name;
     if (inner_to_outer) {
       l = lengths[lengths.size() - i - 1];
@@ -845,11 +845,11 @@ Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id,
       name = it->name + "." + std::to_string(i);
     }
     Iterator res;
-    if (l && tosplit_min && tosplit_extent) {
-      res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind,
+    if (l.defined() && tosplit_min && tosplit_extent) {
+      res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l), it->iter_kind,
                      IteratorAnnotation::kNone);
       tosplit_min = Integer(0);
-      tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value());
+      tosplit_extent = indexdiv(tosplit_extent.value() + l - 1, l);
     } else {
       res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone);
       tosplit_min = NullOpt;
@@ -899,8 +899,8 @@ Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id,
 }
 
 Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
-                                    int stage_id, int iter_id,
-                                    const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+                                    int stage_id, int iter_id, const Array<PrimExpr>& lengths,
+                                    bool inner_to_outer) {
   auto stage = (*stages)[stage_id];
   const Array<IterVar>& axes = stage_to_axes->at(stage);
 
@@ -909,7 +909,7 @@ Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* st
     IterVar outer = axes[iter_id], inner;
     for (int i = static_cast<int>(lengths.size()) - 1; i >= 0; i--) {
       IterVar to_split = outer;
-      stage.split(to_split, lengths[i].value(), &outer, &inner);
+      stage.split(to_split, lengths[i], &outer, &inner);
       outs.push_back(inner);
     }
     outs.push_back(outer);
@@ -917,7 +917,7 @@ Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* st
     IterVar outer, inner = axes[iter_id];
     for (size_t i = 0; i < lengths.size(); i++) {
       IterVar to_split = inner;
-      stage.split_by_nparts(to_split, lengths[i].value(), &outer, &inner);
+      stage.split_by_nparts(to_split, lengths[i], &outer, &inner);
       outs.push_back(outer);
     }
     outs.push_back(inner);
@@ -942,8 +942,7 @@ Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* st
 }
 
 String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, int stage_id,
-                             int iter_id, const Array<Optional<Integer>>& lengths,
-                             bool inner_to_outer) {
+                             int iter_id, const Array<PrimExpr>& lengths, bool inner_to_outer) {
   const auto& stage = (*stages)[stage_id];
   auto to_split = stage_to_axes->at(stage)[iter_id];
   const auto& func_name = CleanName(stage->op->name);
@@ -974,7 +973,7 @@ String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_
 }
 
 SplitStep::SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
-                     const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+                     const Array<PrimExpr>& lengths, bool inner_to_outer) {
   auto node = make_object<SplitStepNode>();
   node->stage_id = stage_id;
   // Extent can be a irreducible expression in some special cases
@@ -999,15 +998,16 @@ SplitStep::SplitStep(dmlc::JSONReader* reader) {
   int int_val;
   s = reader->NextArrayItem();
   ICHECK(s);
-  reader->Read(&int_val);
-  if (int_val) {
-    node->extent = Integer(int_val);
-  }
-  s = reader->NextArrayItem();
-  ICHECK(s);
-  reader->Read(&node->lengths);
-  s = reader->NextArrayItem();
-  ICHECK(s);
+  // TODO
+  // reader->Read(&int_val);
+  // if (int_val) {
+  //   node->extent = Integer(int_val);
+  // }
+  // s = reader->NextArrayItem();
+  // ICHECK(s);
+  // reader->Read(&node->lengths);
+  // s = reader->NextArrayItem();
+  // ICHECK(s);
   reader->Read(&node->inner_to_outer);
   data_ = std::move(node);
 }
@@ -1017,8 +1017,9 @@ void SplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
   writer->WriteString(record_prefix_str);
   writer->WriteArrayItem(stage_id);
   writer->WriteArrayItem(iter_id);
-  writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0);
-  writer->WriteArrayItem(lengths);
+  // TODO
+  // writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0);
+  // writer->WriteArrayItem(lengths);
   writer->WriteArrayItem(static_cast<int>(inner_to_outer));
 }
 
@@ -1055,8 +1056,7 @@ void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
   writer->WriteArrayItem(n_split);
 }
 
-Array<Optional<Integer>> FollowSplitStepNode::ExtractSplitLengths(
-    const Array<Step>& transform_steps) const {
+Array<PrimExpr> FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps) const {
   // Make sure src_step_id is within the range of transform_steps.
   ICHECK_LT(src_step_id, transform_steps.size());
   auto ps = transform_steps[src_step_id].as<SplitStepNode>();
@@ -1067,7 +1067,7 @@ Array<Optional<Integer>> FollowSplitStepNode::ExtractSplitLengths(
   ICHECK_LE(n_split, ps->lengths.size() + 1);
   ICHECK(ps != nullptr);
 
-  Array<Optional<Integer>> lengths;
+  Array<PrimExpr> lengths;
   lengths.reserve(n_split);
   int j = 0;
   // Get the first (n_split-1) split factors of followed src_step.
@@ -1079,19 +1079,14 @@ Array<Optional<Integer>> FollowSplitStepNode::ExtractSplitLengths(
   // ps->lengths.size()+1.
   PrimExpr last_factor = 1;
   for (; j < static_cast<int>(ps->lengths.size()); ++j) {
-    if (ps->lengths[j]) {
-      last_factor *= ps->lengths[j].value();
+    if (ps->lengths[j].defined()) {
+      last_factor *= ps->lengths[j];
     } else {
       last_factor = PrimExpr();
       break;
     }
   }
-  if (last_factor.defined()) {
-    lengths.push_back(Downcast<Integer>(last_factor));
-  } else {
-    lengths.push_back(NullOpt);
-  }
-
+  lengths.push_back(last_factor);
   return lengths;
 }
 
@@ -1176,8 +1171,7 @@ void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
   writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
 }
 
-Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
-    const Array<Step>& transform_steps) const {
+PrimExpr FollowFusedSplitStepNode::ExtractSplitLength(const Array<Step>& transform_steps) const {
   PrimExpr ret(1);
 
   for (int src_step_id : src_step_ids) {
@@ -1186,13 +1180,9 @@ Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
     auto ps = transform_steps[src_step_id].as<SplitStepNode>();
     ICHECK(ps != nullptr);
     // Multiple the splitting factor on corresponding splitting level of src_steps.
-    if (ps->lengths[level] && ret.defined()) {
-      ret *= ps->lengths[level].value();
-    } else {
-      return NullOpt;
-    }
+    ret *= ps->lengths[level];
   }
-  return Downcast<Integer>(ret);
+  return ret;
 }
 
 Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const {
diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h
index 2943adf19f26dbfb5781320fc8cfba32f4c152e6..0249c55e1b657a94799a357e5f0717065e682db8 100755
--- a/src/auto_scheduler/utils.h
+++ b/src/auto_scheduler/utils.h
@@ -105,12 +105,9 @@ inline void FindAndDeleteItem(std::vector<T>* array, const T& to_delete) {
 }
 
 /*! \brief Compute the product of all elements in a vector */
-inline int64_t ElementProduct(const std::vector<int>& array) {
-  int64_t ret = 1;
-  for (auto x : array) {
-    ret *= x;
-  }
-  return ret;
+inline PrimExpr ElementProduct(const std::vector<PrimExpr>& array) {
+  return std::accumulate(array.begin(), array.end(), PrimExpr(1),
+                         [](const PrimExpr& a, const PrimExpr& b) { return a * b; });
 }
 
 /*! \brief Move elements from multiple vectors to one vector */