junrushao1994 commented on a change in pull request #10823:
URL: https://github.com/apache/tvm/pull/10823#discussion_r838026907
##########
File path: src/meta_schedule/postproc/rewrite_unbound_block.cc
##########
@@ -195,17 +200,31 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule&
sch) {
if (bind_type == tir::BindType::kBindBlock) {
sch->Bind(fused, "blockIdx.x");
} else if (bind_type == tir::BindType::kBindBlockThread) {
- Array<LoopRV> splits = sch->Split(fused, {NullOpt,
Integer(this->warp_size_)});
- ICHECK_EQ(splits.size(), 2);
- sch->Bind(splits[0], "blockIdx.x");
- sch->Bind(splits[1], "threadIdx.x");
+ int extent_size = 0, max_block = 256;
+ if (const IntImmNode* node = sch->Get(fused)->extent.as<IntImmNode>()) {
+ extent_size = node->value;
+ }
+ Array<LoopRV> splits;
+ if (extent_size > max_block * max_num_threads_) {
+ splits = sch->Split(fused, {NullOpt, Integer(max_block),
Integer(max_num_threads_)});
+ ICHECK_EQ(splits.size(), 3);
+ sch->Reorder({splits[1], splits[2], splits[0]});
+ sch->Bind(splits[1], "blockIdx.x");
+ sch->Bind(splits[2], "threadIdx.x");
+ } else if (extent_size != 0) {
+ splits = sch->Split(fused, {NullOpt,
Integer(std::min(max_num_threads_, extent_size))});
+ ICHECK_EQ(splits.size(), 2);
+ sch->Bind(splits[0], "blockIdx.x");
+ sch->Bind(splits[1], "threadIdx.x");
+ }
}
}
return true;
}
Postproc Postproc::RewriteUnboundBlock() {
ObjectPtr<RewriteUnboundBlockNode> n =
make_object<RewriteUnboundBlockNode>();
+ n->max_num_threads_ = -1;
n->warp_size_ = -1;
Review comment:
Could you confirm if `warp_size` is used any more? If not, let's simply
remove this field
##########
File path: src/meta_schedule/postproc/rewrite_unbound_block.cc
##########
@@ -195,17 +200,31 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule&
sch) {
if (bind_type == tir::BindType::kBindBlock) {
sch->Bind(fused, "blockIdx.x");
} else if (bind_type == tir::BindType::kBindBlockThread) {
- Array<LoopRV> splits = sch->Split(fused, {NullOpt,
Integer(this->warp_size_)});
- ICHECK_EQ(splits.size(), 2);
- sch->Bind(splits[0], "blockIdx.x");
- sch->Bind(splits[1], "threadIdx.x");
+ int extent_size = 0, max_block = 256;
+ if (const IntImmNode* node = sch->Get(fused)->extent.as<IntImmNode>()) {
+ extent_size = node->value;
+ }
+ Array<LoopRV> splits;
+ if (extent_size > max_block * max_num_threads_) {
+ splits = sch->Split(fused, {NullOpt, Integer(max_block),
Integer(max_num_threads_)});
+ ICHECK_EQ(splits.size(), 3);
+ sch->Reorder({splits[1], splits[2], splits[0]});
+ sch->Bind(splits[1], "blockIdx.x");
+ sch->Bind(splits[2], "threadIdx.x");
+ } else if (extent_size != 0) {
+ splits = sch->Split(fused, {NullOpt,
Integer(std::min(max_num_threads_, extent_size))});
+ ICHECK_EQ(splits.size(), 2);
+ sch->Bind(splits[0], "blockIdx.x");
+ sch->Bind(splits[1], "threadIdx.x");
+ }
Review comment:
we will need to handle the case where `extent_size == 0`. For now, you
can either throw an error or write some simplistic split patterns
##########
File path: src/meta_schedule/postproc/rewrite_unbound_block.cc
##########
@@ -195,17 +200,31 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule&
sch) {
if (bind_type == tir::BindType::kBindBlock) {
sch->Bind(fused, "blockIdx.x");
} else if (bind_type == tir::BindType::kBindBlockThread) {
- Array<LoopRV> splits = sch->Split(fused, {NullOpt,
Integer(this->warp_size_)});
- ICHECK_EQ(splits.size(), 2);
- sch->Bind(splits[0], "blockIdx.x");
- sch->Bind(splits[1], "threadIdx.x");
+ int extent_size = 0, max_block = 256;
Review comment:
Use int64_t for safety
##########
File path: src/meta_schedule/postproc/rewrite_unbound_block.cc
##########
@@ -195,17 +200,31 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule&
sch) {
if (bind_type == tir::BindType::kBindBlock) {
sch->Bind(fused, "blockIdx.x");
} else if (bind_type == tir::BindType::kBindBlockThread) {
- Array<LoopRV> splits = sch->Split(fused, {NullOpt,
Integer(this->warp_size_)});
- ICHECK_EQ(splits.size(), 2);
- sch->Bind(splits[0], "blockIdx.x");
- sch->Bind(splits[1], "threadIdx.x");
+ int extent_size = 0, max_block = 256;
+ if (const IntImmNode* node = sch->Get(fused)->extent.as<IntImmNode>()) {
Review comment:
Use `GetLoopIntExtent` instead
##########
File path: src/meta_schedule/postproc/rewrite_unbound_block.cc
##########
@@ -157,6 +157,8 @@ class RewriteUnboundBlockNode : public PostprocNode {
Optional<Integer> warp_size =
context->target.value()->GetAttr<Integer>("thread_warp_size");
CHECK(warp_size.defined()) << "ValueError: missing attribute
`thread_warp_size` in the target";
this->warp_size_ = warp_size.value();
+ this->max_num_threads_ =
+
context->target.value()->GetAttr<Integer>("max_threads_per_block").value();
Review comment:
Let's shoot a clear error message just like what's done with `warp_size`
##########
File path: src/meta_schedule/postproc/rewrite_unbound_block.cc
##########
@@ -195,17 +200,31 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule&
sch) {
if (bind_type == tir::BindType::kBindBlock) {
sch->Bind(fused, "blockIdx.x");
} else if (bind_type == tir::BindType::kBindBlockThread) {
- Array<LoopRV> splits = sch->Split(fused, {NullOpt,
Integer(this->warp_size_)});
- ICHECK_EQ(splits.size(), 2);
- sch->Bind(splits[0], "blockIdx.x");
- sch->Bind(splits[1], "threadIdx.x");
+ int extent_size = 0, max_block = 256;
Review comment:
Note that 256 is is a magic number for `max_block`. Ideally we should:
- Rename `max_block` => `max_threadblock`
- Make it a parameter to the constructor of `RewriteUnboundBlock`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]