jcf94 commented on a change in pull request #8605:
URL: https://github.com/apache/tvm/pull/8605#discussion_r681471388
##########
File path: src/relay/transforms/convert_sparse_conv2d.cc
##########
@@ -155,6 +316,18 @@ Pass Conv2dToSparse(const Array<ObjectRef>& weight_name,
const Array<Array<PrimE
TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse").set_body_typed(Conv2dToSparse);
+
+Pass Conv2dToSparse2(const String& layout, int kernel_size, int blockH, int
blockW, double sparse_thresh) {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ auto f0 = Downcast<Function>(Conv2dToSparse2(f, layout, kernel_size,
blockH, blockW, sparse_thresh));
+ return f0;
+ };
+ return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2",
{"DeadCodeElimination"});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse2").set_body_typed(Conv2dToSparse2);
Review comment:
What's the different between the new added pass with the original pass?
Can this be merged to the original one?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]