manupa-arm commented on a change in pull request #9163:
URL: https://github.com/apache/tvm/pull/9163#discussion_r719578588



##########
File path: src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
##########
@@ -32,17 +32,37 @@ namespace relay {
 namespace contrib {
 namespace cmsisnn {
 
-class RelayToTIR : public MixedModeVisitor {
+class RelayToTIRVisitor : public MixedModeVisitor {
  public:
-  explicit RelayToTIR(String func_name) : func_name_(func_name) {}
+  explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {}
+
+  tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; }
 
  private:
-  void emit_softmax_tir(const Expr& expr) {
+  template <typename T>
+  const T ArgumentToConstantValue(const Expr& arg) {

Review comment:
       nit : Why is the function specific to "Argument" ? It seems like it can 
work for any Expr

##########
File path: src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
##########
@@ -79,15 +99,51 @@ class RelayToTIR : public MixedModeVisitor {
         IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), 
row_size),
         IntImm(DataType::Int(32), mult),     IntImm(DataType::Int(32), shift),
         IntImm(DataType::Int(32), diff_min), out_var};
-    tir::Stmt body =
-        tir::Evaluate(tvm::tir::Call(DataType::Int(8), 
tir::builtin::call_extern(), args));
 
-    Map<String, ObjectRef> dict_attrs;
-    dict_attrs.Set("global_symbol", func_name_);
-    dict_attrs.Set("tir.noalias", Bool(true));
+    CreatePrimFuncForExtern(func_signature, args);
+  }
 
-    primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, 
tir::Buffer>(),
-                              DictAttrs(dict_attrs));
+  void EmitMul(const Expr& expr) {
+    auto* mul_call = expr.as<CallNode>();
+
+    const float input_0_scale = 
ArgumentToConstantValue<float>(mul_call->args[2]);
+    const int32_t input_0_zero_point = 
ArgumentToConstantValue<int32_t>(mul_call->args[3]);
+    const float input_1_scale = 
ArgumentToConstantValue<float>(mul_call->args[4]);
+    const int32_t input_1_zero_point = 
ArgumentToConstantValue<int32_t>(mul_call->args[5]);
+    const float output_scale = 
ArgumentToConstantValue<float>(mul_call->args[6]);
+    const int32_t output_zero_point = 
ArgumentToConstantValue<int32_t>(mul_call->args[7]);
+
+    double quantized_multiplier = static_cast<double>(input_0_scale) *
+                                  static_cast<double>(input_1_scale) /
+                                  static_cast<double>(output_scale);
+    auto mult_shift_pair = 
tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
+    int32_t output_multiplier = std::get<0>(mult_shift_pair);
+    int32_t output_shift = std::get<1>(mult_shift_pair);
+
+    PrimExpr tensor_size = mul_call->type_as<TensorTypeNode>()->Size();
+
+    tir::Var input_0("input_0", DataType::Handle(8));

Review comment:
       Its worth a comment why we create a handle of 8 bits here. Why do we 
think 8-bits are sufficient for all cases ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to