Skip to content
Snippets Groups Projects
Commit 56fd61c1 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Changed the workflow around CHR step

parent a2ba55a5
No related branches found
No related tags found
No related merge requests found
...@@ -16,16 +16,18 @@ using namespace tvm::tir; ...@@ -16,16 +16,18 @@ using namespace tvm::tir;
class BufstoreInfoNode : public Object { class BufstoreInfoNode : public Object {
public: public:
size_t stage_id, iter_id; size_t stage_id, iter_id;
BufferStore bufstore; BufferLoad lhs;
Array<String> iters_in_expr; Array<BufferLoad> rhs;
tvm::runtime::NDArray counts; BufferStore orig_bufstore;
PrimExpr orig_rhs;
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("stage_id", &stage_id); v->Visit("stage_id", &stage_id);
v->Visit("iter_id", &iter_id); v->Visit("iter_id", &iter_id);
v->Visit("bufstore", &bufstore); v->Visit("lhs", &lhs);
v->Visit("iters_in_expr", &iters_in_expr); v->Visit("rhs", &rhs);
v->Visit("counts", &counts); v->Visit("orig_bufstore", &orig_bufstore);
v->Visit("orig_rhs", &orig_rhs);
} }
static constexpr const char* _type_key = "ansor.BufstoreInfo"; static constexpr const char* _type_key = "ansor.BufstoreInfo";
...@@ -34,14 +36,15 @@ class BufstoreInfoNode : public Object { ...@@ -34,14 +36,15 @@ class BufstoreInfoNode : public Object {
class BufstoreInfo : public ObjectRef { class BufstoreInfo : public ObjectRef {
public: public:
explicit BufstoreInfo(size_t stage_id, size_t iter_id, BufferStore bufstore, explicit BufstoreInfo(size_t stage_id, size_t iter_id, BufferLoad lhs, Array<BufferLoad> rhs,
Array<String> iters_in_expr, tvm::runtime::NDArray counts) { BufferStore orig_bufstore, PrimExpr orig_rhs) {
auto node = make_object<BufstoreInfoNode>(); auto node = make_object<BufstoreInfoNode>();
node->stage_id = stage_id; node->stage_id = stage_id;
node->iter_id = iter_id; node->iter_id = iter_id;
node->bufstore = std::move(bufstore); node->lhs = std::move(lhs);
node->iters_in_expr = std::move(iters_in_expr); node->rhs = std::move(rhs);
node->counts = std::move(counts); node->orig_bufstore = std::move(orig_bufstore);
node->orig_rhs = std::move(orig_rhs);
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -49,17 +52,9 @@ class BufstoreInfo : public ObjectRef { ...@@ -49,17 +52,9 @@ class BufstoreInfo : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufstoreInfoNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufstoreInfoNode);
}; };
// TVM_REGISTER_NODE_TYPE(BufferAccessNode);
TVM_REGISTER_NODE_TYPE(BufstoreInfoNode); TVM_REGISTER_NODE_TYPE(BufstoreInfoNode);
class IterVarsExtractor : public StmtExprVisitor {
public:
explicit IterVarsExtractor() {}
void VisitExpr_(const VarNode* node) final { ++this->varcounts[node]; }
std::unordered_map<const VarNode*, size_t> varcounts;
};
class BufstoreExtractor : public StmtExprVisitor { class BufstoreExtractor : public StmtExprVisitor {
public: public:
explicit BufstoreExtractor(const Array<Stage>& stages) { explicit BufstoreExtractor(const Array<Stage>& stages) {
...@@ -84,66 +79,54 @@ class BufstoreExtractor : public StmtExprVisitor { ...@@ -84,66 +79,54 @@ class BufstoreExtractor : public StmtExprVisitor {
void VisitStmt_(const BufferStoreNode* node) final { void VisitStmt_(const BufferStoreNode* node) final {
auto& name = node->buffer->name; auto& name = node->buffer->name;
auto it = this->stage_name_to_id.find(name); auto it = this->stage_name_to_id.find(name);
if (it == this->stage_name_to_id.end()) LOG_FATAL << "Buffer " << name << " is not found"; if (it == this->stage_name_to_id.end()) {
IterVarsExtractor iv_extractor; LOG_WARNING << "Buffer " << name << " is not found";
iv_extractor(node->value); return;
BufferStore bufstore(node->buffer, node->value, node->indices, node->span);
size_t n = iv_extractor.varcounts.size(), i = 0;
Array<String> iters_in_expr;
auto counts =
tvm::runtime::NDArray::Empty({(int64_t)n}, DLDataType{kDLInt, 32, 1}, {kDLCPU, 0});
for (auto& kv : iv_extractor.varcounts) {
iters_in_expr.push_back(kv.first->name_hint);
static_cast<int*>(counts->data)[i] = (int)kv.second;
i += 1;
} }
StmtExprVisitor::VisitStmt_(node);
this->bufstore_info.push_back( this->bufstore_info.push_back(
BufstoreInfo(it->second, itervars_stack.size(), bufstore, iters_in_expr, counts)); BufstoreInfo(it->second, itervars_stack.size() - 1, BufferLoad(node->buffer, node->indices),
std::move(this->buffer_loads),
BufferStore(node->buffer, node->value, node->indices), node->value));
this->buffer_loads = Array<BufferLoad>();
}
void VisitExpr_(const BufferLoadNode* node) final {
this->buffer_loads.push_back(BufferLoad(node->buffer, node->indices));
} }
std::unordered_map<std::string, size_t> stage_name_to_id; std::unordered_map<std::string, size_t> stage_name_to_id;
Array<BufstoreInfo> bufstore_info; Array<BufstoreInfo> bufstore_info;
Array<Iterator> itervars_stack; Array<Iterator> itervars_stack;
const BufferStoreNode* cur_bufstore; Array<BufferLoad> buffer_loads; // Cleared at every BufferStore node
}; };
BufstoreInfo GetBufstoreByName(const SearchTask& task, const State& state, const Step& step, BufstoreInfo GetNewBufstore(const SearchTask& task, State &state, const Step &step) {
const std::string& chr_buf_name) { auto task_dag = task->compute_dag;
Stmt generated = GenerateCodeForState(task, state); state.CopyOnWrite()->transform_steps.push_back(step);
BufstoreExtractor extractor(state->stages); StepApplyToState(step, &state, task_dag);
extractor(generated); state = task_dag.InferBound(state);
bool found = false; auto stmt = GenerateCodeForState(task, state);
BufstoreInfo ret; if (auto chr = step.as<CacheReadStepNode>()) {
for (auto& bufstore_info : extractor.bufstore_info) { // The new CHR stage will (somehow) have stage id of chr->stage_id + 1
if (bufstore_info->bufstore->buffer->name != chr_buf_name) continue; int stage_id = chr->stage_id + 1;
ret = bufstore_info; BufstoreExtractor extractor({state->stages[stage_id]});
found = true; extractor(stmt);
break; if (extractor.bufstore_info.size() != 1)
LOG_FATAL << "Expected only one bufstore in the new CHR stage";
return extractor.bufstore_info[0];
} }
if (!found) return BufstoreInfo();
LOG_FATAL << "CHR stage " << TransformStepToStr(step) << " (buffer_name=" << chr_buf_name
<< ") not found";
return ret;
} }
Array<ObjectRef> GetCacheReadsBufferStore(const SearchTask& task, const State& state_) { class IterVarsExtractor : public StmtExprVisitor {
const auto& task_dag = task->compute_dag; public:
const auto& tr_steps = state_->transform_steps; explicit IterVarsExtractor() {}
auto state = task_dag->init_state;
Array<ObjectRef> ret(tr_steps.size(), ObjectRef()); void VisitExpr_(const VarNode* node) final { ++this->varcounts[node->name_hint]; }
for (size_t i = 0; i < tr_steps.size(); ++i) {
auto& step = tr_steps[i]; std::unordered_map<String, size_t> varcounts;
state.CopyOnWrite()->transform_steps.push_back(step); };
StepApplyToState(step, &state, task_dag);
auto* chr = step.as<CacheReadStepNode>();
if (!chr) continue;
auto& chr_stage = state->stages[chr->stage_id + 1];
auto bufstore_info = GetBufstoreByName(task, state, step, chr_stage->op->name);
ret.Set(i, bufstore_info);
}
return ret;
}
Array<Stmt> ReplayStepsGenCode(const SearchTask& task, const Array<Step>& trSteps) { Array<Stmt> ReplayStepsGenCode(const SearchTask& task, const Array<Step>& trSteps) {
const auto& taskDAG = task->compute_dag; const auto& taskDAG = task->compute_dag;
...@@ -162,16 +145,56 @@ Array<Stmt> ReplayStepsGenCode(const SearchTask& task, const Array<Step>& trStep ...@@ -162,16 +145,56 @@ Array<Stmt> ReplayStepsGenCode(const SearchTask& task, const Array<Step>& trStep
return generated_stmts; return generated_stmts;
} }
TVM_REGISTER_GLOBAL("auto_scheduler.GetBufferStores") Array<Step> DecodeSteps(const String& jsonString) {
.set_body_typed([](const SearchTask& task, const State& state) { std::istringstream is(jsonString);
dmlc::JSONReader reader(&is);
Array<Step> steps;
reader.Read(&steps);
return steps;
}
TVM_REGISTER_GLOBAL("auto_scheduler.GetInitialBufstores")
.set_body_typed([](const SearchTask& task) {
auto state = task->compute_dag->init_state;
auto stmt = GenerateCodeForState(task, state); auto stmt = GenerateCodeForState(task, state);
BufstoreExtractor extractor(state->stages); BufstoreExtractor extractor(state->stages);
extractor(stmt); extractor(stmt);
return extractor.bufstore_info; return extractor.bufstore_info;
}); });
TVM_REGISTER_GLOBAL("auto_scheduler.GetCacheReadsBufferStore") TVM_REGISTER_GLOBAL("auto_scheduler.GetNewBufstore")
.set_body_typed(GetCacheReadsBufferStore); .set_body_typed([](const SearchTask& task, const State &state, const String &step_json) {
std::istringstream is(step_json);
dmlc::JSONReader reader(&is);
reader.BeginArray();
Step step = StepReadFromRecord(&reader);
ICHECK(!reader.NextArrayItem());
State new_state = state;
auto bufstore = GetNewBufstore(task, new_state, step);
return Array<ObjectRef>{new_state, bufstore};
});
TVM_REGISTER_GLOBAL("auto_scheduler.CountIterators").set_body_typed([](const PrimExpr &expr) {
IterVarsExtractor extractor;
extractor(expr);
Array<Array<ObjectRef>> ret;
for (const auto &pair: extractor.varcounts) {
ObjectRef name = pair.first;
DataType dtype{kDLUInt, 32, 1};
ObjectRef count = IntImm(dtype, (int64_t) pair.second);
ret.push_back(Array<ObjectRef>{name, count});
}
return ret;
});
TVM_REGISTER_GLOBAL("auto_scheduler.EncodeTrSteps").set_body_typed([](const Array<Step>& steps) {
std::ostringstream os;
dmlc::JSONWriter writer(&os);
writer.Write(steps);
return os.str();
});
/********** Debug APIs ********************************************************/
TVM_REGISTER_GLOBAL("auto_scheduler.TransformStepToStr").set_body_typed(TransformStepToStr); TVM_REGISTER_GLOBAL("auto_scheduler.TransformStepToStr").set_body_typed(TransformStepToStr);
...@@ -186,15 +209,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.PrintStateAllLoops").set_body_typed([](const ...@@ -186,15 +209,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.PrintStateAllLoops").set_body_typed([](const
return state.ToStr(false); return state.ToStr(false);
}); });
TVM_REGISTER_GLOBAL("auto_scheduler.ApplyEncodedStepsToInitState")
TVM_REGISTER_GLOBAL("auto_scheduler.CreateStateFromEncodedSteps")
.set_body_typed([](const ComputeDAG& taskDAG, const String& jsonString) { .set_body_typed([](const ComputeDAG& taskDAG, const String& jsonString) {
auto state = taskDAG->init_state; auto state = taskDAG->init_state;
std::istringstream is(jsonString); for (auto& step : DecodeSteps(jsonString)) {
dmlc::JSONReader reader(&is);
Array<Step> steps;
reader.Read(&steps);
for (auto& step : steps) {
state.CopyOnWrite()->transform_steps.push_back(step); state.CopyOnWrite()->transform_steps.push_back(step);
StepApplyToState(step, &state, taskDAG); StepApplyToState(step, &state, taskDAG);
} }
...@@ -202,13 +220,6 @@ TVM_REGISTER_GLOBAL("auto_scheduler.CreateStateFromEncodedSteps") ...@@ -202,13 +220,6 @@ TVM_REGISTER_GLOBAL("auto_scheduler.CreateStateFromEncodedSteps")
return state; return state;
}); });
TVM_REGISTER_GLOBAL("auto_scheduler.EncodeTrSteps").set_body_typed([](const Array<Step>& steps) {
std::ostringstream os;
dmlc::JSONWriter writer(&os);
writer.Write(steps);
return os.str();
});
} // namespace auto_scheduler } // namespace auto_scheduler
} // namespace tvm } // namespace tvm
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment