This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5efe8b0bfd Enhancement for fold_scale_axis and dnnl_json_runtime 
(#11815)
5efe8b0bfd is described below

commit 5efe8b0bfdff4c9939185a7581dc77e23cbcb6d5
Author: Ivy Zhang <[email protected]>
AuthorDate: Mon Jul 4 14:06:00 2022 +0800

    Enhancement for fold_scale_axis and dnnl_json_runtime (#11815)
    
    * enhance WA in dnnl_convolution, support crop for tensor with mismatched 
groups and OC
    
    * add missing param checks for conv2d, conv3d
    
    * fix lint
---
 src/relay/transforms/fold_scale_axis.cc       |  8 ++++++++
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 10 ++++++++--
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/src/relay/transforms/fold_scale_axis.cc 
b/src/relay/transforms/fold_scale_axis.cc
index f4f05badec..7cc15a8f93 100644
--- a/src/relay/transforms/fold_scale_axis.cc
+++ b/src/relay/transforms/fold_scale_axis.cc
@@ -588,9 +588,11 @@ Expr ConvForwardRewrite(const Call& ref_call, const ATTRS* 
param, const Array<Ex
 Array<Message> PreConvForwardPrep(const Call& call, const Message& 
out_message) {
   if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
     const auto* param = call->attrs.as<Conv2DAttrs>();
+    ICHECK(param != nullptr);
     return ConvForwardPrep(call, param, out_message);
   }
   const auto* param = call->attrs.as<Conv3DAttrs>();
+  ICHECK(param != nullptr);
   return ConvForwardPrep(call, param, out_message);
 }
 
@@ -598,9 +600,11 @@ Expr PreConvForwardRewrite(const Call& ref_call, const 
Array<Expr>& new_args,
                            const Message& message) {
   if (backend::IsOp(ref_call.as<CallNode>(), "nn.conv2d")) {
     const auto* param = ref_call->attrs.as<Conv2DAttrs>();
+    ICHECK(param != nullptr);
     return ConvForwardRewrite(ref_call, param, new_args, message);
   }
   const auto* param = ref_call->attrs.as<Conv3DAttrs>();
+  ICHECK(param != nullptr);
   return ConvForwardRewrite(ref_call, param, new_args, message);
 }
 
@@ -1040,9 +1044,11 @@ Expr ConvBackwardTransform(const Call& call, const 
ATTRS* param, const Message&
 Message PreConvBackwardPrep(const Call& call, const Array<Message>& 
in_messages) {
   if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
     const auto* param = call->attrs.as<Conv2DAttrs>();
+    ICHECK(param != nullptr);
     return ConvBackwardPrep(call, param, in_messages);
   }
   const auto* param = call->attrs.as<Conv3DAttrs>();
+  ICHECK(param != nullptr);
   return ConvBackwardPrep(call, param, in_messages);
 }
 
@@ -1050,9 +1056,11 @@ Expr PreConvBackwardTransform(const Call& call, const 
Message& message, const Ex
                               const BackwardTransformer& transformer) {
   if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
     const auto* param = call->attrs.as<Conv2DAttrs>();
+    ICHECK(param != nullptr);
     return ConvBackwardTransform(call, param, message, scale, transformer);
   }
   const auto* param = call->attrs.as<Conv3DAttrs>();
+  ICHECK(param != nullptr);
   return ConvBackwardTransform(call, param, message, scale, transformer);
 }
 
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc 
b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index a4239186b4..a46f170fea 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -318,9 +318,15 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     // Let's try to compensate it for weight tensor. Weight IC should match 
with source IC.
     // Example src: [1, 3, 224, 224] with layout NCHW
     //         wgh: [16, 3, 3, 3] with layout OIHW2i8o -> [2, 2, 3, 3, 2, 8]
-    if (wgh_tr.dims()[2] != src_tr.dims()[1] / groups) {
+    // Similarly, Weight OC should match with destination OC.
+    // Example dst: [1, 1000, 7, 7] with layout NCHW
+    //         wgh: [1000, 1024, 1, 1] with layout OIHW48o -> [21, 1024, 1, 1, 
48]
+    if (wgh_tr.dims()[0] != groups || wgh_tr.dims()[1] != dst_tr.dims()[1] / 
groups ||
+        wgh_tr.dims()[2] != src_tr.dims()[1] / groups) {
       auto wgh_croped_dims = wgh_tr.dims();
-      wgh_croped_dims[2] = src_tr.dims()[1];
+      wgh_croped_dims[0] = groups;
+      wgh_croped_dims[1] = dst_tr.dims()[1] / groups;  // wgh_OC = dst_OC / 
groups
+      wgh_croped_dims[2] = src_tr.dims()[1] / groups;  // wgh_IC = src_IC / 
groups
       auto zero_offset = dnnl::memory::dims(wgh_tr.dims().size(), 0);
       wgh_tr = wgh_tr.Crop(wgh_croped_dims, zero_offset);
     }

Reply via email to