tqchen commented on a change in pull request #4482: [Relay] External codegen
URL: https://github.com/apache/incubator-tvm/pull/4482#discussion_r357492770
 
 

 ##########
 File path: src/relay/backend/contrib/contrib_codegen.h
 ##########
 @@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/contrib_codegen.h
+ * \brief The base class for external codegen tools.
+ */
+#ifndef TVM_RELAY_BACKEND_CONTRIB_CONTRIB_CODEGEN_H_
+#define TVM_RELAY_BACKEND_CONTRIB_CONTRIB_CODEGEN_H_
+
+#include <tvm/relay/expr.h>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+
+class ExternCodegenBase {
+ public:
+  ExternCodegenBase() = default;
+
+  /*!
+   * \brief Create a runtime module for the external library. For example, it
+   * could be a CSourceModule that can be directly compiled and linked together
+   * with a DSOModule, or a json style module that emitts a json artifact that
+   * is able to be executed by a customized json runtime.
+   *
+   * \param ref The subgraph Relay expression/module to be executed using 
extern ops.
+   *
+   * \return A runtime module.
+   */
+  virtual runtime::Module CreateExternModule(const NodeRef& ref) = 0;
+
+  /*!
+   * \brief Split the Relay function name to tokens.
+   *
+   * \param func The provided function.
+   * \param prefix The prefix of the function name, i.e. dnnl.
+   *
+   * \return A vector of tokenized function name splitted by "_".
+   */
+  std::string GetSubgraphID(const Function& func, const std::string& prefix) 
const {
+    const auto name_node =
+        FunctionGetAttr(func, attr::kFuncName).as<tvm::ir::StringImm>();
+    CHECK(name_node != nullptr) << "Fail to retrieve subgraph name.";
+    std::string name = name_node->value;
+    return GetSubgraphID(name, prefix);
+  }
+
+  /*!
+   * \brief Split the encoded function name to tokens.
+   *
+   * \param the function name string.
+   *
+   * \return a vector of tokenized function name splitted by "_".
+   */
+  std::string GetSubgraphID(const std::string& name, const std::string& 
prefix) const {
+    std::string temp = name;
+    std::vector<std::string> tokens;
+    std::string delimiter = "_";
+    size_t pos = 0;
+    std::string token;
+    while ((pos = temp.find(delimiter)) != std::string::npos) {
+      token = temp.substr(0, pos);
+      tokens.push_back(token);
+      temp.erase(0, pos + delimiter.length());
+    }
+    tokens.push_back(temp);
+
+    CHECK(tokens.size() >= 2) << "Invalid subgraph name: " << name;
+    CHECK(tokens[0] == prefix)
+        << "Function name: " << name
+        << " does not start with: " << prefix;
+    return tokens[1];
+  }
+};
+
+// A helper class to write the declaration of external functions.
+class ExternSourcePrinter {
+ protected:
+  /*! \brief Print indents using spaces. */
+  void PrintIndents() {
+    for (int i = 0; i < indent_; i++) {
+      code_stream_ << ' ';
+    }
+  }
+
+  /*!
+   * \brief Enter a new scope.
+   */
+  void EnterScope() { indent_ += 2; }
+
+  /*!
+   * \brief Exit a scope.
+   */
+  void ExitScope() {
+    CHECK_GE(indent_, 2U) << "Wrong ident found.";
+    indent_ -= 2;
+  }
+
+  /*!
+   * \brief Gerenate a wrapper for the subgraph that will use external codegen.
+   *
+   * \param func_name The name of wrapper function.
+   * \param arg_cnt The expected number of arguments for the wrapper.
+   *
+   * \code
+   *
+   * // An example code for the wrapper.
+   * extern "C" void foo(TVMValue* value, int* type_code, int nargs) {
+   *   if (nargs != 3) {
+   *     printf("foo expects 3 args, but received %d\n", nargs);
+   *     return 1;
+   *   }
+   *
+   *   DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
+   *   DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
+   *   DLTensor* out = static_cast<DLTensor*>(value[2].v_handle);
+   *
+   *   foo_(static_cast<float*>(arg0->data),
+   *        static_cast<float*>(arg1->data),
+   *        static_cast<float*>(out->data));
+   *   return 0;
+   * }
+   *
+   * \endcode
+   */
+  void GenerateSubgraphWrapper(const std::string& func_name, int arg_cnt) {
 
 Review comment:
   GenerateBackendCFunc

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to