mbaret commented on a change in pull request #6222:
URL: https://github.com/apache/incubator-tvm/pull/6222#discussion_r467856126



##########
File path: src/relay/backend/contrib/ethosn/codegen.cc
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/ethosn/codegen.cc
+ * \brief The Relay -> Ethos-N command stream compiler.
+ */
+#include <tvm/relay/expr_functor.h>
+#include <tvm/runtime/module.h>
+
+#include "codegen_ethosn.h"
+#include "ethosn_api.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace ethosn {
+
+sl::TensorInfo GetTensorInfo(std::map<Expr, std::vector<sl::TensorInfo>> 
tensor_table,
+                             const Call& call) {
+  if (tensor_table.find(call) != tensor_table.end()) return 
tensor_table[call][0];
+
+  return sl::TensorInfo();
+}
+
+void InferTensorsVisitor::InferCall(const CallNode* cn) {

Review comment:
       The motivation behind this is principally clarity rather than necessity. 
The InferCall function ends up getting very long as more operators are 
introduced and we wanted to separate this lengthy function from the traversal 
logic so that it is quick to reason about the traversal without having to scan 
through a huge block of code. If you don't think this clarity if worthwhile, 
then we can inline it.

##########
File path: src/relay/backend/contrib/ethosn/codegen.cc
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/ethosn/codegen.cc
+ * \brief The Relay -> Ethos-N command stream compiler.
+ */
+#include <tvm/relay/expr_functor.h>
+#include <tvm/runtime/module.h>
+
+#include "codegen_ethosn.h"
+#include "ethosn_api.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace ethosn {
+
+sl::TensorInfo GetTensorInfo(std::map<Expr, std::vector<sl::TensorInfo>> 
tensor_table,
+                             const Call& call) {
+  if (tensor_table.find(call) != tensor_table.end()) return 
tensor_table[call][0];
+
+  return sl::TensorInfo();
+}
+
+void InferTensorsVisitor::InferCall(const CallNode* cn) {
+  EthosnError err;
+  Call call = GetRef<Call>(cn);
+  // Determine call -> NPU mapping
+  if (EthosnAPI::IsEthosOp(call, "qnn.concatenate")) {
+    ConcatenateParams params;
+    err = EthosnAPI::Concatenate(call, &params);
+    tensor_table_[cn->args[0]] = params.input_infos;
+  } else if (EthosnAPI::IsEthosOp(call, "split")) {
+    SplitParams params;
+    params.input_info = GetTensorInfo(tensor_table_, call);
+    err = EthosnAPI::Split(call, &params);
+    tensor_table_[cn->args[0]] = {params.input_info};
+  } else {
+    err = EthosnError("unknown operator");
+  }
+  if (err) {
+    ReportFatalError(call, err);
+  }
+}
+
+// This will only visit an expression if the expression's tensor info
+// has already been entirely inferred.
+// An example where this is important is a tuple node where each
+// get item node will only infer one field of the tuple's expression info.
+// We don't want to traverse the tuple until all of its fields have been 
inferred.
+void InferTensorsVisitor::VisitInferred(const Expr& expr) {
+  if (tensor_table_.find(expr) != tensor_table_.end()) {
+    for (const auto& tensor_info : tensor_table_[expr]) {
+      if (tensor_info == sl::TensorInfo()) return;
+    }
+    VisitExpr(expr);
+  }
+}
+
+void InferTensorsVisitor::VisitExpr_(const CallNode* cn) {
+  InferCall(cn);
+  // Pre-order visitor
+  for (const auto& arg : cn->args) {
+    VisitInferred(arg);
+  }
+}
+
+void InferTensorsVisitor::VisitExpr_(const TupleNode* tn) {
+  auto tuple = GetRef<Tuple>(tn);
+  CHECK(tensor_table_.find(tuple) != tensor_table_.end());
+  for (size_t i = 0; i < tn->fields.size(); i++) {
+    tensor_table_[tn->fields[i]] = {tensor_table_[tuple][i]};
+  }
+  // Pre-order visitor
+  for (const auto& field : tn->fields) {
+    VisitExpr(field);
+  }
+}
+
+void InferTensorsVisitor::VisitExpr_(const TupleGetItemNode* tgn) {
+  // Don't assume it must be targeting a TupleNode
+  // Vars and calls can still have TupleType
+  auto tg = GetRef<TupleGetItem>(tgn);
+  CHECK(tensor_table_.find(tg) != tensor_table_.end());
+  auto tuple = tg->tuple;
+  auto type = tuple->checked_type().as<TupleTypeNode>();
+  int index = tg->index;
+  // Resize the tensor infos to the tuple size if not already done
+  if (tensor_table_.find(tuple) == tensor_table_.end()) {
+    tensor_table_[tuple].resize(type->fields.size());
+  }
+  tensor_table_[tuple][index] = tensor_table_[tg][0];
+  // Pre-order visitor
+  VisitInferred(tuple);
+}
+
+sl::TensorsAndId MakeOps(const sl::TensorAndId<sl::Operand>& op) {
+  sl::TensorsAndId ops;
+  ops.tensors = {op.tensor};
+  ops.operationId = op.operationId;
+  return ops;
+}
+
+sl::TensorsAndId ConstructNetworkVisitor::HandleCall(const CallNode* cn) {

Review comment:
       Same comment as above.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to