Mousius commented on a change in pull request #9510:
URL: https://github.com/apache/tvm/pull/9510#discussion_r750249633



##########
File path: tests/python/relay/aot/aot_test_utils.py
##########
@@ -651,7 +651,17 @@ def run_and_check(
         t.extractall(base_path)
 
         workspace_bytes += model.extra_memory_in_bytes
-        workspace_bytes += mlf_extract_workspace_size_bytes(tar_file)
+        mlf_workspace_size = mlf_extract_workspace_size_bytes(tar_file)
+        workspace_bytes += mlf_workspace_size
+
+        if interface_api == "c":
+            header_path = os.path.join(base_path, 
f"codegen/host/include/tvmgen_{model.name}.h")
+            with open(header_path, "r") as header_file:
+                header_contents = header_file.read()
+                assert (
+                    f"#define TVMGEN_{model.name.upper()}_WORKSPACE_SIZE 
{mlf_workspace_size}"
+                    in header_contents
+                )

Review comment:
       Rather than asserting the constant is there, the AOT tests should try to 
use the constant when generating the workspace and stop using the JSON variant 
- so we can test whether it works E2E.

##########
File path: src/target/source/interface_c.cc
##########
@@ -140,15 +145,26 @@ class InterfaceCNode : public runtime::ModuleNode {
     code_stream << ");\n";
   }
 
+  void EmitWorkspaceSize(std::stringstream& code_stream) {
+    std::string workspace_size_name =
+        ToCConstantStyle(PrefixGeneratedName({module_name_, 
"WORKSPACE_SIZE"}));
+    code_stream << "/*!\n"
+                << " * \\brief Workspace size \n"

Review comment:
       ```suggestion
                   << " * \\brief Workspace size for TVM module \"" << 
module_name_ << "\"\n"
   ```

##########
File path: python/tvm/micro/model_library_format.py
##########
@@ -319,7 +321,10 @@ def _export_graph_model_library_format(
         include_path.mkdir()
         inputs, outputs = _get_inputs_and_outputs_from_module(mod)
         devices = mod.get_devices()
-        generate_c_interface_header(mod.libmod_name, inputs, outputs, devices, 
include_path)
+        workspace_size = 
str(metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"])

Review comment:
       We should pass this as an integer rather than a string

##########
File path: tests/cpp/target/source/interface_c_test.cc
##########
@@ -257,18 +263,27 @@ TEST(InterfaceAPI, ContainsDeviceStructSanitised) {
                 << "};\n\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, 
{"device+1", "device+2"});
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, 
{"device+1", "device+2"}, "");
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
 }
 
 TEST(InterfaceAPI, ContainsDeviceStructClash) {
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, 
{"device+", "device-"});
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, 
{"device+", "device-"}, "");
   ASSERT_THROW(test_module->GetSource(), InternalError);
 }
 
+TEST(InterfaceAPI, ContainsWorkspaceSize) {
+  runtime::Module test_module =
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 
"765432");
+  std::string header_source = test_module->GetSource();
+
+  ASSERT_THAT(header_source,
+              HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_WORKSPACE_SIZE 
765432"));

Review comment:
       Can we check the brief as well to ensure we don't break it in future? (I 
have done this by accident when refactoring it)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to