diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 4a4ab18b5eed80dde7593b8c3b9a98ed7bc39556..e652c1baf87a68d7d2950d0b8eba46fd3737a012 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -378,6 +378,23 @@ Array<State> SketchPolicyNode::GenerateSketches() { } StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states.size() << std::endl; + + // const auto& sampleSketch = out_states[0]; + // auto sampleState = AnnotateTillValid(sampleSketch); + // PrintStateReplaySteps(std::cerr, search_task->compute_dag, sampleState->transform_steps); + + // double allSketchesSpace = 0.0; + // for (const auto &sketch: out_states) { + // double loopTilingSpace, loopUnrollSpace; + // std::tie(loopTilingSpace, loopUnrollSpace) = EstimateSketchSearchSpace(sketch, &std::cerr); + // std::cerr << "Loop tiling search space: " << loopTilingSpace << "\n" + // << "Loop unrolling search space: " << loopUnrollSpace << "\n" + // << "Total search space of sketch: " << loopTilingSpace * loopUnrollSpace << "\n" + // << "\n"; + // allSketchesSpace += loopTilingSpace * loopUnrollSpace; + // } + // std::cerr << "Total search space: " << allSketchesSpace << std::endl; + return out_states; } @@ -671,6 +688,51 @@ Array<MeasureInput> SketchPolicyNode::PickStatesWithEpsGreedy(const Array<State> return inputs; } +State SketchPolicyNode::AnnotateTillValid(const State& sketch) { + auto retSketch = sketch; + while (true) { + bool valid = true; + for (const auto& rule : init_rules) { + if (rule->Apply(this, &retSketch, &rand_gen) == + PopulationGenerationRule::ResultKind::kInvalid) { + valid = false; + break; + } + } + if (valid) + return retSketch; + else + retSketch = sketch; + } +} + +std::tuple<double, double> SketchPolicyNode::EstimateSketchSearchSpace(const State& sketch, + std::ostream* os) { + auto state = AnnotateTillValid(sketch); + // Value taken from sketch_policy_rules.cc + int nUnrollOptions = IsGPUTask(search_task) ? 5 : 4; + double loopTilingSpace = 1.0, loopUnrollSpace = 1.0; + for (const auto& step : state->transform_steps) { + if (auto ps = step.as<SplitStepNode>()) { + SplitFactorizationMemo split_memo; + int extent = GetIntImm(ps->extent.value()); + const auto& candidates = split_memo.GetFactorizationSchemes(extent, ps->lengths.size(), 64); + if (os) { + *os << TransformStepToStr(step); + std::cerr << " " << candidates.size() << " possible factorizations\n"; + } + loopTilingSpace *= candidates.size(); + } else if (step.as<PragmaStepNode>()) { + if (os) { + *os << TransformStepToStr(step); + std::cerr << " " << nUnrollOptions << " possible factorizations\n"; + } + loopUnrollSpace *= nUnrollOptions; + } + } + return {loopTilingSpace, loopUnrollSpace}; +} + /********** PreloadCustomSketchRule **********/ TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index faf058b45b198a3f3b56e6977ce811206f112bc8..0faa8a55704e8f96434afe5778681d2145492980 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -163,6 +163,10 @@ class SketchPolicyNode : public SearchPolicyNode { const Array<State>& random_states, int remaining_n_trials); + State AnnotateTillValid(const State& sketch); + + std::tuple<double, double> EstimateSketchSearchSpace(const State& sketch, std::ostream* os); + /*! \brief The number of states to measure per iteration. */ int num_measure_per_iter_; diff --git a/src/auto_scheduler/utils.cc b/src/auto_scheduler/utils.cc index 68f503836cfb966cff7c745416f4771e36a152cd..1405b1e434ce5e8831dbb5ae064d814a60a6c263 100755 --- a/src/auto_scheduler/utils.cc +++ b/src/auto_scheduler/utils.cc @@ -24,13 +24,201 @@ #include "utils.h" +#include <tvm/auto_scheduler/transform_step.h> +#include <tvm/driver/driver_api.h> +#include <tvm/te/operation.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/transform.h> + +#include "search_policy/utils.h" + namespace tvm { namespace auto_scheduler { +TVM_REGISTER_PASS_CONFIG_OPTION("ansor.verbose", Bool); + NullStream& NullStream::Global() { static NullStream stream; return stream; } +String TransformStepToStr(const Step& step) { + std::ostringstream os; + if (auto ps = step.as<AnnotationStepNode>()) { + os << "Annotation(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id << ", annotation=\"" + << IteratorAnnotationString[static_cast<int>(ps->annotation)] << "\")"; + } else if (auto ps = step.as<FuseStepNode>()) { + os << "Fuse(stage_id=" << ps->stage_id << ", fused_ids" << ps->fused_ids << ")"; + } else if (auto ps = step.as<PragmaStepNode>()) { + os << "Pragma(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id + << ", pragma=" << ps->pragma_type << ")"; + } else if (auto ps = step.as<ReorderStepNode>()) { + os << "Reorder(stage_id=" << ps->stage_id << ", order_after=" << ps->after_ids << ")"; + } else if (auto ps = step.as<SplitStepNode>()) { + os << "Split(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id << ", extent=" << ps->extent + << ", " << ps->lengths << ", inner_to_outer=" << ps->inner_to_outer << ")"; + } else if (auto ps = step.as<FollowSplitStepNode>()) { + os << "FollowSplit(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id + << ", src_step_id=" << ps->src_step_id << ", n_split=" << ps->n_split << ")"; + } else if (auto ps = step.as<FollowFusedSplitStepNode>()) { + os << "FollowFusedSplit(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id + << ", src_step_ids=" << ps->src_step_ids << ", level=" << ps->level + << ", factor_or_nparts=" << ps->factor_or_nparts << ")"; + } else if (auto ps = step.as<StorageAlignStepNode>()) { + os << "StorageAlign(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id + << ", factor=" << ps->factor << ", offset=" << ps->offset << ")"; + } else if (auto ps = step.as<ComputeAtStepNode>()) { + os << "ComputeAt(stage_id=" << ps->stage_id << ", target_stage_id=" << ps->target_stage_id + << ", loop=" << ps->target_iter_id << ")"; + } else if (auto ps = step.as<ComputeInlineStepNode>()) { + os << "ComputeInline(stage_id=" << ps->stage_id << ")"; + } else if (auto ps = step.as<ComputeRootStepNode>()) { + os << "ComputeRoot(stage_id=" << ps->stage_id << ")"; + } else if (auto ps = step.as<CacheReadStepNode>()) { + os << "CacheRead(stage_id=" << ps->stage_id << ", scope_name=" << ps->scope_name + << ", reader_stage_ids=" << ps->reader_stage_ids << ")"; + } else if (auto ps = step.as<CacheWriteStepNode>()) { + os << "CacheWrite(stage_id=" << ps->stage_id << ", scope_name=" << ps->scope_name << ")"; + } else if (auto ps = step.as<RfactorStepNode>()) { + os << "RFactor(stage_id=" << ps->stage_id << ", from_loop=" << ps->iter_id + << ", to_loop=" << ps->factor_iter_id << ")"; + } else { + LOG(FATAL) << "Invalid step: " << step; + } + return os.str(); +} + +Array<Stmt> ReplayStepsGenCode(const SearchTask& task, const Array<Step>& trSteps) { + const auto& taskDAG = task->compute_dag; + auto initState = taskDAG->init_state; + Array<Stmt> generated_stmts; + for (auto& step : trSteps) { + initState.CopyOnWrite()->transform_steps.push_back(step); + StepApplyToState(step, &initState, taskDAG); + try { + initState = taskDAG.InferBound(initState); + } catch (Error& e) { + LOG_FATAL << "Failed inferring bound: " << e.what() << "\n"; + } + generated_stmts.push_back(GenerateCodeForState(task, initState)); + } + return generated_stmts; +} + +tvm::transform::Sequential GetCodeGenPasses(const HardwareParams& hw_params, bool is_gpu) { + using namespace tvm::tir::transform; + auto pass_list = Array<tvm::transform::Pass>(); + if (is_gpu) { + auto pass_ctx = tvm::transform::PassContext::Current(); + bool disable_vectorize = + pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value(); + bool instrument_bound_checkers = + pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value(); + Map<String, PrimExpr> gpu_params{ + {"max_shared_memory_per_block", hw_params->max_shared_memory_per_block}, + {"max_local_memory_per_block", hw_params->max_local_memory_per_block}, + {"max_threads_per_block", hw_params->max_threads_per_block}, + {"max_vector_bytes", hw_params->vector_unit_bytes}, + {"max_vthread", hw_params->max_vthread_extent}, + }; + pass_list = Array<tvm::transform::Pass>( + {// Phase 0 + InjectPrefetch(), StorageFlatten(64, instrument_bound_checkers), + // Phase 1 + NarrowDataType(32), Simplify(), VectorizeLoop(!disable_vectorize), InjectVirtualThread(), + StorageRewrite(), Simplify(), VerifyGPUCode(gpu_params)}); + } else { + pass_list = Array<tvm::transform::Pass>({Simplify()}); + } + return Sequential(pass_list); +} + +Stmt GenerateCodeForState(const SearchTask& task, const State& state, bool print_error) { + te::Schedule sch; + Array<te::Tensor> tensors; + std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); + // When inlining, replace const matrices with const values. + // Produces wrong IR, but good enough for feature extraction, and + // can improve the speed of feature extraction/search. Must be + // called before ScheduleToModule to have an effect. + sch = sch.normalize_for_feature_extraction(); + + try { + const std::string& name = "main"; + auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), tensors.end()}, name, + std::unordered_map<te::Tensor, te::Buffer>()); + auto passes = GetCodeGenPasses(task->hardware_params, IsGPUTask(task)); + passes(mod); + PrimFunc prim_func = Downcast<PrimFunc>(mod->Lookup(name)); + return prim_func->body; + } catch (Error& e) { + if (print_error) LOG_WARNING << "Failed to generate code: " << e.what() << "\n"; + return Stmt(); + } +} + +TVM_REGISTER_GLOBAL("auto_scheduler.CreateStateFromEncodedSteps") + .set_body_typed([](const ComputeDAG& taskDAG, const String& jsonString) { + auto state = taskDAG->init_state; + std::istringstream is(jsonString); + dmlc::JSONReader reader(&is); + Array<Step> steps; + reader.Read(&steps); + for (auto& step : steps) { + state.CopyOnWrite()->transform_steps.push_back(step); + StepApplyToState(step, &state, taskDAG); + } + state = taskDAG.InferBound(state); + return state; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.EncodeTrSteps").set_body_typed([](const Array<Step>& steps) { + std::ostringstream os; + dmlc::JSONWriter writer(&os); + writer.Write(steps); + return os.str(); +}); + +TVM_REGISTER_GLOBAL("auto_scheduler.TransformStepToStr").set_body_typed(TransformStepToStr); + +TVM_REGISTER_GLOBAL("auto_scheduler.ReplayStepsGenCode").set_body_typed(ReplayStepsGenCode); + +TVM_REGISTER_GLOBAL("auto_scheduler.GenerateCodeForState") + .set_body_typed([](const SearchTask& task, const State& state) { + return GenerateCodeForState(task, state, true); + }); + } // namespace auto_scheduler } // namespace tvm + +namespace dmlc { +namespace json { +template <> +struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::Array<::tvm::auto_scheduler::Step>& data) { + writer->BeginArray(false); + for (const auto& step : data) { + writer->WriteArraySeperator(); + writer->BeginArray(false); + step->WriteToRecord(writer); + writer->EndArray(); + } + writer->EndArray(); + } + + inline static void Read(dmlc::JSONReader* reader, + ::tvm::Array<::tvm::auto_scheduler::Step>* data) { + bool s; + reader->BeginArray(); + data->clear(); + while (reader->NextArrayItem()) { + reader->BeginArray(); + data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader)); + s = reader->NextArrayItem(); + ICHECK(!s); + } + } +}; +} // namespace json +} // namespace dmlc diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index 9fc5a1dd8f22969fec0d79353a60f76471f01275..ebe560e536c4f1a631391e08b6ca80bc21c77166 100755 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -26,6 +26,9 @@ #define TVM_AUTO_SCHEDULER_UTILS_H_ #include <dmlc/common.h> +#include <tvm/auto_scheduler/compute_dag.h> +#include <tvm/auto_scheduler/search_task.h> +#include <tvm/auto_scheduler/transform_step.h> #include <tvm/tir/expr.h> #include <algorithm> @@ -296,6 +299,16 @@ inline void ParseKernelLayout(const String& layout, Array<PrimExpr>* shape, /*! \brief Get the base name before '_' of an axis */ inline std::string AxisBaseName(const std::string& str) { return str.substr(0, str.rfind("_")); } +void PrintTransformStep(std::ostream& os, const Step& step); + +void PrintStateReplaySteps(std::ostream& os, const ComputeDAG& taskDAG, const Array<Step>& trSteps); + +String TransformStepToStr(const Step& step); + +tvm::transform::Sequential GetCodeGenPasses(const HardwareParams& hw_params, bool is_gpu); + +Stmt GenerateCodeForState(const SearchTask& task, const State& state, bool print_error = false); + } // namespace auto_scheduler } // namespace tvm