manupa-arm commented on a change in pull request #10563:
URL: https://github.com/apache/tvm/pull/10563#discussion_r824706222
##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -86,63 +85,62 @@ class ScalarToTensorConstantMutator : public
MixedModeMutator {
final_call = Call(global_var, call->args);
}
- // Substitute scalar constant with a tensor constant in the call to
composite function
- // comprising partitioned binary ops. Shape of the new constant should be
same as its
- // neighbouring tensor's shape.
+ // Substitute scalar constant with tensor constant in the call to
composite function.
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
- auto func_name = func->GetAttr<String>(attr::kComposite);
- if (func_name.defined() &&
- (func_name == "cmsis-nn.qnn_add" || func_name ==
"cmsis-nn.qnn_mul")) {
- final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
- }
+ final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
}
return final_call;
}
// Replaces scalar variable with a tensor variable with same shape as that
of the neibouring
- // operand tensor in a binary op
+ // operand tensor in a binary op (add or multiply supported via CMSIS-NN
path). This applies only
+ // to 1st and 2nd arguments of the ops.
Call ReplaceScalarWithTensorVariable(Call call) {
const OpNode* opnode = call->op.as<OpNode>();
- if (opnode == nullptr) {
+ if (opnode == nullptr || (opnode->name != "qnn.add" && opnode->name !=
"qnn.mul")) {
Review comment:
Ahhh my bad. there are different.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]