billishyahao commented on code in PR #11966:
URL: https://github.com/apache/tvm/pull/11966#discussion_r919276218
##########
src/relay/op/nn/nn.cc:
##########
@@ -259,10 +260,25 @@ bool DensePackRel(const Array<Type>& types, int
num_inputs, const Attrs& attrs,
ICHECK(param != nullptr);
ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
- ICHECK(weight->shape.size() == 3 || weight->shape.size() == 4) << "Expect
weight to be 3D or 4D";
+ ICHECK(weight->shape.size() == 2 || weight->shape.size() == 3 ||
weight->shape.size() == 4) << "Expect weight to be 2D, 3D or 4D";
Array<tvm::PrimExpr> oshape = data->shape;
- oshape.Set(1, weight->shape[0] * weight->shape[2]);
+
+ std::string weight_layout = param->weight_layout;
+ std::regex blk_cn_fmt("NC[[:digit:]]+c[[:digit:]]+n");
+ std::regex blk_nc_fmt("NC[[:digit:]]+n[[:digit:]]+c");
+
+ if (weight->shape.size() == 3) {
+ oshape.Set(1, weight->shape[0] * weight->shape[2]);
Review Comment:
Thanks for the comment. I fix this by introducing a new function.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]