Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
T
tvm-fork
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Yifan Zhao
tvm-fork
Commits
56fd61c1
"examples/src/main/python/ml/bisecting_k_means_example.py" did not exist on "414367850982c4f8fc5e63cc94caa422eb736db5"
Commit
56fd61c1
authored
3 years ago
by
Yifan Zhao
Browse files
Options
Downloads
Patches
Plain Diff
Changed the workflow around CHR step
parent
a2ba55a5
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/auto_scheduler/ansor_api.cc
+97
-86
97 additions, 86 deletions
src/auto_scheduler/ansor_api.cc
with
97 additions
and
86 deletions
src/auto_scheduler/ansor_api.cc
+
97
−
86
View file @
56fd61c1
...
@@ -16,16 +16,18 @@ using namespace tvm::tir;
...
@@ -16,16 +16,18 @@ using namespace tvm::tir;
class
BufstoreInfoNode
:
public
Object
{
class
BufstoreInfoNode
:
public
Object
{
public:
public:
size_t
stage_id
,
iter_id
;
size_t
stage_id
,
iter_id
;
BufferStore
bufstore
;
BufferLoad
lhs
;
Array
<
String
>
iters_in_expr
;
Array
<
BufferLoad
>
rhs
;
tvm
::
runtime
::
NDArray
counts
;
BufferStore
orig_bufstore
;
PrimExpr
orig_rhs
;
void
VisitAttrs
(
AttrVisitor
*
v
)
{
void
VisitAttrs
(
AttrVisitor
*
v
)
{
v
->
Visit
(
"stage_id"
,
&
stage_id
);
v
->
Visit
(
"stage_id"
,
&
stage_id
);
v
->
Visit
(
"iter_id"
,
&
iter_id
);
v
->
Visit
(
"iter_id"
,
&
iter_id
);
v
->
Visit
(
"bufstore"
,
&
bufstore
);
v
->
Visit
(
"lhs"
,
&
lhs
);
v
->
Visit
(
"iters_in_expr"
,
&
iters_in_expr
);
v
->
Visit
(
"rhs"
,
&
rhs
);
v
->
Visit
(
"counts"
,
&
counts
);
v
->
Visit
(
"orig_bufstore"
,
&
orig_bufstore
);
v
->
Visit
(
"orig_rhs"
,
&
orig_rhs
);
}
}
static
constexpr
const
char
*
_type_key
=
"ansor.BufstoreInfo"
;
static
constexpr
const
char
*
_type_key
=
"ansor.BufstoreInfo"
;
...
@@ -34,14 +36,15 @@ class BufstoreInfoNode : public Object {
...
@@ -34,14 +36,15 @@ class BufstoreInfoNode : public Object {
class
BufstoreInfo
:
public
ObjectRef
{
class
BufstoreInfo
:
public
ObjectRef
{
public:
public:
explicit
BufstoreInfo
(
size_t
stage_id
,
size_t
iter_id
,
Buffer
Store
bufstore
,
explicit
BufstoreInfo
(
size_t
stage_id
,
size_t
iter_id
,
Buffer
Load
lhs
,
Array
<
BufferLoad
>
rhs
,
Array
<
String
>
iters_in_expr
,
tvm
::
runtime
::
NDArray
count
s
)
{
BufferStore
orig_bufstore
,
PrimExpr
orig_rh
s
)
{
auto
node
=
make_object
<
BufstoreInfoNode
>
();
auto
node
=
make_object
<
BufstoreInfoNode
>
();
node
->
stage_id
=
stage_id
;
node
->
stage_id
=
stage_id
;
node
->
iter_id
=
iter_id
;
node
->
iter_id
=
iter_id
;
node
->
bufstore
=
std
::
move
(
bufstore
);
node
->
lhs
=
std
::
move
(
lhs
);
node
->
iters_in_expr
=
std
::
move
(
iters_in_expr
);
node
->
rhs
=
std
::
move
(
rhs
);
node
->
counts
=
std
::
move
(
counts
);
node
->
orig_bufstore
=
std
::
move
(
orig_bufstore
);
node
->
orig_rhs
=
std
::
move
(
orig_rhs
);
data_
=
std
::
move
(
node
);
data_
=
std
::
move
(
node
);
}
}
...
@@ -49,17 +52,9 @@ class BufstoreInfo : public ObjectRef {
...
@@ -49,17 +52,9 @@ class BufstoreInfo : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD
(
BufstoreInfoNode
);
TVM_DEFINE_OBJECT_REF_COW_METHOD
(
BufstoreInfoNode
);
};
};
// TVM_REGISTER_NODE_TYPE(BufferAccessNode);
TVM_REGISTER_NODE_TYPE
(
BufstoreInfoNode
);
TVM_REGISTER_NODE_TYPE
(
BufstoreInfoNode
);
class
IterVarsExtractor
:
public
StmtExprVisitor
{
public:
explicit
IterVarsExtractor
()
{}
void
VisitExpr_
(
const
VarNode
*
node
)
final
{
++
this
->
varcounts
[
node
];
}
std
::
unordered_map
<
const
VarNode
*
,
size_t
>
varcounts
;
};
class
BufstoreExtractor
:
public
StmtExprVisitor
{
class
BufstoreExtractor
:
public
StmtExprVisitor
{
public:
public:
explicit
BufstoreExtractor
(
const
Array
<
Stage
>&
stages
)
{
explicit
BufstoreExtractor
(
const
Array
<
Stage
>&
stages
)
{
...
@@ -84,66 +79,54 @@ class BufstoreExtractor : public StmtExprVisitor {
...
@@ -84,66 +79,54 @@ class BufstoreExtractor : public StmtExprVisitor {
void
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
void
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
auto
&
name
=
node
->
buffer
->
name
;
auto
&
name
=
node
->
buffer
->
name
;
auto
it
=
this
->
stage_name_to_id
.
find
(
name
);
auto
it
=
this
->
stage_name_to_id
.
find
(
name
);
if
(
it
==
this
->
stage_name_to_id
.
end
())
LOG_FATAL
<<
"Buffer "
<<
name
<<
" is not found"
;
if
(
it
==
this
->
stage_name_to_id
.
end
())
{
IterVarsExtractor
iv_extractor
;
LOG_WARNING
<<
"Buffer "
<<
name
<<
" is not found"
;
iv_extractor
(
node
->
value
);
return
;
BufferStore
bufstore
(
node
->
buffer
,
node
->
value
,
node
->
indices
,
node
->
span
);
size_t
n
=
iv_extractor
.
varcounts
.
size
(),
i
=
0
;
Array
<
String
>
iters_in_expr
;
auto
counts
=
tvm
::
runtime
::
NDArray
::
Empty
({(
int64_t
)
n
},
DLDataType
{
kDLInt
,
32
,
1
},
{
kDLCPU
,
0
});
for
(
auto
&
kv
:
iv_extractor
.
varcounts
)
{
iters_in_expr
.
push_back
(
kv
.
first
->
name_hint
);
static_cast
<
int
*>
(
counts
->
data
)[
i
]
=
(
int
)
kv
.
second
;
i
+=
1
;
}
}
StmtExprVisitor
::
VisitStmt_
(
node
);
this
->
bufstore_info
.
push_back
(
this
->
bufstore_info
.
push_back
(
BufstoreInfo
(
it
->
second
,
itervars_stack
.
size
(),
bufstore
,
iters_in_expr
,
counts
));
BufstoreInfo
(
it
->
second
,
itervars_stack
.
size
()
-
1
,
BufferLoad
(
node
->
buffer
,
node
->
indices
),
std
::
move
(
this
->
buffer_loads
),
BufferStore
(
node
->
buffer
,
node
->
value
,
node
->
indices
),
node
->
value
));
this
->
buffer_loads
=
Array
<
BufferLoad
>
();
}
void
VisitExpr_
(
const
BufferLoadNode
*
node
)
final
{
this
->
buffer_loads
.
push_back
(
BufferLoad
(
node
->
buffer
,
node
->
indices
));
}
}
std
::
unordered_map
<
std
::
string
,
size_t
>
stage_name_to_id
;
std
::
unordered_map
<
std
::
string
,
size_t
>
stage_name_to_id
;
Array
<
BufstoreInfo
>
bufstore_info
;
Array
<
BufstoreInfo
>
bufstore_info
;
Array
<
Iterator
>
itervars_stack
;
Array
<
Iterator
>
itervars_stack
;
const
BufferStoreNode
*
cur_bufstore
;
Array
<
BufferLoad
>
buffer_loads
;
// Cleared at every BufferStore node
};
};
BufstoreInfo
GetBufstoreByName
(
const
SearchTask
&
task
,
const
State
&
state
,
const
Step
&
step
,
BufstoreInfo
GetNewBufstore
(
const
SearchTask
&
task
,
State
&
state
,
const
Step
&
step
)
{
const
std
::
string
&
chr_buf_name
)
{
auto
task_dag
=
task
->
compute_dag
;
Stmt
generated
=
GenerateCodeForState
(
task
,
state
);
state
.
CopyOnWrite
()
->
transform_steps
.
push_back
(
step
);
BufstoreExtractor
extractor
(
state
->
stages
);
StepApplyToState
(
step
,
&
state
,
task_dag
);
extractor
(
generated
);
state
=
task_dag
.
InferBound
(
state
);
bool
found
=
false
;
auto
stmt
=
GenerateCodeForState
(
task
,
state
);
BufstoreInfo
ret
;
if
(
auto
chr
=
step
.
as
<
CacheReadStepNode
>
())
{
for
(
auto
&
bufstore_info
:
extractor
.
bufstore_info
)
{
// The new CHR stage will (somehow) have stage id of chr->stage_id + 1
if
(
bufstore_info
->
bufstore
->
buffer
->
name
!=
chr_buf_name
)
continue
;
int
stage_id
=
chr
->
stage_id
+
1
;
ret
=
bufstore_info
;
BufstoreExtractor
extractor
({
state
->
stages
[
stage_id
]});
found
=
true
;
extractor
(
stmt
);
break
;
if
(
extractor
.
bufstore_info
.
size
()
!=
1
)
LOG_FATAL
<<
"Expected only one bufstore in the new CHR stage"
;
return
extractor
.
bufstore_info
[
0
];
}
}
if
(
!
found
)
return
BufstoreInfo
();
LOG_FATAL
<<
"CHR stage "
<<
TransformStepToStr
(
step
)
<<
" (buffer_name="
<<
chr_buf_name
<<
") not found"
;
return
ret
;
}
}
Array
<
ObjectRef
>
GetCacheReadsBufferStore
(
const
SearchTask
&
task
,
const
State
&
state_
)
{
class
IterVarsExtractor
:
public
StmtExprVisitor
{
const
auto
&
task_dag
=
task
->
compute_dag
;
public:
const
auto
&
tr_steps
=
state_
->
transform_steps
;
explicit
IterVarsExtractor
()
{}
auto
state
=
task_dag
->
init_state
;
Array
<
ObjectRef
>
ret
(
tr_steps
.
size
(),
ObjectRef
());
void
VisitExpr_
(
const
VarNode
*
node
)
final
{
++
this
->
varcounts
[
node
->
name_hint
];
}
for
(
size_t
i
=
0
;
i
<
tr_steps
.
size
();
++
i
)
{
auto
&
step
=
tr_steps
[
i
];
std
::
unordered_map
<
String
,
size_t
>
varcounts
;
state
.
CopyOnWrite
()
->
transform_steps
.
push_back
(
step
);
};
StepApplyToState
(
step
,
&
state
,
task_dag
);
auto
*
chr
=
step
.
as
<
CacheReadStepNode
>
();
if
(
!
chr
)
continue
;
auto
&
chr_stage
=
state
->
stages
[
chr
->
stage_id
+
1
];
auto
bufstore_info
=
GetBufstoreByName
(
task
,
state
,
step
,
chr_stage
->
op
->
name
);
ret
.
Set
(
i
,
bufstore_info
);
}
return
ret
;
}
Array
<
Stmt
>
ReplayStepsGenCode
(
const
SearchTask
&
task
,
const
Array
<
Step
>&
trSteps
)
{
Array
<
Stmt
>
ReplayStepsGenCode
(
const
SearchTask
&
task
,
const
Array
<
Step
>&
trSteps
)
{
const
auto
&
taskDAG
=
task
->
compute_dag
;
const
auto
&
taskDAG
=
task
->
compute_dag
;
...
@@ -162,16 +145,56 @@ Array<Stmt> ReplayStepsGenCode(const SearchTask& task, const Array<Step>& trStep
...
@@ -162,16 +145,56 @@ Array<Stmt> ReplayStepsGenCode(const SearchTask& task, const Array<Step>& trStep
return
generated_stmts
;
return
generated_stmts
;
}
}
TVM_REGISTER_GLOBAL
(
"auto_scheduler.GetBufferStores"
)
Array
<
Step
>
DecodeSteps
(
const
String
&
jsonString
)
{
.
set_body_typed
([](
const
SearchTask
&
task
,
const
State
&
state
)
{
std
::
istringstream
is
(
jsonString
);
dmlc
::
JSONReader
reader
(
&
is
);
Array
<
Step
>
steps
;
reader
.
Read
(
&
steps
);
return
steps
;
}
TVM_REGISTER_GLOBAL
(
"auto_scheduler.GetInitialBufstores"
)
.
set_body_typed
([](
const
SearchTask
&
task
)
{
auto
state
=
task
->
compute_dag
->
init_state
;
auto
stmt
=
GenerateCodeForState
(
task
,
state
);
auto
stmt
=
GenerateCodeForState
(
task
,
state
);
BufstoreExtractor
extractor
(
state
->
stages
);
BufstoreExtractor
extractor
(
state
->
stages
);
extractor
(
stmt
);
extractor
(
stmt
);
return
extractor
.
bufstore_info
;
return
extractor
.
bufstore_info
;
});
});
TVM_REGISTER_GLOBAL
(
"auto_scheduler.GetCacheReadsBufferStore"
)
TVM_REGISTER_GLOBAL
(
"auto_scheduler.GetNewBufstore"
)
.
set_body_typed
(
GetCacheReadsBufferStore
);
.
set_body_typed
([](
const
SearchTask
&
task
,
const
State
&
state
,
const
String
&
step_json
)
{
std
::
istringstream
is
(
step_json
);
dmlc
::
JSONReader
reader
(
&
is
);
reader
.
BeginArray
();
Step
step
=
StepReadFromRecord
(
&
reader
);
ICHECK
(
!
reader
.
NextArrayItem
());
State
new_state
=
state
;
auto
bufstore
=
GetNewBufstore
(
task
,
new_state
,
step
);
return
Array
<
ObjectRef
>
{
new_state
,
bufstore
};
});
TVM_REGISTER_GLOBAL
(
"auto_scheduler.CountIterators"
).
set_body_typed
([](
const
PrimExpr
&
expr
)
{
IterVarsExtractor
extractor
;
extractor
(
expr
);
Array
<
Array
<
ObjectRef
>>
ret
;
for
(
const
auto
&
pair
:
extractor
.
varcounts
)
{
ObjectRef
name
=
pair
.
first
;
DataType
dtype
{
kDLUInt
,
32
,
1
};
ObjectRef
count
=
IntImm
(
dtype
,
(
int64_t
)
pair
.
second
);
ret
.
push_back
(
Array
<
ObjectRef
>
{
name
,
count
});
}
return
ret
;
});
TVM_REGISTER_GLOBAL
(
"auto_scheduler.EncodeTrSteps"
).
set_body_typed
([](
const
Array
<
Step
>&
steps
)
{
std
::
ostringstream
os
;
dmlc
::
JSONWriter
writer
(
&
os
);
writer
.
Write
(
steps
);
return
os
.
str
();
});
/********** Debug APIs ********************************************************/
TVM_REGISTER_GLOBAL
(
"auto_scheduler.TransformStepToStr"
).
set_body_typed
(
TransformStepToStr
);
TVM_REGISTER_GLOBAL
(
"auto_scheduler.TransformStepToStr"
).
set_body_typed
(
TransformStepToStr
);
...
@@ -186,15 +209,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.PrintStateAllLoops").set_body_typed([](const
...
@@ -186,15 +209,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.PrintStateAllLoops").set_body_typed([](const
return
state
.
ToStr
(
false
);
return
state
.
ToStr
(
false
);
});
});
TVM_REGISTER_GLOBAL
(
"auto_scheduler.ApplyEncodedStepsToInitState"
)
TVM_REGISTER_GLOBAL
(
"auto_scheduler.CreateStateFromEncodedSteps"
)
.
set_body_typed
([](
const
ComputeDAG
&
taskDAG
,
const
String
&
jsonString
)
{
.
set_body_typed
([](
const
ComputeDAG
&
taskDAG
,
const
String
&
jsonString
)
{
auto
state
=
taskDAG
->
init_state
;
auto
state
=
taskDAG
->
init_state
;
std
::
istringstream
is
(
jsonString
);
for
(
auto
&
step
:
DecodeSteps
(
jsonString
))
{
dmlc
::
JSONReader
reader
(
&
is
);
Array
<
Step
>
steps
;
reader
.
Read
(
&
steps
);
for
(
auto
&
step
:
steps
)
{
state
.
CopyOnWrite
()
->
transform_steps
.
push_back
(
step
);
state
.
CopyOnWrite
()
->
transform_steps
.
push_back
(
step
);
StepApplyToState
(
step
,
&
state
,
taskDAG
);
StepApplyToState
(
step
,
&
state
,
taskDAG
);
}
}
...
@@ -202,13 +220,6 @@ TVM_REGISTER_GLOBAL("auto_scheduler.CreateStateFromEncodedSteps")
...
@@ -202,13 +220,6 @@ TVM_REGISTER_GLOBAL("auto_scheduler.CreateStateFromEncodedSteps")
return
state
;
return
state
;
});
});
TVM_REGISTER_GLOBAL
(
"auto_scheduler.EncodeTrSteps"
).
set_body_typed
([](
const
Array
<
Step
>&
steps
)
{
std
::
ostringstream
os
;
dmlc
::
JSONWriter
writer
(
&
os
);
writer
.
Write
(
steps
);
return
os
.
str
();
});
}
// namespace auto_scheduler
}
// namespace auto_scheduler
}
// namespace tvm
}
// namespace tvm
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment