huajsj commented on a change in pull request #8497:
URL: https://github.com/apache/tvm/pull/8497#discussion_r675909127
##########
File path: tests/cpp/build_module_test.cc
##########
@@ -200,6 +200,134 @@ TEST(BuildModule, Heterogeneous) {
}
}
+TEST(BuildModule, ZeroCopy) {
+ /*
+ *
+ * A B
+ * \ /
+ * elemwise_add(out0)
+ * \
+ * C copy
+ * \ /
+ * elemwise_sub(out1)
+ */
+
+ using namespace tvm;
+ using namespace tvm::te;
+
+ auto target_llvm = Target("llvm");
+
+ // The shape of input tensors.
+ const int n = 4;
+ Array<PrimExpr> shape{n};
+
+ auto A = placeholder(shape, DataType::Float(32), "A");
+ auto B = placeholder(shape, DataType::Float(32), "B");
+ auto C = placeholder(shape, DataType::Float(32), "C");
+
+ auto elemwise_add = compute(
+ A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "elemwise_add");
+
+ auto copy = placeholder(shape, DataType::Float(32), "__copy");
+ auto elemwise_sub = compute(
+ C->shape, [©, &C](PrimExpr i) { return copy[i] - C[i]; },
"elemwise_sub");
+
+ With<Target> llvm_scope(target_llvm);
+ auto s1 = create_schedule({elemwise_add->op});
+ auto s2 = create_schedule({elemwise_sub->op});
+
+ auto args1 = Array<Tensor>({A, B, elemwise_add});
+ auto args2 = Array<Tensor>({copy, C, elemwise_sub});
+
+ std::unordered_map<Tensor, Buffer> binds;
+ auto lowered_s1 = LowerSchedule(s1, args1, "elemwise_add", binds);
+ auto lowered_s2 = LowerSchedule(s2, args2, "elemwise_sub", binds);
+ Map<tvm::Target, IRModule> inputs = {{target_llvm, lowered_s1},
{target_llvm, lowered_s2}};
+ auto module = build(inputs, Target());
+
+ // Execute the graph and check the correctness.
+ // Setup graph json.
+ std::string json =
+ "{\"nodes\": [{\"op\": \"null\", \"name\": \"A\", \"inputs\": []}, "
+ "{\"op\": \"null\", \"name\": \"B\", \"inputs\": []}, {\"op\": "
+ "\"tvm_op\", \"name\": \"elemwise_add\", \"attrs\": {\"flatten_data\": "
+ "\"1\", \"func_name\": \"elemwise_add\", \"num_inputs\": \"2\", "
+ "\"num_outputs\": \"1\"}, \"inputs\": [[0, 0, 0], [1, 0, 0]]}, {\"op\": "
+ "\"tvm_op\", \"name\": \"__copy_add_to_sub\", \"attrs\": "
+ "{\"flatten_data\": \"0\", \"func_name\": \"__copy\", \"num_inputs\": "
+ "\"1\", \"num_outputs\": \"1\"}, \"inputs\": [[2, 0, 0]]}, {\"op\": "
+ "\"null\", \"name\": \"C\", \"inputs\": []}, {\"op\": \"tvm_op\", "
+ "\"name\": \"elemwise_sub\", \"attrs\": {\"flatten_data\": \"0\", "
+ "\"func_name\": \"elemwise_sub\", \"num_inputs\": \"2\", "
+ "\"num_outputs\": \"1\"}, \"inputs\": [[3, 0, 0], [4, 0, 0]]}], "
+ "\"arg_nodes\": [0, 1, 4], \"node_row_ptr\": [0, 1, 2, 3, 4, 5, 6], "
+ "\"heads\": [[2, 0, 0], [5, 0, 0]], \"attrs\": {\"storage_id\":
[\"list_int\", [3, "
+ "4, 0, 1, 5, 2]], \"shape\": [\"list_shape\", [[4], [4], [4], [4], [4], "
+ "[4]]], \"device_index\": [\"list_int\", [2, 2, 2, 1, 1, 1]], \"dtype\":
"
+ "[\"list_int\", [0, 0, 0, 0, 0, 0]], \"dltype\": [\"list_str\", "
+ "[\"float32\", \"float32\", \"float32\", \"float32\", \"float32\", "
+ "\"float32\"]]}}";
+ // Setup inputs.
+ auto a_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+ auto b_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+ auto c_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+
+ auto pa = (float*)(a_val->data);
+ auto pb = (float*)(b_val->data);
+ auto pc = (float*)(c_val->data);
+
+ // Assign values.
+ for (int i = 0; i < n; i++) {
+ pa[i] = i;
+ pb[i] = i + 1.0;
+ pc[i] = i - 1.0;
+ }
+
+ // // Initialize graph executor.
+ int cpu_dev_ty = static_cast<int>(kDLCPU);
+ int cpu_dev_id = 0;
+
+ const runtime::PackedFunc* graph_executor =
+ tvm::runtime::Registry::Get("tvm.graph_executor.create");
+ runtime::Module mod = (*graph_executor)(json, module, cpu_dev_ty,
cpu_dev_id);
+
+ // test FFI for module.
+ auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
+ int tcode = args[1];
+ ICHECK_EQ(args[0].type_code(), tcode);
+ });
+
+ test_ffi(runtime::Module(mod), static_cast<int>(kTVMModuleHandle));
+ test_ffi(Optional<runtime::Module>(mod), static_cast<int>(kTVMModuleHandle));
+
+ PackedFunc set_input = mod.GetFunction("set_input", false);
+ PackedFunc run = mod.GetFunction("run", false);
+ PackedFunc get_output = mod.GetFunction("get_output", false);
+ PackedFunc set_output_zero_copy = mod.GetFunction("set_output_zero_copy",
false);
+ set_input("A", a_val);
+ set_input("B", b_val);
+ set_input("C", c_val);
+
+ tvm::runtime::NDArray out0 = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1},
{kDLCPU, 0});
+ tvm::runtime::NDArray out1 = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1},
{kDLCPU, 0});
+ set_output_zero_copy(0, out0);
+ set_output_zero_copy(1, out1);
+
+ run();
+ // print_data_entry();
+ float* p_out0 = (float*)out0->data;
+ float* p_out1 = (float*)out1->data;
+
+ // Check correctness.
+ for (int i = 0; i < n; ++i) {
+ ICHECK_LT(std::fabs(p_out0[i] - (i + (i + 1.0))), 1e-5);
+ }
+
+ for (int i = 0; i < n; ++i) {
+ ICHECK_LT(std::fabs(p_out1[i] - (i + (i + 1.0) - (i - 1.0))), 1e-5);
+ }
Review comment:
this make sense, thanks for the explain.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]