Skip to content
Snippets Groups Projects
Commit 05fe7702 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Utils for estimating search space and printing various info

parent 95aac922
No related branches found
No related tags found
No related merge requests found
Pipeline #194577 failed
......@@ -378,6 +378,23 @@ Array<State> SketchPolicyNode::GenerateSketches() {
}
StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states.size() << std::endl;
// const auto& sampleSketch = out_states[0];
// auto sampleState = AnnotateTillValid(sampleSketch);
// PrintStateReplaySteps(std::cerr, search_task->compute_dag, sampleState->transform_steps);
// double allSketchesSpace = 0.0;
// for (const auto &sketch: out_states) {
// double loopTilingSpace, loopUnrollSpace;
// std::tie(loopTilingSpace, loopUnrollSpace) = EstimateSketchSearchSpace(sketch, &std::cerr);
// std::cerr << "Loop tiling search space: " << loopTilingSpace << "\n"
// << "Loop unrolling search space: " << loopUnrollSpace << "\n"
// << "Total search space of sketch: " << loopTilingSpace * loopUnrollSpace << "\n"
// << "\n";
// allSketchesSpace += loopTilingSpace * loopUnrollSpace;
// }
// std::cerr << "Total search space: " << allSketchesSpace << std::endl;
return out_states;
}
......@@ -671,6 +688,51 @@ Array<MeasureInput> SketchPolicyNode::PickStatesWithEpsGreedy(const Array<State>
return inputs;
}
State SketchPolicyNode::AnnotateTillValid(const State& sketch) {
auto retSketch = sketch;
while (true) {
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &retSketch, &rand_gen) ==
PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
}
if (valid)
return retSketch;
else
retSketch = sketch;
}
}
std::tuple<double, double> SketchPolicyNode::EstimateSketchSearchSpace(const State& sketch,
std::ostream* os) {
auto state = AnnotateTillValid(sketch);
// Value taken from sketch_policy_rules.cc
int nUnrollOptions = IsGPUTask(search_task) ? 5 : 4;
double loopTilingSpace = 1.0, loopUnrollSpace = 1.0;
for (const auto& step : state->transform_steps) {
if (auto ps = step.as<SplitStepNode>()) {
SplitFactorizationMemo split_memo;
int extent = GetIntImm(ps->extent.value());
const auto& candidates = split_memo.GetFactorizationSchemes(extent, ps->lengths.size(), 64);
if (os) {
*os << TransformStepToStr(step);
std::cerr << " " << candidates.size() << " possible factorizations\n";
}
loopTilingSpace *= candidates.size();
} else if (step.as<PragmaStepNode>()) {
if (os) {
*os << TransformStepToStr(step);
std::cerr << " " << nUnrollOptions << " possible factorizations\n";
}
loopUnrollSpace *= nUnrollOptions;
}
}
return {loopTilingSpace, loopUnrollSpace};
}
/********** PreloadCustomSketchRule **********/
TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode);
......
......@@ -163,6 +163,10 @@ class SketchPolicyNode : public SearchPolicyNode {
const Array<State>& random_states,
int remaining_n_trials);
State AnnotateTillValid(const State& sketch);
std::tuple<double, double> EstimateSketchSearchSpace(const State& sketch, std::ostream* os);
/*! \brief The number of states to measure per iteration. */
int num_measure_per_iter_;
......
......@@ -24,13 +24,201 @@
#include "utils.h"
#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/driver/driver_api.h>
#include <tvm/te/operation.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/transform.h>
#include "search_policy/utils.h"
namespace tvm {
namespace auto_scheduler {
TVM_REGISTER_PASS_CONFIG_OPTION("ansor.verbose", Bool);
NullStream& NullStream::Global() {
static NullStream stream;
return stream;
}
String TransformStepToStr(const Step& step) {
std::ostringstream os;
if (auto ps = step.as<AnnotationStepNode>()) {
os << "Annotation(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id << ", annotation=\""
<< IteratorAnnotationString[static_cast<int>(ps->annotation)] << "\")";
} else if (auto ps = step.as<FuseStepNode>()) {
os << "Fuse(stage_id=" << ps->stage_id << ", fused_ids" << ps->fused_ids << ")";
} else if (auto ps = step.as<PragmaStepNode>()) {
os << "Pragma(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id
<< ", pragma=" << ps->pragma_type << ")";
} else if (auto ps = step.as<ReorderStepNode>()) {
os << "Reorder(stage_id=" << ps->stage_id << ", order_after=" << ps->after_ids << ")";
} else if (auto ps = step.as<SplitStepNode>()) {
os << "Split(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id << ", extent=" << ps->extent
<< ", " << ps->lengths << ", inner_to_outer=" << ps->inner_to_outer << ")";
} else if (auto ps = step.as<FollowSplitStepNode>()) {
os << "FollowSplit(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id
<< ", src_step_id=" << ps->src_step_id << ", n_split=" << ps->n_split << ")";
} else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
os << "FollowFusedSplit(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id
<< ", src_step_ids=" << ps->src_step_ids << ", level=" << ps->level
<< ", factor_or_nparts=" << ps->factor_or_nparts << ")";
} else if (auto ps = step.as<StorageAlignStepNode>()) {
os << "StorageAlign(stage_id=" << ps->stage_id << ", loop=" << ps->iter_id
<< ", factor=" << ps->factor << ", offset=" << ps->offset << ")";
} else if (auto ps = step.as<ComputeAtStepNode>()) {
os << "ComputeAt(stage_id=" << ps->stage_id << ", target_stage_id=" << ps->target_stage_id
<< ", loop=" << ps->target_iter_id << ")";
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
os << "ComputeInline(stage_id=" << ps->stage_id << ")";
} else if (auto ps = step.as<ComputeRootStepNode>()) {
os << "ComputeRoot(stage_id=" << ps->stage_id << ")";
} else if (auto ps = step.as<CacheReadStepNode>()) {
os << "CacheRead(stage_id=" << ps->stage_id << ", scope_name=" << ps->scope_name
<< ", reader_stage_ids=" << ps->reader_stage_ids << ")";
} else if (auto ps = step.as<CacheWriteStepNode>()) {
os << "CacheWrite(stage_id=" << ps->stage_id << ", scope_name=" << ps->scope_name << ")";
} else if (auto ps = step.as<RfactorStepNode>()) {
os << "RFactor(stage_id=" << ps->stage_id << ", from_loop=" << ps->iter_id
<< ", to_loop=" << ps->factor_iter_id << ")";
} else {
LOG(FATAL) << "Invalid step: " << step;
}
return os.str();
}
Array<Stmt> ReplayStepsGenCode(const SearchTask& task, const Array<Step>& trSteps) {
const auto& taskDAG = task->compute_dag;
auto initState = taskDAG->init_state;
Array<Stmt> generated_stmts;
for (auto& step : trSteps) {
initState.CopyOnWrite()->transform_steps.push_back(step);
StepApplyToState(step, &initState, taskDAG);
try {
initState = taskDAG.InferBound(initState);
} catch (Error& e) {
LOG_FATAL << "Failed inferring bound: " << e.what() << "\n";
}
generated_stmts.push_back(GenerateCodeForState(task, initState));
}
return generated_stmts;
}
tvm::transform::Sequential GetCodeGenPasses(const HardwareParams& hw_params, bool is_gpu) {
using namespace tvm::tir::transform;
auto pass_list = Array<tvm::transform::Pass>();
if (is_gpu) {
auto pass_ctx = tvm::transform::PassContext::Current();
bool disable_vectorize =
pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
Map<String, PrimExpr> gpu_params{
{"max_shared_memory_per_block", hw_params->max_shared_memory_per_block},
{"max_local_memory_per_block", hw_params->max_local_memory_per_block},
{"max_threads_per_block", hw_params->max_threads_per_block},
{"max_vector_bytes", hw_params->vector_unit_bytes},
{"max_vthread", hw_params->max_vthread_extent},
};
pass_list = Array<tvm::transform::Pass>(
{// Phase 0
InjectPrefetch(), StorageFlatten(64, instrument_bound_checkers),
// Phase 1
NarrowDataType(32), Simplify(), VectorizeLoop(!disable_vectorize), InjectVirtualThread(),
StorageRewrite(), Simplify(), VerifyGPUCode(gpu_params)});
} else {
pass_list = Array<tvm::transform::Pass>({Simplify()});
}
return Sequential(pass_list);
}
Stmt GenerateCodeForState(const SearchTask& task, const State& state, bool print_error) {
te::Schedule sch;
Array<te::Tensor> tensors;
std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps);
// When inlining, replace const matrices with const values.
// Produces wrong IR, but good enough for feature extraction, and
// can improve the speed of feature extraction/search. Must be
// called before ScheduleToModule to have an effect.
sch = sch.normalize_for_feature_extraction();
try {
const std::string& name = "main";
auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), tensors.end()}, name,
std::unordered_map<te::Tensor, te::Buffer>());
auto passes = GetCodeGenPasses(task->hardware_params, IsGPUTask(task));
passes(mod);
PrimFunc prim_func = Downcast<PrimFunc>(mod->Lookup(name));
return prim_func->body;
} catch (Error& e) {
if (print_error) LOG_WARNING << "Failed to generate code: " << e.what() << "\n";
return Stmt();
}
}
TVM_REGISTER_GLOBAL("auto_scheduler.CreateStateFromEncodedSteps")
.set_body_typed([](const ComputeDAG& taskDAG, const String& jsonString) {
auto state = taskDAG->init_state;
std::istringstream is(jsonString);
dmlc::JSONReader reader(&is);
Array<Step> steps;
reader.Read(&steps);
for (auto& step : steps) {
state.CopyOnWrite()->transform_steps.push_back(step);
StepApplyToState(step, &state, taskDAG);
}
state = taskDAG.InferBound(state);
return state;
});
TVM_REGISTER_GLOBAL("auto_scheduler.EncodeTrSteps").set_body_typed([](const Array<Step>& steps) {
std::ostringstream os;
dmlc::JSONWriter writer(&os);
writer.Write(steps);
return os.str();
});
TVM_REGISTER_GLOBAL("auto_scheduler.TransformStepToStr").set_body_typed(TransformStepToStr);
TVM_REGISTER_GLOBAL("auto_scheduler.ReplayStepsGenCode").set_body_typed(ReplayStepsGenCode);
TVM_REGISTER_GLOBAL("auto_scheduler.GenerateCodeForState")
.set_body_typed([](const SearchTask& task, const State& state) {
return GenerateCodeForState(task, state, true);
});
} // namespace auto_scheduler
} // namespace tvm
namespace dmlc {
namespace json {
template <>
struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> {
inline static void Write(dmlc::JSONWriter* writer,
const ::tvm::Array<::tvm::auto_scheduler::Step>& data) {
writer->BeginArray(false);
for (const auto& step : data) {
writer->WriteArraySeperator();
writer->BeginArray(false);
step->WriteToRecord(writer);
writer->EndArray();
}
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader,
::tvm::Array<::tvm::auto_scheduler::Step>* data) {
bool s;
reader->BeginArray();
data->clear();
while (reader->NextArrayItem()) {
reader->BeginArray();
data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader));
s = reader->NextArrayItem();
ICHECK(!s);
}
}
};
} // namespace json
} // namespace dmlc
......@@ -26,6 +26,9 @@
#define TVM_AUTO_SCHEDULER_UTILS_H_
#include <dmlc/common.h>
#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/auto_scheduler/search_task.h>
#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/tir/expr.h>
#include <algorithm>
......@@ -296,6 +299,16 @@ inline void ParseKernelLayout(const String& layout, Array<PrimExpr>* shape,
/*! \brief Get the base name before '_' of an axis */
inline std::string AxisBaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
void PrintTransformStep(std::ostream& os, const Step& step);
void PrintStateReplaySteps(std::ostream& os, const ComputeDAG& taskDAG, const Array<Step>& trSteps);
String TransformStepToStr(const Step& step);
tvm::transform::Sequential GetCodeGenPasses(const HardwareParams& hw_params, bool is_gpu);
Stmt GenerateCodeForState(const SearchTask& task, const State& state, bool print_error = false);
} // namespace auto_scheduler
} // namespace tvm
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment