diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index 0ca14c43eb470da85193492868c2f71831eb6433..55c023507e9d410503cd486dd4db0ed535478546 100755 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -350,7 +350,7 @@ class State : public ObjectRef { * most iterator of split results will become the new attach point. */ TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it, - const Array<Optional<Integer>>& lengths, + const Array<PrimExpr>& lengths, bool inner_to_outer = true); /*! * \brief The schedule primitive similar to split, but uses split factors from previous steps. diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index 4cc1551e76fcb33cfb0a9527610fde035204c8a6..45bc9f922e69d6d9932900f189ee5e42823fccdc 100755 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -505,7 +505,7 @@ class SplitStepNode : public StepNode { /*! \brief The extent length of the axis to split. */ Optional<PrimExpr> extent; /*! \brief The split factors. */ - Array<Optional<Integer>> lengths; + Array<PrimExpr> lengths; /*! * \brief If true, the `lengths` denote the lengths of iterators * from inner level to outer level @@ -561,7 +561,7 @@ class SplitStep : public Step { * \param inner_to_outer The split direction. */ SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent, - const Array<Optional<Integer>>& lengths, bool inner_to_outer); + const Array<PrimExpr>& lengths, bool inner_to_outer); /*! * \brief The constructor used to read a step record from JSONReader and create the @@ -591,7 +591,7 @@ class FollowSplitStepNode : public StepNode { * \param transform_steps An array of history transform steps. * \return The multiple split factors. */ - Array<Optional<Integer>> ExtractSplitLengths(const Array<Step>& transform_steps) const; + Array<PrimExpr> ExtractSplitLengths(const Array<Step>& transform_steps) const; /*! * \brief Apply the current step to State. @@ -672,7 +672,7 @@ class FollowFusedSplitStepNode : public StepNode { * \param transform_steps An array of history transform steps. * \return Split factor. */ - Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const; + PrimExpr ExtractSplitLength(const Array<Step>& transform_steps) const; /*! * \brief Apply the current step to State. diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc deleted file mode 100755 index 2bb722781f2f577e77eecb0465a38179ecc2d45f..0000000000000000000000000000000000000000 --- a/src/auto_scheduler/feature.cc +++ /dev/null @@ -1,1677 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file auto_scheduler/feature.cc - * \brief Feature extraction for the cost model - */ - -#include <tvm/arith/analyzer.h> -#include <tvm/auto_scheduler/feature.h> -#include <tvm/auto_scheduler/measure.h> -#include <tvm/auto_scheduler/measure_record.h> -#include <tvm/driver/driver_api.h> -#include <tvm/runtime/ndarray.h> -#include <tvm/runtime/registry.h> -#include <tvm/support/parallel_for.h> -#include <tvm/te/operation.h> -#include <tvm/te/schedule_pass.h> -#include <tvm/tir/analysis.h> -#include <tvm/tir/op_attr_types.h> -#include <tvm/tir/stmt_functor.h> -#include <tvm/tir/transform.h> - -#include <algorithm> -#include <cmath> -#include <numeric> -#include <unordered_map> -#include <vector> - -#include "search_policy/utils.h" -#include "utils.h" - -namespace tvm { -namespace auto_scheduler { - -using namespace tvm::tir; -using arith::Analyzer; -using arith::ConstIntBound; - -template <class T> -using BufferMap = std::unordered_map<Buffer, T, ObjectHash, ObjectEqual>; - -// The number of samples to extract for arithmetic intensity curves -static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; - -// Annotation position encoding -enum class AnnotationPosType : int { - kPosNone = 0, // Does not have this kind of annotation - kPosInnerSpatial = 1, // The annotated iterator is the innermost spatial iterator - kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial iterator - kPosOuterSpatial = 3, // The annotated iterator is the outermost spatial iterator - kPosInnerReduce = 4, // The annotated iterator is the innermost reduce iterator - kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator - kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator - kPosMixed = 7 // The annotated iterator is a mixed space and reduce iterator -}; - -// Buffer access type -enum class BufferAccessType : int { kRead = 0, kWrite = 1, kReadWrite = 2, kUnknownRW = 3 }; - -// Accesses to a buffer -struct BufferAccess { - // data reuse type - BufferAccessType acc_type{BufferAccessType::kUnknownRW}; - // Use a two-dimensional array to store multiple multi-dimensional accesses. - // The innermost vector stores the multi-dimensional indices of one access. - std::vector<std::vector<PrimExpr>> indices; -}; - -// Data reuse type -enum class ReuseType : int { kLoopMultipleRead = 0, kSerialMultipleReadWrite = 1, kNoReuse = 2 }; - -// Feature for an access of a buffer -struct BufferAccessFeature { - std::string buffer_name; // The name of the buffer - BufferAccessType acc_type; // The type of the access - float bytes; // The touched memory in bytes - float unique_bytes; // The touched unique memory in bytes - float lines; // The number of touched cache lines - float unique_lines; // The number touched unique cache lines - ReuseType reuse_type; // Tye type of data reuse - float reuse_dis_iter; // The reuse distance in iterator number - float reuse_dis_bytes; // The reuse distance in total touched bytes - float reuse_ct; // The reuse ratio - float bytes_d_reuse_ct; // bytes / reuse_ct - float unique_bytes_d_reuse_ct; // unique_bytes / reuse_ct - float lines_d_reuse_ct; // lines / reuse_ct - float unique_lines_d_reuse_ct; // unique_lines / reuse_ct - float stride; // The stride in access -}; - -// Feature set of a BufferStore statement -struct FeatureSet { - // Group 1: Computation related features - float float_mad; // The number of float MAD (Multiply–add) ops - float float_addsub; // The number of float add and sub ops - float float_mul; // The number of float multiply ops - float float_divmod; // The number of float div and mod ops - float float_cmp; // The number of float comparison ops - float float_math_func; // The number of float math func calls - float float_other_func; // The number of other float func calls - float int_mad; // The number of integer MAD (Multiply–add) ops - float int_addsub; // The number of integer add and sub ops - float int_mul; // The number of float multiply ops - float int_divmod; // The number of float div and mod ops - float int_cmp; // The number of float comparison ops - float int_math_func; // The number of float math func calls - float int_other_func; // The number of other float func calls - float bool_op; // The number of bool ops - float select_op; // The number of select ops - float vec_num; // The number of vectorized iterators - float vec_prod; // The product of the lengths of vectorized iterators - float vec_len; // The length of the innermost vectorized iterator - AnnotationPosType vec_type; // The type of vectorization position - float unroll_num; // The number of unrolled iterators - float unroll_prod; // The product of the lengths of vectorized iterators - float unroll_len; // The length of the innermost unrolled iterator - AnnotationPosType unroll_type; // The type of unroll position - float parallel_num; // The number of paralleled iterators - float parallel_prod; // The product of the lengths of paralleled iterators - float parallel_len; // The length of the innermost paralleled iterators - AnnotationPosType parallel_type; // The type of parallel position - float is_gpu; // Whether it is a GPU task - float blockIdx_x_len; // The length of blockIdx.x - float blockIdx_y_len; // The length of blockIdx.y - float blockIdx_z_len; // The length of blockIdx.z - float threadIdx_x_len; // The length of threadIdx.x - float threadIdx_y_len; // The length of threadIdx.y - float threadIdx_z_len; // The length of threadIdx.z - float vthread_len; // The length of virtual thread - - // Group 2: Buffer access related features (per buffer) - std::vector<BufferAccessFeature> access_feas; - - // Group 3: Arithmetic intensity related features - float arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N]; // points sampled from the - // arithmetic intensity curve - - // Group 4: Allocation related features - float alloc_size; // The size of allocated buffer in bytes - float alloc_outer_prod; // The product of lengths of loops outside the scope of the allocation - float alloc_inner_prod; // The product of lengths of loops inside the score of the allocation - float alloc_prod; // alloc_outer_prod * alloc_inner_prod - - // Group 5: Outer scope related features - float outer_prod; // The product of lengths of outer loops - float num_loops; // The number of outer loops - float auto_unroll_max_step; // The value of pragma "auto_unroll_max_step" -}; - -// Return whether a var is in an expr -bool VarInExpr(const Var& var, const PrimExpr& expr) { - bool find = false; - - PostOrderVisit(expr, [&find, &var](const ObjectRef& node) { - if (find) { - return; - } - - if (const VarNode* op = node.as<VarNode>()) { - if (op == var.get()) { - find = true; - } - } - }); - - return find; -} - -// Get position encoding for annotation -AnnotationPosType GetAnnotationPosEncoding(const Var& var, const Array<PrimExpr>& spatial_args, - const Array<IterVar>& axis, - const Array<IterVar>& reduce_axis) { - // Try to match spatial args first - size_t find_i = 0; - size_t find_ct = 0; - for (size_t i = 0; i < spatial_args.size(); ++i) { - if (VarInExpr(var, spatial_args[i])) { - find_i = i; - find_ct += 1; - } - } - - if (find_ct == 0) { - // If it is not found in spacial args, then it is a reduce iterator. - // Use name to match - const std::string& var_name = var->name_hint; - for (size_t i = 0; i < reduce_axis.size(); ++i) { - if (var_name.find(reduce_axis[i]->var->name_hint) != std::string::npos) { - find_i = i; - find_ct++; - } - } - if (find_ct >= 1) { - if (find_i == 0) { - return AnnotationPosType::kPosInnerReduce; - } else if (find_i == reduce_axis.size() - 1) { - return AnnotationPosType::kPosOuterReduce; - } else { - return AnnotationPosType::kPosMiddleReduce; - } - } else { - // If the axis is not found in both spatial args and reduce axis, - // then this stage must compute_at somewhere under this axis and this axis is simplified out - // We assume it is an outer spatial - return AnnotationPosType::kPosOuterSpatial; - } - } else if (find_ct == 1) { - if (find_i == spatial_args.size() - 1) { - return AnnotationPosType::kPosInnerSpatial; - } else if (find_i == 0) { - return AnnotationPosType::kPosOuterSpatial; - } else { - return AnnotationPosType::kPosMiddleSpatial; - } - } else { - return AnnotationPosType::kPosMixed; - } -} - -// Return the extent of a for loop -int64_t GetLoopExtent(const ForNode* node) { - auto pint = node->extent.as<IntImmNode>(); - if (pint != nullptr) { - return pint->value; - } else { - return 1; - } -} - -// Count math ops in an expr -class MathOpCounter : public StmtExprVisitor { - public: -#define VisitBinary(Type, float_ct, int_ct) \ - void VisitExpr_(const Type* op) final { \ - if (op->a.dtype().is_float()) { \ - float_ct++; \ - } else { \ - int_ct++; \ - } \ - StmtExprVisitor::VisitExpr_(op); \ - } - - VisitBinary(AddNode, float_addsub, int_addsub); - VisitBinary(SubNode, float_addsub, int_addsub); - VisitBinary(MulNode, float_mul, int_mul); - VisitBinary(DivNode, float_divmod, int_divmod); - VisitBinary(ModNode, float_divmod, int_divmod); - VisitBinary(FloorDivNode, float_divmod, int_divmod); - VisitBinary(FloorModNode, float_divmod, int_divmod); - VisitBinary(MaxNode, float_cmp, int_cmp); - VisitBinary(MinNode, float_cmp, int_cmp); - VisitBinary(EQNode, float_cmp, int_cmp); - VisitBinary(NENode, float_cmp, int_cmp); - VisitBinary(LTNode, float_cmp, int_cmp); - VisitBinary(LENode, float_cmp, int_cmp); - VisitBinary(GTNode, float_cmp, int_cmp); - VisitBinary(GENode, float_cmp, int_cmp); - -#undef VisitBinary - - void VisitExpr_(const AndNode* op) final { - bool_op++; - StmtExprVisitor::VisitExpr_(op); - } - void VisitExpr_(const OrNode* op) final { - bool_op++; - StmtExprVisitor::VisitExpr_(op); - } - void VisitExpr_(const NotNode* op) final { - bool_op++; - StmtExprVisitor::VisitExpr_(op); - } - void VisitExpr_(const SelectNode* op) final { - select_op++; - StmtExprVisitor::VisitExpr_(op); - } - - void VisitExpr_(const CallNode* op) final { - auto* pop = op->op.as<OpNode>(); - ICHECK(pop != nullptr); - auto effect_kind = op_call_effect_[GetRef<Op>(pop)]; - bool is_pure = - effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation; - - if (is_pure) { - if (op->dtype.is_float()) { - float_math_func++; - } else { - int_math_func++; - } - } else { - if (op->dtype.is_float()) { - float_other_func++; - } else { - int_other_func++; - } - } - StmtExprVisitor::VisitExpr_(op); - } - - // todo(merrymercy): Detect MAD (Multiply–add) - size_t float_mad{0}; // The number of float MAD (Multiply–add) ops - size_t float_addsub{0}; // The number of float add and sub ops - size_t float_mul{0}; // The number of float multiply ops - size_t float_divmod{0}; // The number of float div and mod ops - size_t float_cmp{0}; // The number of float comparison ops - size_t float_math_func{0}; // The number of float math func calls - size_t float_other_func{0}; // The number of other float func calls - size_t int_mad{0}; // The number of integer MAD (Multiply–add) ops - size_t int_addsub{0}; // The number of integer add and sub ops - size_t int_mul{0}; // The number of float multiply ops - size_t int_divmod{0}; // The number of float div and mod ops - size_t int_cmp{0}; // The number of float comparison ops - size_t int_math_func{0}; // The number of float math func calls - size_t int_other_func{0}; // The number of other float func calls - size_t bool_op{0}; // The number of bool ops - size_t select_op{0}; // The number of select ops - - OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind"); -}; - -// Extract all buffer accesses in an expr -class BufferAccessExtractor : public StmtExprVisitor { - public: - void ExtractReads(const PrimExpr& expr) { this->VisitExpr(expr); } - - void InsertAccess(const Buffer& buf, BufferAccessType acc_type, const Array<PrimExpr>& indices) { - BufferAccess& acc = buf_accesses[buf]; - acc.acc_type = acc_type; - acc.indices.push_back(std::vector<PrimExpr>(indices.begin(), indices.end())); - } - - void VisitExpr_(const BufferLoadNode* op) final { - BufferAccess& acc = buf_accesses[op->buffer]; - switch (acc.acc_type) { - case BufferAccessType::kRead: - break; - case BufferAccessType::kWrite: - acc.acc_type = BufferAccessType::kReadWrite; - break; - case BufferAccessType::kReadWrite: - break; - case BufferAccessType::kUnknownRW: - default: - acc.acc_type = BufferAccessType::kRead; - break; - } - - if (acc.acc_type != BufferAccessType::kReadWrite) { - // If a buffer is both read and written, in the tvm DSL, it must be a update, - // so the indices should be the same. Then we can skip appending indices for it. - // Otherwise we do the following. - buf_accesses[op->buffer].indices.push_back( - std::vector<PrimExpr>(op->indices.begin(), op->indices.end())); - } - StmtExprVisitor::VisitExpr_(op); - } - - BufferMap<BufferAccess> buf_accesses; -}; - -// Compute the coefficient for an loop iterator in an expression -// Note: we use an approximation strategy to find coefficient. -// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear) -class CoefficientExtractor : public StmtExprVisitor { - public: - void VisitExpr_(const MulNode* node) final { - StmtExprVisitor::VisitExpr_(node); - if (visited_var) { - if (!visited_add) { - if (auto a = node->a.as<IntImmNode>()) { - visited_mul = true; - stride = a->value; - } else if (auto b = node->b.as<IntImmNode>()) { - visited_mul = true; - stride = b->value; - } - } - } - } - - void VisitExpr_(const AddNode* node) final { - StmtExprVisitor::VisitExpr_(node); - if (visited_var) { - if (!visited_mul) { - visited_add = true; - stride = 1; - } - } - } - - void VisitExpr_(const VarNode* node) final { - if (node == var_) { - visited_var = true; - // This is a magic default stride in case our approximation strategy fails - stride = 2; - } - } - - int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) { - visited_var = visited_mul = visited_add = false; - var_ = var; - - this->VisitExpr(expr); - - if (visited_var && !visited_mul && !visited_add) { - return 1; - } else { - return stride; - } - } - - bool visited_var{false}; - bool visited_mul{false}; - bool visited_add{false}; - int stride{0}; - - private: - const VarNode* var_{nullptr}; -}; - -// Compute stride for the accesses to a buffer -int64_t ComputeStride(const std::vector<std::vector<PrimExpr>>& indices, - const std::vector<int>& shape, const VarNode* stride_var) { - int64_t min_stride = std::numeric_limits<int64_t>::max(); - bool find = false; - CoefficientExtractor extractor; - - for (const auto& index : indices) { - int64_t shape_stride = 1; - for (int i = static_cast<int>(index.size()) - 1; i >= 0; i--) { - int coefficient = extractor.ExtractCoefficient(index[i], stride_var); - if (extractor.visited_var) { - find = true; - min_stride = std::min(min_stride, std::abs(coefficient) * shape_stride); - break; - } - shape_stride *= shape[i]; - } - } - - return find ? min_stride : 0; -} - -// Compute touched bytes and cache lines for accesses to a buffer -void ComputeRegion(const std::vector<std::vector<PrimExpr>>& indices, arith::Analyzer* ana, - std::vector<int>* region) { - region->clear(); - - if (indices.empty()) { - return; - } - - region->reserve(indices[0].size()); - - if (indices.size() == 1) { - for (const auto& index : indices[0]) { - ConstIntBound bound = ana->const_int_bound(index); - region->push_back(bound->max_value - bound->min_value + 1); - } - } else { - // future(lmzheng): implement a more accurate IntSet? - for (size_t i = 0; i < indices[0].size(); ++i) { - int64_t minimum = ConstIntBound::kPosInf, maximum = ConstIntBound::kNegInf; - for (size_t j = 0; j < indices.size(); ++j) { - ConstIntBound bound = ana->const_int_bound(indices[j][i]); - - minimum = std::min(minimum, bound->min_value); - maximum = std::max(maximum, bound->max_value); - } - region->push_back(maximum - minimum + 1); - } - } -} - -// Compute reuse distance and reuse ratio for accesses to a buffer -// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct -std::tuple<ReuseType, float, float, float> ComputeReuse( - const Buffer& buf, const std::vector<std::vector<PrimExpr>>& indices, - const std::vector<const ForNode*>& for_loop_stack, - const std::unordered_map<const ForNode*, - BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>>& - for_touch_regions) { - float reuse_dis_iter = 1.0f; - float reuse_dis_bytes = -1.0f; - - for (int i = static_cast<int>(for_loop_stack.size()) - 1; i >= 0; --i) { - const ForNode* cur_for = for_loop_stack[i]; - bool find = false; - - for (size_t j = 0; j < indices.size(); j++) { - for (size_t k = 0; k < indices[j].size(); k++) { - if (VarInExpr(cur_for->loop_var, indices[j][k])) { - find = true; - break; - } - } - if (find) { - break; - } - } - - int64_t extent = GetLoopExtent(for_loop_stack[i]); - if (find) { - // accumulate/update reuse distance - reuse_dis_iter *= extent; - reuse_dis_bytes = 0.0f; - for (const auto& iter : for_touch_regions.at(cur_for)) { - for (const auto& access : iter.second) { - reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); - } - } - } else { - // Have LoopMultipleRead reuse - if (reuse_dis_bytes < 0) { - // For the reuse in the innermost axis, the above code won't be executed. - // So we compute bytes here - reuse_dis_bytes = 0.0f; - for (const auto& iter : for_touch_regions.at(cur_for)) { - for (const auto& access : iter.second) { - reuse_dis_bytes += 1 * std::get<2>(access); - } - } - } - return std::make_tuple(ReuseType::kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent); - } - - const BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>& buffer_map = - for_touch_regions.at(cur_for); - - int serial_reuse = static_cast<int>(buffer_map.at(buf).size()) - 1; - if (serial_reuse > 0) { - int64_t extent = GetLoopExtent(cur_for); - - // Have SerialMultipleReadWrite reuse - reuse_dis_iter = std::numeric_limits<float>::max(); - for (const auto& acc_info : buffer_map.at(buf)) { - reuse_dis_iter = std::min(reuse_dis_iter, static_cast<float>(std::get<1>(acc_info))); - } - - reuse_dis_bytes = 0.0f; - for (const auto& iter : for_touch_regions.at(cur_for)) { - for (const auto& access : iter.second) { - reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); - } - } - - return std::make_tuple(ReuseType::kSerialMultipleReadWrite, reuse_dis_iter / extent, - reuse_dis_bytes / extent, serial_reuse); - } - } - - return std::make_tuple(ReuseType::kNoReuse, 0, 0, 0); -} - -class StateFeaturesNode : public Object { - public: - Stmt state_code; - Array<BufferStore> buffers; - tvm::runtime::NDArray features; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("state_code", &state_code); - v->Visit("buffers", &buffers); - v->Visit("features", &features); - } - - void add_buffer_and_features(const BufferStore& bufNode, std::vector<size_t> features_) { - this->buffers.push_back(bufNode); - } - - static constexpr const char* _type_key = "ansor.StateFeatures"; - TVM_DECLARE_BASE_OBJECT_INFO(StateFeaturesNode, Object); -}; - -constexpr size_t NFeatures = 17; -using Features = std::array<size_t, NFeatures>; - -class StateFeatures : public ObjectRef { - public: - explicit StateFeatures( - const std::vector<std::pair<BufferStore, Features>>& buffers_and_features) { - auto node = make_object<StateFeaturesNode>(); - size_t n = buffers_and_features.size(); - node->features = tvm::runtime::NDArray::Empty({(int64_t)n, NFeatures}, - DLDataType{kDLFloat, 32, 1}, {kDLCPU, 0}); - for (size_t i = 0; i < n; i++) { - const auto& [buf, features] = buffers_and_features[i]; - node->buffers.push_back(buf); - for (size_t j = 0; j < NFeatures; ++j) - static_cast<float*>(node->features->data)[i * NFeatures + j] = (float)features[j]; - } - data_ = std::move(node); - } - - TVM_DEFINE_OBJECT_REF_METHODS(StateFeatures, ObjectRef, StateFeaturesNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(StateFeaturesNode); -}; - -TVM_REGISTER_NODE_TYPE(StateFeaturesNode); - -// Extract features for every BufferStore statement -class PerStoreFeatureExtractor : public StmtExprVisitor { - public: - explicit PerStoreFeatureExtractor(int cache_line_size) : cache_line_size_(cache_line_size) {} - - void VisitStmt_(const AttrStmtNode* node) final { - if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) { - const Var& var = node->node.as<IterVarNode>()->var; - int extent = GetIntImm(node->value); - - int* plen = nullptr; - - const std::string& name = var.get()->name_hint; - if (node->attr_key == tir::attr::thread_extent) { - if (name == "blockIdx.x") { - plen = &blockIdx_x_len_; - } else if (name == "blockIdx.y") { - plen = &block_idx_y_len_; - } else if (name == "blockIdx.z") { - plen = &block_idx_z_len_; - } else if (name == "threadIdx.x") { - plen = &threadIdx_x_len_; - } else if (name == "threadIdx.y") { - plen = &thread_idx_y_len_; - } else if (name == "threadIdx.z") { - plen = &thread_idx_z_len_; - } else { - LOG(FATAL) << "invalid thread itervar " + name; - } - } else { - plen = &vthread_len_; - } - - int extent_before = *plen; - if (node->attr_key == tir::attr::thread_extent) { - *plen = extent; - } else { - *plen *= extent; - } - - is_gpu_ = true; - - // make a fake for node for blockIdx.x or threadIdx.x - Stmt fake_for_node = For(var, 0, extent, ForKind::kParallel, node->body); - - outer_loop_prod_ *= extent; - for_loop_stack_.push_back(fake_for_node.as<ForNode>()); - StmtExprVisitor::VisitStmt_(node); - for_loop_stack_.pop_back(); - outer_loop_prod_ /= extent; - - *plen = extent_before; - } else if (node->attr_key == "pragma_auto_unroll_max_step") { - int value = GetIntImm(node->value); - - int16_t old_value = cur_auto_unroll_max_step_; - cur_auto_unroll_max_step_ = value; - StmtExprVisitor::VisitStmt_(node); - cur_auto_unroll_max_step_ = old_value; - } else { - StmtExprVisitor::VisitStmt_(node); - } - } - - void VisitStmt_(const ForNode* node) final { - int64_t loop_extent = GetLoopExtent(node); - - if (node->kind == ForKind::kVectorized) { - vec_for_stack_.push_back(node); - } else if (node->kind == ForKind::kUnrolled) { - unroll_for_stack_.push_back(node); - } else if (node->kind == ForKind::kParallel) { - parallel_for_stack_.push_back(node); - } - - outer_loop_prod_ *= loop_extent; - for_loop_stack_.push_back(node); - StmtExprVisitor::VisitStmt_(node); - for_loop_stack_.pop_back(); - outer_loop_prod_ /= loop_extent; - - if (node->kind == ForKind::kVectorized) { - vec_for_stack_.pop_back(); - } else if (node->kind == ForKind::kUnrolled) { - unroll_for_stack_.pop_back(); - } else if (node->kind == ForKind::kParallel) { - parallel_for_stack_.pop_back(); - } - } - - void VisitStmt_(const BufferStoreNode* node) final { - MathOpCounter math_op_counter; - math_op_counter(node->value); - - BufferStore store(node->buffer, node->value, node->indices, node->span); - std::array<size_t, 17> features = { - math_op_counter.float_mad, math_op_counter.float_addsub, - math_op_counter.float_mul, math_op_counter.float_divmod, - math_op_counter.float_cmp, math_op_counter.float_math_func, - math_op_counter.float_other_func, math_op_counter.int_mad, - math_op_counter.int_addsub, math_op_counter.int_mul, - math_op_counter.int_divmod, math_op_counter.int_math_func, - math_op_counter.int_cmp, math_op_counter.int_other_func, - math_op_counter.bool_op, math_op_counter.select_op, - (size_t)this->outer_loop_prod_}; - this->bufstore_and_feats.emplace_back(store, features); - - std::vector<float> mem_bytes_list; - std::vector<float> compute_ops_list; - double cur_compute_ops; - - // Group 1: Computation related features - ExtractComputationFeature(node, math_op_counter); - - // Group 2: Buffer access related features (per buffer) - ExtractBufferAccessFeature(node, math_op_counter, &cur_compute_ops, &compute_ops_list, - &mem_bytes_list); - - // Group 3: Arithmetic intensity related features - ExtractArithmeticIntensityFeature(node, cur_compute_ops, compute_ops_list, mem_bytes_list); - - // Group 4: Allocation related features - ExtractOuterScopeFeature(node); - } - - void VisitStmt_(const BufferRealizeNode* node) final { - StmtExprVisitor::VisitStmt_(node); - - // Group 5: Outer scope related features - ExtractAllocationFeature(node); - } - - // Extract computation related features (group 1) - void ExtractComputationFeature(const BufferStoreNode* node, - const MathOpCounter& math_op_counter) { - FeatureSet& fea = buffer_features[node->buffer]; - - // Computation related features - fea.float_mad = outer_loop_prod_ * math_op_counter.float_mad; - fea.float_addsub = outer_loop_prod_ * math_op_counter.float_addsub; - fea.float_mul = outer_loop_prod_ * math_op_counter.float_mul; - fea.float_divmod = outer_loop_prod_ * math_op_counter.float_divmod; - fea.float_cmp = outer_loop_prod_ * math_op_counter.float_cmp; - fea.float_math_func = outer_loop_prod_ * math_op_counter.float_math_func; - fea.float_other_func = outer_loop_prod_ * math_op_counter.float_other_func; - fea.int_mad = outer_loop_prod_ * math_op_counter.int_mad; - fea.int_addsub = outer_loop_prod_ * math_op_counter.int_addsub; - fea.int_mul = outer_loop_prod_ * math_op_counter.int_mul; - fea.int_divmod = outer_loop_prod_ * math_op_counter.int_divmod; - fea.int_math_func = outer_loop_prod_ * math_op_counter.int_math_func; - fea.int_cmp = outer_loop_prod_ * math_op_counter.int_cmp; - fea.int_other_func = outer_loop_prod_ * math_op_counter.int_other_func; - fea.bool_op = outer_loop_prod_ * math_op_counter.bool_op; - fea.select_op = outer_loop_prod_ * math_op_counter.select_op; - - fea.vec_len = fea.unroll_len = fea.parallel_len = 0.0f; - fea.vec_type = fea.unroll_type = fea.parallel_type = AnnotationPosType::kPosNone; - - fea.vec_num = vec_for_stack_.size(); - if (!vec_for_stack_.empty()) { - fea.vec_len = GetLoopExtent(vec_for_stack_.back()); - fea.vec_prod = 1.0; - for (const ForNode* pfor : vec_for_stack_) { - fea.vec_prod *= GetLoopExtent(pfor); - } - fea.vec_type = AnnotationPosType::kPosMixed; - // todo(merrymercy): this feature requires operation (tvm.compute) information - // GetAnnotationPosEncoding(vec_for_stack_.back()->loop_var, - // node->args, pcompute->axis, pcompute->reduce_axis); - } - - fea.unroll_num = unroll_for_stack_.size(); - if (!unroll_for_stack_.empty()) { - fea.unroll_len = GetLoopExtent(unroll_for_stack_.back()); - fea.unroll_prod = 1.0; - for (const ForNode* pfor : unroll_for_stack_) { - fea.unroll_prod *= GetLoopExtent(pfor); - } - fea.unroll_type = AnnotationPosType::kPosMixed; - // GetAnnotationPosEncoding(unroll_for_stack_.back()->loop_var, - // node->args, pcompute->axis, pcompute->reduce_axis); - } - - fea.parallel_num = parallel_for_stack_.size(); - if (!parallel_for_stack_.empty()) { - fea.parallel_len = GetLoopExtent(parallel_for_stack_.back()); - fea.parallel_prod = 1.0; - for (const ForNode* pfor : parallel_for_stack_) { - fea.parallel_prod *= GetLoopExtent(pfor); - } - fea.parallel_type = AnnotationPosType::kPosMixed; - // GetAnnotationPosEncoding(parallel_for_stack_.back()->loop_var, - // node->args, pcompute->axis, pcompute->reduce_axis); - } - - // GPU threads - fea.is_gpu = is_gpu_; - fea.blockIdx_x_len = blockIdx_x_len_; - fea.blockIdx_y_len = block_idx_y_len_; - fea.blockIdx_z_len = block_idx_z_len_; - fea.threadIdx_x_len = threadIdx_x_len_; - fea.threadIdx_y_len = thread_idx_y_len_; - fea.threadIdx_z_len = thread_idx_z_len_; - fea.vthread_len = vthread_len_; - } - - // Extract buffer access related features (group 2) - void ExtractBufferAccessFeature(const BufferStoreNode* node, const MathOpCounter& math_op_counter, - double* cur_compute_ops, std::vector<float>* compute_ops_list, - std::vector<float>* mem_bytes_list) { - FeatureSet& fea = buffer_features[node->buffer]; - - // Extract all buffer accesses - std::vector<BufferAccessFeature> acc_feas; - BufferAccessExtractor buf_extractor; - buf_extractor.InsertAccess(node->buffer, BufferAccessType::kWrite, node->indices); - buf_extractor.ExtractReads(node->value); - - // Compute touched region for all outer loops - for (auto x : for_loop_stack_) { - ana_.Bind(x->loop_var, Range::FromMinExtent(x->min, 1), true); - } - - mem_bytes_list->reserve(for_loop_stack_.size()); - compute_ops_list->reserve(for_loop_stack_.size()); - - *cur_compute_ops = math_op_counter.float_mad + math_op_counter.float_addsub + - math_op_counter.float_mul + math_op_counter.float_divmod + - math_op_counter.float_cmp + math_op_counter.float_math_func + - math_op_counter.float_other_func; - - std::vector<int> tmp_region; - for (int i = static_cast<int>(for_loop_stack_.size()) - 1; i >= 0; i--) { - const ForNode* p_for = for_loop_stack_[i]; - - ana_.Bind(p_for->loop_var, - Range::FromMinExtent(for_loop_stack_[i]->min, for_loop_stack_[i]->extent), true); - - // Note, here we do overwrite. - // So if there are multiple BufferStoreNode, the last one will overwrite the first few. - // e.g. The update part in gemm will overwrite the init part. - BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>& buffer_regions_map = - for_touch_regions_[p_for]; - - int64_t mem_bytes = 0; - for (const auto& x : buf_extractor.buf_accesses) { - const Buffer& t = x.first; - const BufferAccess& acc = x.second; - - ComputeRegion(acc.indices, &ana_, &tmp_region); - int64_t touched_size = ElementProduct(tmp_region); - buffer_regions_map[t].push_back( - std::make_tuple(acc.acc_type, touched_size, t->dtype.bytes())); - mem_bytes += touched_size * t->dtype.bytes(); - } - - mem_bytes_list->push_back(std::log2(mem_bytes)); - *cur_compute_ops *= GetLoopExtent(for_loop_stack_[i]); - compute_ops_list->push_back(std::log2(*cur_compute_ops)); - } - - // Buffer access related features (per buffer) - for (const auto& x : buf_extractor.buf_accesses) { - const Buffer& t = x.first; - const BufferAccess& acc = x.second; - - std::vector<int> int_shape; - for (const auto& dim : t->shape) { - int_shape.push_back(GetIntImm(dim)); - } - - size_t ele_bytes = t->dtype.bytes(); - - // calculate bytes - float bytes = outer_loop_prod_ * ele_bytes; - float unique_bytes; - - // calculate cache lines - int64_t stride; - float lines; - float unique_lines; - - if (for_loop_stack_.empty()) { - unique_bytes = ele_bytes; - stride = 0; - lines = 1.0f; - unique_lines = 1.0f; - } else { - unique_bytes = - std::get<1>(for_touch_regions_[for_loop_stack_.front()][t].front()) * ele_bytes; - - stride = 0; - int64_t reduce_ratio = 1; - - int i; - for (i = static_cast<int>(for_loop_stack_.size()) - 1; i >= 0; i--) { - stride = ComputeStride(acc.indices, int_shape, for_loop_stack_[i]->loop_var.get()); - if (stride != 0) { - break; - } - reduce_ratio *= GetLoopExtent(for_loop_stack_.back()); - } - - lines = outer_loop_prod_ / reduce_ratio * - std::min(1.0f, 1.0f * stride * ele_bytes / cache_line_size_); - lines = std::max(lines, 1.0f); - - // convert `stride` back to the stride of the innermost iterator - stride = (i == static_cast<int>(for_loop_stack_.size()) - 1 ? stride : 0); - - float n_continuous = ele_bytes; - for (int i = std::min(static_cast<int>(tmp_region.size()) - 1, - static_cast<int>(int_shape.size()) - 1); - i >= 0; i--) { - if (tmp_region[i] == int_shape[i]) { - n_continuous *= tmp_region[i]; - break; - } - } - unique_lines = unique_bytes / std::min(n_continuous, static_cast<float>(cache_line_size_)); - unique_lines = std::max(unique_lines, 1.0f); - } - - ReuseType reuse_type; - float reuse_dis_iter, reuse_dis_bytes, reuse_ct; - std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) = - ComputeReuse(t, acc.indices, for_loop_stack_, for_touch_regions_); - - acc_feas.emplace_back(); - BufferAccessFeature& acc_fea = acc_feas.back(); - - acc_fea.buffer_name = t->name; - acc_fea.acc_type = acc.acc_type; - acc_fea.stride = stride; - acc_fea.bytes = bytes; - acc_fea.unique_bytes = unique_bytes; - acc_fea.lines = lines; - acc_fea.unique_lines = unique_lines; - acc_fea.reuse_type = reuse_type; - acc_fea.reuse_dis_iter = reuse_dis_iter; - acc_fea.reuse_dis_bytes = reuse_dis_bytes; - acc_fea.reuse_ct = reuse_ct; - if (acc_fea.reuse_ct > 0.5) { - acc_fea.bytes_d_reuse_ct = bytes / reuse_ct; - acc_fea.unique_bytes_d_reuse_ct = unique_bytes / reuse_ct; - acc_fea.lines_d_reuse_ct = lines / reuse_ct; - acc_fea.unique_lines_d_reuse_ct = unique_lines / reuse_ct; - } else { - // no reuse, multiply by a magic number '2' - acc_fea.bytes_d_reuse_ct = bytes * 2; - acc_fea.unique_bytes_d_reuse_ct = unique_bytes * 2; - acc_fea.lines_d_reuse_ct = lines * 2; - acc_fea.unique_lines_d_reuse_ct = unique_lines * 2; - } - } - - fea.access_feas = acc_feas; - } - - // Extract arithmetic intensity related feature (group 3) - void ExtractArithmeticIntensityFeature(const BufferStoreNode* node, double cur_compute_ops, - const std::vector<float>& compute_ops_list, - const std::vector<float>& mem_bytes_list) { - FeatureSet& fea = buffer_features[node->buffer]; - - // Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops). - // We use piecewise linear interpolation to fit this curve. - int pt = 0; - if (cur_compute_ops <= 0 || compute_ops_list.empty()) { - std::fill(fea.arith_intensity_curve, - fea.arith_intensity_curve + ARITH_INTENSITY_CURVE_SAMPLE_N, 0.0); - } else { - for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { - float cur_compute_ops = compute_ops_list.back() * (i + 1) / ARITH_INTENSITY_CURVE_SAMPLE_N; - while (compute_ops_list[pt] < cur_compute_ops - 1e-4) { - pt++; - } - ICHECK_LT(pt, compute_ops_list.size()); - - float value; - if (pt == 0) { - value = compute_ops_list[pt] / mem_bytes_list[pt]; - } else { - float base = compute_ops_list[pt - 1] / mem_bytes_list[pt - 1]; - float slope = (compute_ops_list[pt] / mem_bytes_list[pt] - - compute_ops_list[pt - 1] / mem_bytes_list[pt - 1]) / - (compute_ops_list[pt] - compute_ops_list[pt - 1]); - value = base + slope * (cur_compute_ops - compute_ops_list[pt - 1]); - } - fea.arith_intensity_curve[i] = value; - } - } - } - - // Extract allocation related features (group 4) - void ExtractAllocationFeature(const BufferRealizeNode* node) { - FeatureSet& fea = buffer_features[node->buffer]; - - float allocation_size = 1.0f; - for (const auto& x : node->bounds) { - allocation_size *= GetIntImm(x->extent); - } - // allocation feature - fea.alloc_size = allocation_size * node->buffer->dtype.bytes(); - fea.alloc_prod = allocation_size * outer_loop_prod_; - fea.alloc_outer_prod = outer_loop_prod_; - fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod_; - } - - // Extract outer scope related features (group 5) - void ExtractOuterScopeFeature(const BufferStoreNode* node) { - FeatureSet& fea = buffer_features[node->buffer]; - - fea.outer_prod = outer_loop_prod_; - fea.num_loops = for_loop_stack_.size(); - fea.auto_unroll_max_step = cur_auto_unroll_max_step_; - } - - // Stores FeatureSet for every buffer - BufferMap<FeatureSet> buffer_features; - - std::vector<std::pair<BufferStore, std::array<size_t, 17>>> bufstore_and_feats; - - private: - // The shared arithmetic analyzer - Analyzer ana_; - - // The product of outer loop - float outer_loop_prod_ = 1.0f; - - // The stacks to store parent loops during DFS - std::vector<const ForNode*> for_loop_stack_; - std::vector<const ForNode*> parallel_for_stack_; - std::vector<const ForNode*> vec_for_stack_; - std::vector<const ForNode*> unroll_for_stack_; - - // GPU-related features - bool is_gpu_{false}; - int blockIdx_x_len_{1}; - int block_idx_y_len_{1}; - int block_idx_z_len_{1}; - int threadIdx_x_len_{1}; - int thread_idx_y_len_{1}; - int thread_idx_z_len_{1}; - int vthread_len_{1}; - int16_t cur_auto_unroll_max_step_{0}; - - // Store touch region information for all for loops. The format of this nested map: - // For a loop, for all its touched buffers, for all different accesses to the buffers, - // its (access type, number of touched elements, number of bytes of single element) - std::unordered_map<const ForNode*, - BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>> - for_touch_regions_; - - // The default cache line size in bytes - const int cache_line_size_ = 64; -}; - -// shifted log to incorporate the property that slog(0) = 0 -// inline float slog(float x) { return x < 0 ? -std::log2(-x + 1) : std::log2(x + 1); } -inline float slog(float x) { return x; } - -void GetPerStoreFeature(const Stmt& stmt, int cache_line_size, int max_n_bufs, - std::vector<float>* ret, StateFeatures* st_features) { - PerStoreFeatureExtractor extractor(cache_line_size); - extractor(stmt); - - ret->push_back(extractor.buffer_features.size()); - if (st_features) *st_features = StateFeatures(extractor.bufstore_and_feats); - - for (const auto& x : extractor.buffer_features) { - const FeatureSet& fea_set = x.second; - - /***** Group 1: Computation related features *****/ - ret->push_back(slog(fea_set.float_mad)); - ret->push_back(slog(fea_set.float_addsub)); - ret->push_back(slog(fea_set.float_mul)); - ret->push_back(slog(fea_set.float_divmod)); - ret->push_back(slog(fea_set.float_cmp)); - ret->push_back(slog(fea_set.float_math_func)); - ret->push_back(slog(fea_set.float_other_func)); - ret->push_back(slog(fea_set.int_mad)); - ret->push_back(slog(fea_set.int_addsub)); - ret->push_back(slog(fea_set.int_mul)); - ret->push_back(slog(fea_set.int_divmod)); - ret->push_back(slog(fea_set.int_cmp)); - ret->push_back(slog(fea_set.int_math_func)); - ret->push_back(slog(fea_set.int_other_func)); - ret->push_back(slog(fea_set.bool_op)); - ret->push_back(slog(fea_set.select_op)); - - ret->push_back(slog(fea_set.vec_num)); - ret->push_back(slog(fea_set.vec_prod)); - ret->push_back(slog(fea_set.vec_len)); - for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) { - ret->push_back(i == static_cast<int>(fea_set.vec_type)); - } - - ret->push_back(slog(fea_set.unroll_num)); - ret->push_back(slog(fea_set.unroll_prod)); - ret->push_back(slog(fea_set.unroll_len)); - for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) { - ret->push_back(i == static_cast<int>(fea_set.unroll_type)); - } - - ret->push_back(slog(fea_set.parallel_num)); - ret->push_back(slog(fea_set.parallel_prod)); - ret->push_back(slog(fea_set.parallel_len)); - for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) { - ret->push_back(i == static_cast<int>(fea_set.parallel_type)); - } - - ret->push_back(fea_set.is_gpu); - ret->push_back(slog(fea_set.blockIdx_x_len)); - ret->push_back(slog(fea_set.blockIdx_y_len)); - ret->push_back(slog(fea_set.blockIdx_z_len)); - ret->push_back(slog(fea_set.threadIdx_x_len)); - ret->push_back(slog(fea_set.threadIdx_y_len)); - ret->push_back(slog(fea_set.threadIdx_z_len)); - ret->push_back(slog(fea_set.vthread_len)); - - /***** Group 2: Buffer access related features *****/ - // sort according to pair (lines, bytes) - std::vector<std::pair<float, float>> buf_order_key; - for (const auto& acc_fea : fea_set.access_feas) { - buf_order_key.emplace_back(acc_fea.lines, acc_fea.bytes); - } - std::vector<int> buf_order(buf_order_key.size()); - std::iota(buf_order.begin(), buf_order.end(), 0); - - auto cmp = [&buf_order_key](int l, int r) { - return buf_order_key[l].first > buf_order_key[r].first || - (buf_order_key[l].first == buf_order_key[r].first && - buf_order_key[l].second > buf_order_key[r].second); - }; - std::sort(buf_order.begin(), buf_order.end(), cmp); - int n_bufs = std::min(max_n_bufs, static_cast<int>(buf_order.size())); - buf_order.resize(n_bufs); - - for (int idx : buf_order) { - const auto& acc_fea = fea_set.access_feas[idx]; - for (int j = 0; j <= static_cast<int>(BufferAccessType::kReadWrite); ++j) { - ret->push_back(j == static_cast<int>(acc_fea.acc_type)); - } - ret->push_back(slog(acc_fea.bytes)); - ret->push_back(slog(acc_fea.unique_bytes)); - ret->push_back(slog(acc_fea.lines)); - ret->push_back(slog(acc_fea.unique_lines)); - for (int j = 0; j <= static_cast<int>(ReuseType::kNoReuse); ++j) { - ret->push_back(j == static_cast<int>(acc_fea.reuse_type)); - } - ret->push_back(slog(acc_fea.reuse_dis_iter)); - ret->push_back(slog(acc_fea.reuse_dis_bytes)); - ret->push_back(slog(acc_fea.reuse_ct)); - ret->push_back(slog(acc_fea.bytes_d_reuse_ct)); - ret->push_back(slog(acc_fea.unique_bytes_d_reuse_ct)); - ret->push_back(slog(acc_fea.lines_d_reuse_ct)); - ret->push_back(slog(acc_fea.unique_lines_d_reuse_ct)); - ret->push_back(slog(acc_fea.stride)); - } - // - fill padding - for (int i = 0; i < max_n_bufs - n_bufs; ++i) { - for (int j = 0; j <= static_cast<int>(BufferAccessType::kReadWrite); ++j) { // 3 - ret->push_back(0.0f); - } - ret->push_back(0.0f); - ret->push_back(0.0f); - ret->push_back(0.0f); - ret->push_back(0.0f); - for (int j = 0; j <= static_cast<int>(ReuseType::kNoReuse); ++j) { // 3 - ret->push_back(0.0f); - } - ret->push_back(0.0f); - ret->push_back(0.0f); - ret->push_back(0.0f); - ret->push_back(0.0f); - ret->push_back(0.0f); - ret->push_back(0.0f); - ret->push_back(0.0f); - ret->push_back(0.0f); - } - - /***** Group 3: Arithmetic intensity related features *****/ - for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { - ret->push_back(fea_set.arith_intensity_curve[i]); - } - - /***** Group 4: Allocation related features *****/ - ret->push_back(slog(fea_set.alloc_size)); - ret->push_back(slog(fea_set.alloc_prod)); - ret->push_back(slog(fea_set.alloc_outer_prod)); - ret->push_back(slog(fea_set.alloc_inner_prod)); - - /***** Group 5: Outer scope related features *****/ - ret->push_back(slog(fea_set.outer_prod)); - ret->push_back(slog(fea_set.num_loops)); - ret->push_back(slog(fea_set.auto_unroll_max_step)); - } -} - -void GetPerStoreFeatureName(int max_n_bufs, std::vector<std::string>* ret) { - /***** Group 1: Computation related features *****/ - ret->push_back(("float_mad")); - ret->push_back(("float_addsub")); - ret->push_back(("float_mul")); - ret->push_back(("float_divmod")); - ret->push_back(("float_cmp")); - ret->push_back(("float_mathfunc")); - ret->push_back(("float_otherfunc")); - ret->push_back(("int_mad")); - ret->push_back(("int_addsub")); - ret->push_back(("int_mul")); - ret->push_back(("int_divmod")); - ret->push_back(("int_cmp")); - ret->push_back(("int_mathfunc")); - ret->push_back(("int_otherfunc")); - ret->push_back(("bool_op")); - ret->push_back(("select_op")); - ret->push_back(("vec_num")); - ret->push_back(("vec_prod")); - ret->push_back(("vec_len")); - ret->push_back(("vec_type.kPosNone")); - ret->push_back(("vec_type.kPosInnerSpatial")); - ret->push_back(("vec_type.kPosMiddleSpatial")); - ret->push_back(("vec_type.kPosOuterSpatial")); - ret->push_back(("vec_type.kPosInnerReduce")); - ret->push_back(("vec_type.kPosMiddleReduce")); - ret->push_back(("vec_type.kPosOuterReduce")); - ret->push_back(("vec_type.kPosMixed")); - ret->push_back(("unroll_num")); - ret->push_back(("unroll_prod")); - ret->push_back(("unroll_len")); - ret->push_back(("unroll_type.kPosNone")); - ret->push_back(("unroll_type.kPosInnerSpatial")); - ret->push_back(("unroll_type.kPosMiddleSpatial")); - ret->push_back(("unroll_type.kPosOuterSpatial")); - ret->push_back(("unroll_type.kPosInnerReduce")); - ret->push_back(("unroll_type.kPosMiddleReduce")); - ret->push_back(("unroll_type.kPosOuterReduce")); - ret->push_back(("unroll_type.kPosMixed")); - ret->push_back(("parallel_num")); - ret->push_back(("parallel_prod")); - ret->push_back(("parallel_len")); - ret->push_back(("parallel_type.kPosNone")); - ret->push_back(("parallel_type.kPosInnerSpatial")); - ret->push_back(("parallel_type.kPosMiddleSpatial")); - ret->push_back(("parallel_type.kPosOuterSpatial")); - ret->push_back(("parallel_type.kPosInnerReduce")); - ret->push_back(("parallel_type.kPosMiddleReduce")); - ret->push_back(("parallel_type.kPosOuterReduce")); - ret->push_back(("parallel_type.kPosMixed")); - ret->push_back(("is_gpu")); - ret->push_back(("blockIdx_x_len")); - ret->push_back(("blockIdx_y_len")); - ret->push_back(("blockIdx_z_len")); - ret->push_back(("threadIdx_x_len")); - ret->push_back(("threadIdx_y_len")); - ret->push_back(("threadIdx_z_len")); - ret->push_back(("vthread_len")); - // section total: 57 - - /***** Group 2: Buffer access related features *****/ - for (size_t i = 0; i < static_cast<size_t>(max_n_bufs); ++i) { - std::string prefix = "B" + std::to_string(i) + "."; - ret->push_back((prefix + "acc_type.kRead")); - ret->push_back((prefix + "acc_type.kWrite")); - ret->push_back((prefix + "acc_type.kReadWrite")); - ret->push_back((prefix + "bytes")); - ret->push_back((prefix + "unique_bytes")); - ret->push_back((prefix + "lines")); - ret->push_back((prefix + "unique_lines")); - ret->push_back((prefix + "reuse_type.kLoopMultipleRead")); - ret->push_back((prefix + "reuse_type.kSerialMultipleReadWrite")); - ret->push_back((prefix + "reuse_type.kNoReuse")); - ret->push_back((prefix + "reuse_dis_iter")); - ret->push_back((prefix + "reuse_dis_bytes")); - ret->push_back((prefix + "reuse_ct")); - ret->push_back((prefix + "bytes_d_reuse_ct")); - ret->push_back((prefix + "unique_bytes_d_reuse_ct")); - ret->push_back((prefix + "lines_d_reuse_ct")); - ret->push_back((prefix + "unique_lines_d_reuse_ct")); - ret->push_back((prefix + "stride")); - } - // section total : max_n_bufs * 18 - - /***** Group 3: Arithmetic intensity related features *****/ - for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { - ret->push_back(("arith_intensity_curve_" + std::to_string(i))); - } - // section total: ARITH_INTENSITY_CURVE_SAMPLE_N = 10 - - /***** Group 4: Allocation related features *****/ - ret->push_back(("alloc_size")); - ret->push_back(("alloc_prod")); - ret->push_back(("alloc_outer_prod")); - ret->push_back(("alloc_inner_prod")); - // section total : 4 - - /***** Group 5: Outer scope related features *****/ - ret->push_back(("outer_prod")); - ret->push_back(("num_loops")); - ret->push_back(("auto_unroll_max_step")); - // section total : 3 -} - -void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, int max_n_bufs, - std::vector<float>* feature, std::atomic<int>* error_ct, - StateFeatures* st_features = nullptr) { - Stmt code_body = GenerateCodeForState(task, state); - if (!code_body.defined()) - (*error_ct)++; - else - GetPerStoreFeature(code_body, task->hardware_params->cache_line_bytes, max_n_bufs, feature, - st_features); -} - -void GetPerStoreFeaturesFromStates(const Array<State>& states, const SearchTask& task, - int skip_first_n_feature_extraction, int max_n_bufs, - std::vector<std::vector<float>>* features) { - // extract features - features->assign(states.size(), std::vector<float>()); - - std::atomic<int> error_ct(0); - - support::parallel_for(skip_first_n_feature_extraction, states.size(), - [&task, &states, &max_n_bufs, &features, &error_ct](int i) { - GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, - &(*features)[i], &error_ct); - }); -} - -void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector<SearchTask>& tasks, - int skip_first_n_feature_extraction, int max_n_bufs, - std::vector<std::vector<float>>* features) { - // extract features - features->assign(states.size(), std::vector<float>()); - - std::atomic<int> error_ct(0); - - support::parallel_for(skip_first_n_feature_extraction, states.size(), - [&tasks, &states, &max_n_bufs, &features, &error_ct](int i) { - GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs, - &(*features)[i], &error_ct); - }); -} - -void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs, - std::vector<std::vector<float>>* features, - std::vector<float>* normalized_throughputs, - std::vector<int>* task_ids) { - Array<State> states; - std::vector<SearchTask> tasks; - - normalized_throughputs->clear(); - task_ids->clear(); - - // (workload_key, target) -> (search_task, task_id) - std::unordered_map<std::pair<std::string, std::string>, std::pair<SearchTask, size_t>> task_cache; - // task_id -> min_cost - std::vector<float> min_costs; - - const auto* workload_key_to_tensors = - tvm::runtime::Registry::Get("auto_scheduler.workload_key_to_tensors"); - ICHECK(workload_key_to_tensors != nullptr); - - // read from file - RecordReader reader(filename); - auto cur_inp = make_object<MeasureInputNode>(); - auto cur_res = make_object<MeasureResultNode>(); - while (reader->ReadNext(cur_inp.get(), cur_res.get())) { - float cost = static_cast<float>(FloatArrayMean(cur_res->costs)); - const std::string& workload_key = cur_inp->task->workload_key; - - SearchTask task; - size_t task_id; - std::pair<std::string, std::string> key(workload_key, cur_inp->task->target->str()); - auto find_res = task_cache.find(key); - if (find_res == task_cache.end()) { - // rebuild task - Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key); - Target target = cur_inp->task->target; - Target target_host = cur_inp->task->target_host; - CheckAndUpdateHostConsistency(&target, &target_host); - task = SearchTask(ComputeDAG(tensors), workload_key, target, target_host, - cur_inp->task->hardware_params, cur_inp->task->layout_rewrite_option, - cur_inp->task->task_input_names); - task_id = task_cache.size(); - - // compute min cost for each task - task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); - min_costs.push_back(cost); - } else { - std::tie(task, task_id) = find_res->second; - min_costs[task_id] = std::min(min_costs[task_id], cost); - } - - tasks.push_back(std::move(task)); - task_ids->push_back(task_id); - states.push_back(cur_inp->state); - normalized_throughputs->push_back(cost); - - if (max_lines > 0 && static_cast<int>(states.size()) >= max_lines) { - break; - } - } - - for (size_t i = 0; i < normalized_throughputs->size(); ++i) { - (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; - } - - GetPerStoreFeaturesFromStates(states, tasks, 0, max_n_bufs, features); -} - -void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs, - const Array<MeasureResult>& results, - int skip_first_n_feature_extraction, int max_n_bufs, - std::vector<std::vector<float>>* features, - std::vector<float>* normalized_throughputs, - std::vector<int>* task_ids) { - Array<State> states; - std::vector<SearchTask> tasks; - - normalized_throughputs->clear(); - task_ids->clear(); - - // (workload_key, target) -> (search_task, task_id) - std::unordered_map<std::pair<std::string, std::string>, std::pair<SearchTask, size_t>> task_cache; - // task_id -> min_cost - std::vector<float> min_costs; - - const auto* workload_key_to_tensors = - tvm::runtime::Registry::Get("auto_scheduler.workload_key_to_tensors"); - ICHECK(workload_key_to_tensors != nullptr); - - tasks.reserve(inputs.size()); - normalized_throughputs->reserve(inputs.size()); - task_ids->reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) { - float cost = static_cast<float>(FloatArrayMean(results[i]->costs)); - const std::string& workload_key = inputs[i]->task->workload_key; - SearchTask task; - - size_t task_id; - std::pair<std::string, std::string> key(workload_key, inputs[i]->task->target->str()); - auto find_res = task_cache.find(key); - if (find_res == task_cache.end()) { - if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete - task = inputs[i]->task; - } else { - // The measure input is incomplete, rebuild task for incomplete measure pairs read from file - try { - Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key); - Target target = inputs[i]->task->target; - Target target_host = inputs[i]->task->target_host; - CheckAndUpdateHostConsistency(&target, &target_host); - task = - SearchTask(ComputeDAG(tensors), workload_key, target, target_host, - inputs[i]->task->hardware_params, inputs[i]->task->layout_rewrite_option, - inputs[i]->task->task_input_names); - } catch (std::exception& e) { - // Cannot build ComputeDAG from workload key, the task may have not been registered in - // this search round - continue; - } - } - task_id = task_cache.size(); - - // compute min cost for each task - task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); - min_costs.push_back(cost); - } else { - std::tie(task, task_id) = find_res->second; - min_costs[task_id] = std::min(min_costs[task_id], cost); - } - - tasks.push_back(std::move(task)); - task_ids->push_back(task_id); - states.push_back(inputs[i]->state); - normalized_throughputs->push_back(cost); - } - - for (size_t i = 0; i < normalized_throughputs->size(); ++i) { - (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; - } - - GetPerStoreFeaturesFromStates(states, tasks, skip_first_n_feature_extraction, max_n_bufs, - features); -} - -/* - * \brief Serialize a two-dimensional variable-size feature vector with normalized throughputs - * and task ids to a one-dimensional flatten byte array. - * - * For faster data copy between c++ and python, the c++ part returns features in a single - * flatten array using a packed format. The python part then unpacks the flatten array. - * - * The packed format for n records is: - * { - * int n; - * int sizes[n+2]; // The sizes for the following arrays - * - * float features_0[size[0]]; // The features for record 0 - * float features_1[size[1]]; // The features for record 1 - * ... - * float features_i[size[i]]; // The features for record i - * ... // until i == n - 1 - * - * float throughputs[sizes[n]]; // The normalized throughputs for n records - * int task_ids[size[n+1]]; // The task ids for n records - * - * } - * To implement this format, we also store int as float, so we can store all numbers - * into a single float array. - */ -TVMByteArray SerializeFeatures(std::vector<std::vector<float>>&& features, - std::vector<float>&& normalized_throughputs, - std::vector<int>&& task_ids, std::vector<char>* out_data) { - size_t total_bytes = 0; - std::vector<int> size_vector; - - int n = features.size(); - - // serialize sizes - size_t size_vector_size = 1 + n + 2; - total_bytes += size_vector_size * sizeof(int); - - size_vector.reserve(size_vector_size); - size_vector.push_back(features.size()); - for (const auto& x : features) { - size_vector.push_back(static_cast<int>(x.size())); - total_bytes += sizeof(float) * x.size(); - } - size_vector.push_back(static_cast<int>(normalized_throughputs.size())); - total_bytes += sizeof(float) * normalized_throughputs.size(); - size_vector.push_back(static_cast<int>(task_ids.size())); - total_bytes += sizeof(int) * task_ids.size(); - - ICHECK_EQ(size_vector.size(), size_vector_size); - - // allocate memory - out_data->reserve(total_bytes); - char* ptr = out_data->data(); - - // serialize size_vector - memmove(ptr, reinterpret_cast<char*>(size_vector.data()), size_vector.size() * sizeof(int)); - ptr += size_vector.size() * sizeof(int); - - // serialize features - for (auto& x : features) { - memmove(ptr, x.data(), sizeof(float) * x.size()); - ptr += sizeof(float) * x.size(); - x.clear(); - } - - // serialize normalized_throughputs - memmove(ptr, reinterpret_cast<char*>(normalized_throughputs.data()), - normalized_throughputs.size() * sizeof(int)); - ptr += normalized_throughputs.size() * sizeof(int); - - // serialize task_ids - memmove(ptr, reinterpret_cast<char*>(task_ids.data()), task_ids.size() * sizeof(int)); - ptr += task_ids.size() * sizeof(int); - - ICHECK_EQ(ptr - out_data->data(), total_bytes); - - return TVMByteArray{out_data->data(), total_bytes}; -} - -TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromFile") - .set_body([](TVMArgs args, TVMRetValue* ret) { - std::string filename = args[0]; - int max_lines = args[1]; - int max_n_bufs = args[2]; - - std::vector<std::vector<float>> features; - std::vector<float> normalized_throughputs; - std::vector<int> task_ids; - - GetPerStoreFeaturesFromFile(filename, max_lines, max_n_bufs, &features, - &normalized_throughputs, &task_ids); - - std::vector<char> byte_data; - *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), - std::move(task_ids), &byte_data); - }); - -TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromMeasurePairs") - .set_body([](TVMArgs args, TVMRetValue* ret) { - Array<MeasureInput> inputs = args[0]; - Array<MeasureResult> results = args[1]; - int skip_first_n_feature_extraction = args[2]; - int max_n_bufs = args[3]; - - std::vector<std::vector<float>> features; - std::vector<float> normalized_throughputs; - std::vector<int> task_ids; - - GetPerStoreFeaturesFromMeasurePairs(inputs, results, skip_first_n_feature_extraction, - max_n_bufs, &features, &normalized_throughputs, - &task_ids); - - std::vector<char> byte_data; - *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), - std::move(task_ids), &byte_data); - }); - -TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromStates") - .set_body([](TVMArgs args, TVMRetValue* ret) { - Array<State> states = args[0]; - SearchTask task = args[1]; - int max_n_bufs = args[2]; - - std::vector<std::vector<float>> features; - std::vector<float> normalized_throughputs; - std::vector<int> task_ids; - - GetPerStoreFeaturesFromStates(states, task, 0, max_n_bufs, &features); - - std::vector<char> byte_data; - *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), - std::move(task_ids), &byte_data); - }); - -TVM_REGISTER_GLOBAL("auto_scheduler.GetBufAndFeatsFromStates") - .set_body_typed([](Array<State> states, SearchTask task, int max_n_bufs) { - std::atomic<int> error_ct(0); - Array<StateFeatures> buf_and_feats; - for (size_t i = 0; i < states.size(); ++i) { - StateFeatures st_features; - std::vector<float> features; - GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, &features, &error_ct, - &st_features); - if (!st_features.defined()) { - LOG_WARNING << "No code generated for state " << i << ", error_ct =" << error_ct - << std::endl; - st_features = StateFeatures(std::vector<std::pair<BufferStore, Features>>()); - } - st_features.CopyOnWrite()->state_code = GenerateCodeForState(task, states[i], true); - buf_and_feats.push_back(st_features); - } - return buf_and_feats; - }); - -TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeatureNames") - .set_body([](TVMArgs args, TVMRetValue* ret) { - int max_n_bufs = args[0]; - std::vector<std::string> names; - - GetPerStoreFeatureName(max_n_bufs, &names); - - Array<String> arr; - for (const auto& x : names) { - arr.push_back(x); - } - *ret = arr; - }); - -} // namespace auto_scheduler -} // namespace tvm diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 517f7ff91f558a5e1b5564358fa5087c388d58eb..1a685cfe01a7c88f251fd5bcce67c2c5010b22d6 100755 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -275,7 +275,7 @@ void State::reorder(int stage_id, const Array<Iterator>& order) { } Array<Iterator> State::split(int stage_id, const Iterator& it, - const Array<Optional<Integer>>& lengths, bool inner_to_outer) { + const Array<PrimExpr>& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; SplitStep step = SplitStep(stage_id, GetIndex(stage->iters, it), @@ -501,7 +501,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder") TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array<Optional<Integer>>& lengths, bool inner_to_outer) { + const Array<PrimExpr>& lengths, bool inner_to_outer) { const auto& res = state.split(stage_id, it, lengths, inner_to_outer); return Array<ObjectRef>{state, res}; }); diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index e652c1baf87a68d7d2950d0b8eba46fd3737a012..3bb4cf9608fd4fcba2b1ded8f23b55283b5a0908 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -111,10 +111,10 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel program_cost_model, node->init_rules.push_back(&init_vectorization); // Mutation Rules for Evolutionary Search - node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90)); - node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.04)); - node->mutation_rules.push_back(std::make_shared<MutateComputeLocation>(0.05)); - node->mutation_rules.push_back(std::make_shared<MutateParallel>(0.01)); + // node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90)); + // node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.04)); + // node->mutation_rules.push_back(std::make_shared<MutateComputeLocation>(0.05)); + // node->mutation_rules.push_back(std::make_shared<MutateParallel>(0.01)); } else if (IsGPUTask(node->search_task)) { // Sketch Generation Rules if (node->search_task->target->GetAttr<String>("device", "") == "mali") { @@ -147,8 +147,8 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel program_cost_model, } // Mutation Rules for Evolutionary Search - node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90)); - node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.10)); + // node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90)); + // node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.10)); } else { LOG(FATAL) << "No default sketch rules for target: " << task->target; } @@ -370,7 +370,7 @@ Array<State> SketchPolicyNode::GenerateSketches() { auto step = pstate->transform_steps[split_step_id].as<SplitStepNode>(); ICHECK(step != nullptr); pstate->transform_steps.Set( - split_step_id, SplitStep(step->stage_id, step->iter_id, step->extent, {NullOpt}, + split_step_id, SplitStep(step->stage_id, step->iter_id, step->extent, {PrimExpr()}, step->inner_to_outer)); } } diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 8df69fc7ce3b93daa27e589e6cce385d9b8ec6b8..83d12b0edf2e4750451278ca03c95b701307894e 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -311,7 +311,7 @@ std::vector<std::pair<State, int>> RuleSimplifyComputeWithConstTensor::Apply( // tile other space indices ICHECK(iter->iter_kind == IteratorKind::kSpatial); tiled_outer_iters.push_back( - tmp_s.split(stage_id, iter, Array<Optional<Integer>>(tile_level - 1, NullOpt))); + tmp_s.split(stage_id, iter, Array<PrimExpr>(tile_level - 1, PrimExpr()))); } } @@ -494,36 +494,27 @@ std::vector<std::pair<State, int>> RuleCustomSketch::Apply(const SketchPolicyNod PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { SplitFactorizationMemo split_memo; - int max_innermost_split_factor = - GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor); - + // TODO: remember variable range constraints + // int max_innermost_split_factor = + // GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor); StateNode* pstate = state->CopyOnWrite(); // Scan the transformation history and randomly fill tiles size for all SplitStep for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { if (auto ps = (*state)->transform_steps[step_id].as<SplitStepNode>()) { - bool all_defined = true; - for (const auto& len : ps->lengths) { - if (!len) { - all_defined = false; - break; - } - } - if (all_defined) { - continue; - } - + bool all_defined = + std::accumulate(ps->lengths.begin(), ps->lengths.end(), true, + [](bool acc, const ObjectRef& item) { return acc && item.defined(); }); + if (all_defined) continue; ICHECK(ps->extent); - int extent = GetIntImm(ps->extent.value()); - const auto& candidate_lens = split_memo.GetFactorizationSchemes(extent, ps->lengths.size(), - max_innermost_split_factor); - ICHECK(!candidate_lens.empty()); - const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()]; - - pstate->transform_steps.Set( - step_id, - SplitStep(ps->stage_id, ps->iter_id, ps->extent, - Array<Optional<Integer>>(candidate_lengths.begin(), candidate_lengths.end()), - ps->inner_to_outer)); + Array<PrimExpr> new_split_lengths; + for (size_t len_id = 0; len_id < ps->lengths.size(); ++len_id) { + if (ps->lengths[len_id].defined()) + LOG_FATAL << "SplitStep length should not be partially defined"; + String var_name = "sp" + std::to_string(step_id) + "_" + std::to_string(len_id); + new_split_lengths.push_back(Var(var_name)); + } + pstate->transform_steps.Set(step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent, + new_split_lengths, ps->inner_to_outer)); } } pstate->concrete = true; @@ -912,330 +903,23 @@ PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* pol PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { - int max_innermost_split_factor = - GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor); - - // Extract all SplitStep - std::vector<size_t> split_step_ids; - for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { - if (auto ps = (*state)->transform_steps[i].as<SplitStepNode>()) { - if (!ps->extent.defined() || !ps->extent.value()->IsInstance<IntImmNode>()) { - continue; - } - auto innermost_factor = ps->lengths.back().value_or(max_innermost_split_factor + 1); - if (GetIntImm(innermost_factor) <= max_innermost_split_factor) { - split_step_ids.push_back(i); - } - } - } - if (split_step_ids.empty()) { - // No tile size could be mutated. - return ResultKind::kInvalid; - } - - // Select a SplitStep with extent larger than one to mutate. - int retry_ct = 0; - int64_t extent = 1; - int step_id; - const SplitStepNode* ps; - - do { - step_id = split_step_ids[(*rand_gen)() % split_step_ids.size()]; - ps = (*state)->transform_steps[step_id].as<SplitStepNode>(); - ICHECK(ps != nullptr); - extent = GetIntImm(ps->extent.value()); - retry_ct += 1; - } while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 && (extent == 1 || extent == 0)); - - if (extent <= 1) { - // Cannot find a step with extent larger than one. - return ResultKind::kInvalid; - } - - // Fetch the current tile sizes. - std::vector<int> lengths(ps->lengths.size() + 1, 1); - for (int i = 0; i < static_cast<int>(ps->lengths.size()); ++i) { - lengths[i + 1] = GetIntImm(ps->lengths[i].value()); - } - lengths[0] = extent / ElementProduct(lengths); - - // Random permute the tile size order. - std::vector<int> random_perm; - RandomPermutation(lengths.size(), &random_perm, rand_gen); - - // Try to divide a factor from one tile size and multiple it to another. - for (size_t i = 0; i < random_perm.size(); ++i) { - size_t src_idx = random_perm[i]; - int length = lengths[src_idx]; - if (length <= 1) { - continue; - } - - // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx] - size_t dst_idx = random_perm[(i + 1) % random_perm.size()]; - const std::vector<int>& factors = policy->split_memo.GetFactors(length); - ICHECK_GE(factors.size(), 1); - - int divide_factor; - if (dst_idx == lengths.size() - 1) { - // Maintain the restriction of hardware_params.max_innermost_split_factor. - int max_factor_index = static_cast<int>(factors.size()) - 1; - for (; max_factor_index >= 1; max_factor_index--) { - if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) { - break; - } - } - if (max_factor_index == 0) { - // Failed on this dst_idx, try next one. - continue; - } - divide_factor = factors[1 + (*rand_gen)() % (max_factor_index)]; - } else { - divide_factor = factors[1 + (*rand_gen)() % (factors.size() - 1)]; - } - - // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx]. - Array<Integer> new_lengths; - for (size_t j = 1; j < lengths.size(); ++j) { - if (j == src_idx) { - new_lengths.push_back(Integer(lengths[j] / divide_factor)); - } else if (j == dst_idx) { - new_lengths.push_back(Integer(lengths[j] * divide_factor)); - } else { - new_lengths.push_back(Integer(lengths[j])); - } - } - - ICHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor); - - StateNode* pstate = state->CopyOnWrite(); - pstate->transform_steps.Set( - step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent, - Array<Optional<Integer>>(new_lengths.begin(), new_lengths.end()), - ps->inner_to_outer)); - return ResultKind::kValid; - } return ResultKind::kInvalid; } PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { - // Extract all auto_unroll_max_step pragma steps. - std::vector<int> pragma_steps; - for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { - if (auto ps = (*state)->transform_steps[i].as<PragmaStepNode>()) { - if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) { - pragma_steps.push_back(i); - } - } - } - if (pragma_steps.empty()) { - return ResultKind::kInvalid; - } - - std::vector<int>& auto_unroll_configs = - IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu; - - // Randomly pick up an auto unroll pragma step - auto step_id = pragma_steps[(*rand_gen)() % pragma_steps.size()]; - auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>(); - ICHECK(ps); - - // Mutate its value to a random candidates - int val = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()]; - StateNode* pstate = state->CopyOnWrite(); - pstate->transform_steps.Set( - step_id, PragmaStep(ps->stage_id, ps->iter_id, - std::string("auto_unroll_max_step") + "$" + std::to_string(val))); - Stage new_stage = pstate->stages[ps->stage_id]; - new_stage.CopyOnWrite()->attrs.auto_unroll_max_step = val; - pstate->stages.Set(ps->stage_id, new_stage); - return ResultKind::kValid; + return ResultKind::kInvalid; } PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { - if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { - return ResultKind::kInvalid; - } - - // Extract all compute_at steps. - std::vector<int> compute_at_steps; - for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) { - if (auto ps = (*state)->transform_steps[s].as<ComputeAtStepNode>()) { - int stage_inc = GetTargetStageIDInState(*state, s) - ps->stage_id; - - if (IsTiled((*state)->stages[ps->stage_id + stage_inc])) { - continue; - } - - if (NeedsMultilevelTiling(policy->search_task, *state, ps->stage_id + stage_inc)) { - continue; - } - compute_at_steps.push_back(s); - } - } - if (compute_at_steps.empty()) { - return ResultKind::kInvalid; - } - - // Randomly pick one step - size_t step_id = compute_at_steps[(*rand_gen)() % compute_at_steps.size()]; - auto ps = (*state)->transform_steps[step_id].as<ComputeAtStepNode>(); - int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id; - ICHECK(ps != nullptr); - - // Randomly pick a new computation location - std::vector<std::pair<int, int>> candidates = - GetComputeLocationCandidates(policy->search_task, *state, ps->stage_id + stage_inc); - if (candidates.empty()) { - return ResultKind::kInvalid; - } - int choice = (*rand_gen)() % (candidates.size()); - int new_compute_at_stage_id = candidates[choice].first; - int new_compute_at_iter_id = candidates[choice].second; - - // Replay a new state. - State tmp_s = policy->search_task->compute_dag->init_state; - for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) { - if (s == step_id) { - tmp_s.CopyOnWrite()->transform_steps.push_back( - ComputeAtStep(ps->stage_id, new_compute_at_stage_id - stage_inc, new_compute_at_iter_id)); - } else { - tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[s]); - } - try { - StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag); - } catch (Error& e) { - return ResultKind::kInvalid; - } - } - - *state = tmp_s; - return ResultKind::kValid; + return ResultKind::kInvalid; } PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { - // This mutation rule only focuses on a case that parallel was added to - // the outermost loop and the loop is generated by fusing other loops. - // In short, we mutate the fusion step before the parallel step. - - // Extract all parallel steps. - std::vector<int> parallel_steps; - for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) { - auto ps = (*state)->transform_steps[s].as<AnnotationStepNode>(); - if (!ps || ps->annotation != IteratorAnnotation::kParallel) { - continue; - } - - // Skip non-outermost loop or the parallel step without fusion beforehand. - if (ps->iter_id != 0 || s == 0 || !(*state)->transform_steps[s - 1].as<FuseStepNode>()) { - continue; - } - auto fuse_step = (*state)->transform_steps[s - 1].as<FuseStepNode>(); - if (fuse_step->fused_ids[0] != 0) { - continue; - } - - parallel_steps.push_back(s); - } - if (parallel_steps.empty()) { - return ResultKind::kInvalid; - } - - // Randomly pick one parallel step. - size_t step_id = parallel_steps[(*rand_gen)() % parallel_steps.size()]; - - // Replay a new state until the picked fuse step. - State tmp_s = policy->search_task->compute_dag->init_state; - for (size_t s = 0; s < step_id - 1; ++s) { - const auto& step = (*state)->transform_steps[s]; - tmp_s.CopyOnWrite()->transform_steps.push_back(step); - StepApplyToState(step, &tmp_s, policy->search_task->compute_dag); - } - - // Compute all possible fusion granularities - auto fuse_step = (*state)->transform_steps[step_id - 1].as<FuseStepNode>(); - int stage_id = fuse_step->stage_id; - const Stage& stage = tmp_s->stages[stage_id]; - size_t max_fusable_iter_id; - for (max_fusable_iter_id = 0; max_fusable_iter_id < stage->iters.size(); ++max_fusable_iter_id) { - const Iterator& it = stage->iters[max_fusable_iter_id]; - if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) { - break; - } - - if (tmp_s->attach_map->iter_to_attached_stages.count( - std::make_pair(stage_id, max_fusable_iter_id))) { - break; - } - } - - if (max_fusable_iter_id == 0) { - return ResultKind::kInvalid; - } - - // Randomly pick one granularity - int fuse_to_iter_id = (*rand_gen)() % max_fusable_iter_id + 1; - Array<Integer> fused_ids; - for (int i = 0; i < fuse_to_iter_id; ++i) { - fused_ids.push_back(i); - } - int iter_offset = fuse_step->fused_ids.back()->value - fused_ids.back()->value; - if (iter_offset == 0) { - return ResultKind::kInvalid; - } - - // Replay the mutated fused and annotation step. - auto new_fuse_step = FuseStep(stage_id, fused_ids); - tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step); - StepApplyToState(new_fuse_step, &tmp_s, policy->search_task->compute_dag); - tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[step_id]); - StepApplyToState((*state)->transform_steps[step_id], &tmp_s, policy->search_task->compute_dag); - - // Replay the rest steps. - for (size_t s = step_id + 1; s < (*state)->transform_steps.size(); ++s) { - auto step = (*state)->transform_steps[s]; - if (step->stage_id == stage_id) { - // Since we changed the loop structure, iter ID in later steps to the same stage - // has to be adjusted. - if (auto ps = step.as<AnnotationStepNode>()) { - if (ps->iter_id == 0) { - step = AnnotationStep(ps->stage_id, 0, ps->annotation); - } else { - ICHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); - step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); - } - } else if (auto ps = step.as<PragmaStepNode>()) { - if (ps->iter_id == 0) { - step = PragmaStep(ps->stage_id, 0, ps->pragma_type); - } else { - ICHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); - step = PragmaStep(ps->stage_id, ps->iter_id + iter_offset, ps->pragma_type); - } - } else { - return ResultKind::kInvalid; - } - } - if (IsStageNumberChangingStep(step)) { - // For these steps, we have to update stage_id because these steps will make stage_id - // out-dated. But here we just simply give up this mutation for simplicity. - // This is not an issue because this will never happend in normal cases where all these steps - // are before parallel steps. - return ResultKind::kInvalid; - } - tmp_s.CopyOnWrite()->transform_steps.push_back(step); - try { - StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag); - } catch (Error& e) { - return ResultKind::kInvalid; - } - } - - *state = tmp_s; - return ResultKind::kValid; + return ResultKind::kInvalid; } } // namespace auto_scheduler diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index ac1cf2dd82c9176cacce0e323360df10407dffc9..c55e8dbd055b01e3c1a127f61b43a0e0acb9fb66 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -182,7 +182,7 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo levels[0].push_back(iter); } else { Array<Iterator> split_res = - tmp_s.split(stage_id, iter, Array<Optional<Integer>>(size - 1, NullOpt)); + tmp_s.split(stage_id, iter, Array<PrimExpr>(size - 1, PrimExpr())); for (int i = 0; i < size; i++) { levels[i].push_back(split_res[i]); } diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index b67d5cdd7bd93c464bea2911b0896afc5d3cced0..8c8742130819b6df62d6ab72888163340d91bb62 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -818,7 +818,7 @@ String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, /********** Split **********/ // common part for SplitStep, FollowSplitStep, and FollowFusedSplitStep Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id, - const Array<Optional<Integer>>& lengths, bool inner_to_outer) { + const Array<PrimExpr>& lengths, bool inner_to_outer) { const Stage& stage = (*state)->stages[stage_id]; const Iterator& it = stage->iters[iter_id]; size_t old_iter_size = stage->iters.size(); @@ -835,7 +835,7 @@ Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id, Array<Iterator> outs; for (size_t i = 0; i < lengths.size(); ++i) { - Optional<Integer> l; + PrimExpr l; String name; if (inner_to_outer) { l = lengths[lengths.size() - i - 1]; @@ -845,11 +845,11 @@ Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id, name = it->name + "." + std::to_string(i); } Iterator res; - if (l && tosplit_min && tosplit_extent) { - res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind, + if (l.defined() && tosplit_min && tosplit_extent) { + res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l), it->iter_kind, IteratorAnnotation::kNone); tosplit_min = Integer(0); - tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value()); + tosplit_extent = indexdiv(tosplit_extent.value() + l - 1, l); } else { res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone); tosplit_min = NullOpt; @@ -899,8 +899,8 @@ Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id, } Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, - int stage_id, int iter_id, - const Array<Optional<Integer>>& lengths, bool inner_to_outer) { + int stage_id, int iter_id, const Array<PrimExpr>& lengths, + bool inner_to_outer) { auto stage = (*stages)[stage_id]; const Array<IterVar>& axes = stage_to_axes->at(stage); @@ -909,7 +909,7 @@ Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* st IterVar outer = axes[iter_id], inner; for (int i = static_cast<int>(lengths.size()) - 1; i >= 0; i--) { IterVar to_split = outer; - stage.split(to_split, lengths[i].value(), &outer, &inner); + stage.split(to_split, lengths[i], &outer, &inner); outs.push_back(inner); } outs.push_back(outer); @@ -917,7 +917,7 @@ Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* st IterVar outer, inner = axes[iter_id]; for (size_t i = 0; i < lengths.size(); i++) { IterVar to_split = inner; - stage.split_by_nparts(to_split, lengths[i].value(), &outer, &inner); + stage.split_by_nparts(to_split, lengths[i], &outer, &inner); outs.push_back(outer); } outs.push_back(inner); @@ -942,8 +942,7 @@ Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* st } String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, int stage_id, - int iter_id, const Array<Optional<Integer>>& lengths, - bool inner_to_outer) { + int iter_id, const Array<PrimExpr>& lengths, bool inner_to_outer) { const auto& stage = (*stages)[stage_id]; auto to_split = stage_to_axes->at(stage)[iter_id]; const auto& func_name = CleanName(stage->op->name); @@ -974,7 +973,7 @@ String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_ } SplitStep::SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent, - const Array<Optional<Integer>>& lengths, bool inner_to_outer) { + const Array<PrimExpr>& lengths, bool inner_to_outer) { auto node = make_object<SplitStepNode>(); node->stage_id = stage_id; // Extent can be a irreducible expression in some special cases @@ -999,15 +998,16 @@ SplitStep::SplitStep(dmlc::JSONReader* reader) { int int_val; s = reader->NextArrayItem(); ICHECK(s); - reader->Read(&int_val); - if (int_val) { - node->extent = Integer(int_val); - } - s = reader->NextArrayItem(); - ICHECK(s); - reader->Read(&node->lengths); - s = reader->NextArrayItem(); - ICHECK(s); + // TODO + // reader->Read(&int_val); + // if (int_val) { + // node->extent = Integer(int_val); + // } + // s = reader->NextArrayItem(); + // ICHECK(s); + // reader->Read(&node->lengths); + // s = reader->NextArrayItem(); + // ICHECK(s); reader->Read(&node->inner_to_outer); data_ = std::move(node); } @@ -1017,8 +1017,9 @@ void SplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteString(record_prefix_str); writer->WriteArrayItem(stage_id); writer->WriteArrayItem(iter_id); - writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0); - writer->WriteArrayItem(lengths); + // TODO + // writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0); + // writer->WriteArrayItem(lengths); writer->WriteArrayItem(static_cast<int>(inner_to_outer)); } @@ -1055,8 +1056,7 @@ void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteArrayItem(n_split); } -Array<Optional<Integer>> FollowSplitStepNode::ExtractSplitLengths( - const Array<Step>& transform_steps) const { +Array<PrimExpr> FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps) const { // Make sure src_step_id is within the range of transform_steps. ICHECK_LT(src_step_id, transform_steps.size()); auto ps = transform_steps[src_step_id].as<SplitStepNode>(); @@ -1067,7 +1067,7 @@ Array<Optional<Integer>> FollowSplitStepNode::ExtractSplitLengths( ICHECK_LE(n_split, ps->lengths.size() + 1); ICHECK(ps != nullptr); - Array<Optional<Integer>> lengths; + Array<PrimExpr> lengths; lengths.reserve(n_split); int j = 0; // Get the first (n_split-1) split factors of followed src_step. @@ -1079,19 +1079,14 @@ Array<Optional<Integer>> FollowSplitStepNode::ExtractSplitLengths( // ps->lengths.size()+1. PrimExpr last_factor = 1; for (; j < static_cast<int>(ps->lengths.size()); ++j) { - if (ps->lengths[j]) { - last_factor *= ps->lengths[j].value(); + if (ps->lengths[j].defined()) { + last_factor *= ps->lengths[j]; } else { last_factor = PrimExpr(); break; } } - if (last_factor.defined()) { - lengths.push_back(Downcast<Integer>(last_factor)); - } else { - lengths.push_back(NullOpt); - } - + lengths.push_back(last_factor); return lengths; } @@ -1176,8 +1171,7 @@ void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteArrayItem(static_cast<int>(factor_or_nparts)); } -Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength( - const Array<Step>& transform_steps) const { +PrimExpr FollowFusedSplitStepNode::ExtractSplitLength(const Array<Step>& transform_steps) const { PrimExpr ret(1); for (int src_step_id : src_step_ids) { @@ -1186,13 +1180,9 @@ Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength( auto ps = transform_steps[src_step_id].as<SplitStepNode>(); ICHECK(ps != nullptr); // Multiple the splitting factor on corresponding splitting level of src_steps. - if (ps->lengths[level] && ret.defined()) { - ret *= ps->lengths[level].value(); - } else { - return NullOpt; - } + ret *= ps->lengths[level]; } - return Downcast<Integer>(ret); + return ret; } Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const { diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index 2943adf19f26dbfb5781320fc8cfba32f4c152e6..0249c55e1b657a94799a357e5f0717065e682db8 100755 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -105,12 +105,9 @@ inline void FindAndDeleteItem(std::vector<T>* array, const T& to_delete) { } /*! \brief Compute the product of all elements in a vector */ -inline int64_t ElementProduct(const std::vector<int>& array) { - int64_t ret = 1; - for (auto x : array) { - ret *= x; - } - return ret; +inline PrimExpr ElementProduct(const std::vector<PrimExpr>& array) { + return std::accumulate(array.begin(), array.end(), PrimExpr(1), + [](const PrimExpr& a, const PrimExpr& b) { return a * b; }); } /*! \brief Move elements from multiple vectors to one vector */