This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 738c2e9e90 [VM][Adreno] Fix using buffers for weights in VM (#15671)
738c2e9e90 is described below
commit 738c2e9e90a6daacc7e581aa1745369ddbdf93f9
Author: Egor Churaev <[email protected]>
AuthorDate: Wed Sep 6 21:46:45 2023 +0300
[VM][Adreno] Fix using buffers for weights in VM (#15671)
* [VM][Adreno] Fix using buffers for weights in VM
In VM `fn->attrs` doesn't contain information about `kernel_layout`. So
we can get this value from `expr_attrib`. In this PR function
`CanUseBuffers` was modified to work with VM.
A new test which checks memory scope for VM was added.
* Fix ci
---
src/relay/transforms/annotate_texture_storage.cc | 8 ++-
.../opencl_texture/test_conv2d_nchw_texture.py | 77 ++++++++++++++++++----
.../relay/opencl_texture/utils/adreno_utils.py | 18 ++---
3 files changed, 76 insertions(+), 27 deletions(-)
diff --git a/src/relay/transforms/annotate_texture_storage.cc
b/src/relay/transforms/annotate_texture_storage.cc
index 4921cef4c8..01d47b6953 100644
--- a/src/relay/transforms/annotate_texture_storage.cc
+++ b/src/relay/transforms/annotate_texture_storage.cc
@@ -174,8 +174,11 @@ class StorageInfo : private
transform::DeviceAwareExprVisitor {
for (const auto& ttype :
FlattenTupleType(fn->params[i]->checked_type())) {
std::string scope = Scope(ttype->shape,
GetVirtualDevice(GetRef<Expr>(call)));
if (expr_attrib.as<Conv2DAttrs>() ||
expr_attrib.as<Conv2DWinogradAttrs>()) {
+ String kernel_layout = expr_attrib.as<Conv2DAttrs>()
+ ?
expr_attrib.as<Conv2DAttrs>()->kernel_layout
+ :
expr_attrib.as<Conv2DWinogradAttrs>()->kernel_layout;
if ((i == weights_pos) && !ttype->dtype.is_float16() &&
- CanUseBuffers(call->args[i], ttype->shape, fn->attrs)) {
+ CanUseBuffers(call->args[i], ttype->shape, kernel_layout))
{
buffers_params.insert(fn->params[i]);
buffers_args.insert(call->args[i]);
scope = "global";
@@ -426,10 +429,9 @@ class StorageInfo : private
transform::DeviceAwareExprVisitor {
}
bool CanUseBuffers(const Expr param, const Array<PrimExpr> shape,
- const tvm::DictAttrs param_attrs) const {
+ const String kernel_layout) const {
bool use_buffer = false;
if (param.as<ConstantNode>() && shape.size() == 5) {
- auto kernel_layout = param_attrs.GetAttr<String>("kernel_layout");
if (kernel_layout == "HWOI4o" || kernel_layout == "HWIO4o") {
int a0 = shape[0].as<IntImmNode>()->value;
int a1 = shape[1].as<IntImmNode>()->value;
diff --git a/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py
b/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py
index 3c9c3f2caf..1dd5ca2abd 100644
--- a/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py
+++ b/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py
@@ -692,7 +692,6 @@ def test_residual_block(remote, target, executor_type,
dtype):
{"data": input_shape},
{"data": dtype},
target,
- static_memory_scope,
)
@@ -790,11 +789,12 @@ def test_concat(remote, target, executor_type, dtype):
static_memory_scope = [
"",
+ "global.texture",
"global",
"global.texture-weight",
- "global.texture-weight",
"global",
- "global.texture-weight",
+ "global.texture-nhwc",
+ "global",
"global.texture-weight",
"",
"",
@@ -803,8 +803,6 @@ def test_concat(remote, target, executor_type, dtype):
"",
]
- static_memory_scope = []
-
if executor_type == "ge":
build_run_compare(
remote,
@@ -823,7 +821,6 @@ def test_concat(remote, target, executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
- static_memory_scope,
)
@@ -968,7 +965,6 @@ def test_pooling_branching_texture_params(remote, target,
executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
- static_memory_scope,
)
@@ -1111,7 +1107,6 @@ def test_branching_texture_params(remote, target,
executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
- static_memory_scope,
)
@@ -1212,7 +1207,6 @@ def test_conv2d_different_lowering_same_op(remote,
target, executor_type, dtype)
{"data": input_shape},
{"data": dtype},
target,
- static_memory_scope,
)
@@ -1380,7 +1374,6 @@ def test_injective_nwo_inputs1(remote, target,
executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
- static_memory_scope,
)
@@ -1495,7 +1488,6 @@ def test_injective_nwo_inputs2(remote, target,
executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
- static_memory_scope,
)
@@ -1534,5 +1526,68 @@ def test_conv2d_to_3_channels(remote, target,
executor_type, dtype):
)
[email protected]_opencl
[email protected]_targets("opencl -device=adreno")
+def test_conv2d_weight_on_buffers(remote, target, executor_type, dtype):
+ target = "opencl -device=adreno"
+ input_shape = (1, 64, 75, 75)
+ filter_shape = (64, 64, 3, 3)
+ bias_shape = (64,)
+ A = relay.var("data", shape=input_shape, dtype=dtype)
+ W = relay.var("weight", shape=filter_shape, dtype=dtype)
+ BS = relay.var("bias", shape=bias_shape, dtype=dtype)
+ conv = relay.nn.conv2d(A, W, padding=[1, 1, 1, 1], channels=64,
kernel_size=(3, 3))
+ conv = relay.nn.bias_add(conv, BS)
+ conv = relay.op.nn.relu(conv)
+
+ mod = relay.Function([A, W, BS], conv)
+ np.random.seed(0)
+ initializer = relay.testing.init.Xavier()
+ filter_data = np.zeros(filter_shape).astype(dtype)
+ bias_data = np.zeros(bias_shape).astype(dtype)
+ initializer("weight", filter_data)
+ initializer("bias", bias_data)
+ params1 = {
+ "weight": tvm.nd.array(filter_data),
+ "bias": tvm.nd.array(bias_data),
+ }
+
+ if executor_type == "ge":
+ static_memory_scope = [
+ "",
+ "global.texture",
+ "global",
+ "global.texture-weight",
+ "",
+ "",
+ ]
+ build_run_compare(
+ remote,
+ mod,
+ params1,
+ {"data": input_shape},
+ {"data": dtype},
+ target,
+ static_memory_scope,
+ )
+ else:
+ static_memory_scope = """
+ VM VirtualDevice[0]: device type 1, id 0 and mem_scope
+ VM VirtualDevice[1]: device type 4, id 0 and mem_scope
+ VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture
+ VM VirtualDevice[3]: device type 4, id 0 and mem_scope global
+ VM VirtualDevice[4]: device type 4, id 0 and mem_scope
global.texture-weight
+ """
+ build_run_compare_vm(
+ remote,
+ mod,
+ params1,
+ {"data": input_shape},
+ {"data": dtype},
+ target,
+ static_memory_scope,
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relay/opencl_texture/utils/adreno_utils.py
b/tests/python/relay/opencl_texture/utils/adreno_utils.py
index 309243df16..d9e52f8847 100644
--- a/tests/python/relay/opencl_texture/utils/adreno_utils.py
+++ b/tests/python/relay/opencl_texture/utils/adreno_utils.py
@@ -161,19 +161,11 @@ def build_run_compare_vm(
tvm_mod_nchwc, target=target, target_host=target_host,
params=params1
)
- # TODO(echuraev): enable scope checking
- ## verification that storage_scope has expected textures scopes
- # graph_json = json.loads(graph)
- # if "storage_scope" in graph_json["attrs"]:
- # assert (
- # len(static_mem_scopes) ==
len(graph_json["attrs"]["storage_scope"][1])
- # or len(static_mem_scopes) == 0
- # )
- # else:
- # assert len(static_mem_scopes) == 0
-
- # for i in range(0, len(static_mem_scopes)):
- # assert static_mem_scopes[i] ==
graph_json["attrs"]["storage_scope"][1][i]
+ if len(static_mem_scopes) > 0:
+ mem_scopes_lines = static_mem_scopes.strip().split("\n")
+ vm_lines = vmc._get_virtual_devices().strip().split("\n")
+ for i in range(0, len(mem_scopes_lines)):
+ assert mem_scopes_lines[i].strip() == vm_lines[i].strip()
if remote is None:
dev = tvm.opencl()