apeskov commented on code in PR #11966:
URL: https://github.com/apache/tvm/pull/11966#discussion_r911878641


##########
src/relay/op/nn/nn.cc:
##########
@@ -259,10 +260,25 @@ bool DensePackRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs,
   ICHECK(param != nullptr);
 
   ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
-  ICHECK(weight->shape.size() == 3 || weight->shape.size() == 4) << "Expect 
weight to be 3D or 4D";
+  ICHECK(weight->shape.size() == 2 || weight->shape.size() == 3 || 
weight->shape.size() == 4) << "Expect weight to be 2D, 3D or 4D";
 
   Array<tvm::PrimExpr> oshape = data->shape;
-  oshape.Set(1, weight->shape[0] * weight->shape[2]);
+
+  std::string weight_layout = param->weight_layout;
+  std::regex blk_cn_fmt("NC[[:digit:]]+c[[:digit:]]+n");
+  std::regex blk_nc_fmt("NC[[:digit:]]+n[[:digit:]]+c");
+
+  if (weight->shape.size() == 3) {
+    oshape.Set(1, weight->shape[0] * weight->shape[2]);

Review Comment:
   In this line you assume that `wgh.shape.size() == 3` means that layout is 
`NC_n`. I see that this logic is inherited from previous implementation. But 
"NC_c", "CN_c" and "CN_n" are also possible layout value.
   
   Could you please check if assumption is correct and throw error message if 
layout is not "NC_n".



##########
tests/python/contrib/test_dnnl.py:
##########
@@ -94,6 +94,9 @@ def partition_for_dnnl(mod, params=None, alter_layout=True, 
prune_subgraphs=True
             )
             with tvm.transform.PassContext(opt_level=3):
                 mod = seq(mod)
+
+    mod = dnnl.rewrite_dense_bias_gelu_reshape_last(mod)

Review Comment:
   Could you please explain, why do we need this pass before "AlterOp" 
transformation passes?



##########
src/relay/op/nn/nn.cc:
##########
@@ -259,10 +260,25 @@ bool DensePackRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs,
   ICHECK(param != nullptr);
 
   ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
-  ICHECK(weight->shape.size() == 3 || weight->shape.size() == 4) << "Expect 
weight to be 3D or 4D";
+  ICHECK(weight->shape.size() == 2 || weight->shape.size() == 3 || 
weight->shape.size() == 4) << "Expect weight to be 2D, 3D or 4D";
 
   Array<tvm::PrimExpr> oshape = data->shape;
-  oshape.Set(1, weight->shape[0] * weight->shape[2]);
+
+  std::string weight_layout = param->weight_layout;
+  std::regex blk_cn_fmt("NC[[:digit:]]+c[[:digit:]]+n");
+  std::regex blk_nc_fmt("NC[[:digit:]]+n[[:digit:]]+c");
+
+  if (weight->shape.size() == 3) {
+    oshape.Set(1, weight->shape[0] * weight->shape[2]);
+  } else if (weight->shape.size() == 4) {
+    if (std::regex_match(weight_layout, blk_cn_fmt)) {
+      oshape.Set(1, weight->shape[0] * weight->shape[3]);
+    } else if (std::regex_match(weight_layout, blk_nc_fmt)) {
+      oshape.Set(1, weight->shape[0] * weight->shape[2]);
+    }

Review Comment:
   The same situation as previous with weight rank 3. Layout can be "CN_c_n" or 
"CN_n_c". With current implementation it will be just skipped without any 
warning/error message. Could you please improve it? 



##########
src/runtime/contrib/dnnl/dnnl_json_runtime.cc:
##########
@@ -241,6 +243,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
           Deconvolution(nid);
         } else if (std::regex_match(op_name, conv_pat)) {
           Convolution(nid);
+        } else if (std::regex_match(op_name, dense_pack_pat) ||
+                   std::regex_match(op_name, stock_dense_pack_pat)) {
+          Dense(nid, true);

Review Comment:
   No need in dense version with name "packeddense". That is absolute 
equivalent of original "Dense" primitive with one additional attribute 
`weight_layout`. So you may use previous implementation of Dense with taking 
into account weight layout. Example is below:
   
   ```cpp
   // yes, GetNodeAttr() supports default value as optional argument 
   auto wgh_layout = GetNodeAttr<std::string>(node, "weight_layout", {"NC"});
   wgh_tr = wgh_tr.TreatAs(wgh_layout, "NC");
   ```



##########
src/relay/op/nn/nn.cc:
##########
@@ -259,10 +260,25 @@ bool DensePackRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs,
   ICHECK(param != nullptr);
 
   ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
-  ICHECK(weight->shape.size() == 3 || weight->shape.size() == 4) << "Expect 
weight to be 3D or 4D";
+  ICHECK(weight->shape.size() == 2 || weight->shape.size() == 3 || 
weight->shape.size() == 4) << "Expect weight to be 2D, 3D or 4D";
 
   Array<tvm::PrimExpr> oshape = data->shape;
-  oshape.Set(1, weight->shape[0] * weight->shape[2]);
+
+  std::string weight_layout = param->weight_layout;
+  std::regex blk_cn_fmt("NC[[:digit:]]+c[[:digit:]]+n");
+  std::regex blk_nc_fmt("NC[[:digit:]]+n[[:digit:]]+c");
+
+  if (weight->shape.size() == 3) {
+    oshape.Set(1, weight->shape[0] * weight->shape[2]);
+  } else if (weight->shape.size() == 4) {
+    if (std::regex_match(weight_layout, blk_cn_fmt)) {
+      oshape.Set(1, weight->shape[0] * weight->shape[3]);
+    } else if (std::regex_match(weight_layout, blk_nc_fmt)) {
+      oshape.Set(1, weight->shape[0] * weight->shape[2]);
+    }
+  } else {
+    oshape.Set(1, weight->shape[0]);

Review Comment:
   The same. "NC" and "CN" are possible.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to