This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5deb95a947 [Adreno][OpenCL] Get rid of extra memory copy (#12286)
5deb95a947 is described below
commit 5deb95a9472002c9fa36a150f9e348f4276d63c5
Author: Egor Churaev <[email protected]>
AuthorDate: Fri Aug 12 05:53:39 2022 +0300
[Adreno][OpenCL] Get rid of extra memory copy (#12286)
* Add annotation pass for device_copy where we get buffers but expect
textures
* Fix issues with running device_copy
* Get rid of extra memory copy
* Fix build after cherry-picking
* Fix lint
* Fix CI
* Apply comments
Co-authored-by: Andrey Malyshev <[email protected]>
---
python/tvm/relay/op/strategy/adreno.py | 18 +++
python/tvm/topi/adreno/__init__.py | 1 +
python/tvm/topi/adreno/conv2d_nchw.py | 29 +++--
python/tvm/topi/adreno/conv2d_nhwc.py | 29 +++--
python/tvm/topi/adreno/depthwise_conv2d_nchw.py | 10 +-
python/tvm/topi/adreno/depthwise_conv2d_nhwc.py | 10 +-
python/tvm/topi/adreno/injective.py | 66 ++++++++++
python/tvm/topi/adreno/utils.py | 23 +++-
src/relay/transforms/annotate_texture_storage.cc | 152 +++++++++++++++++++----
src/runtime/opencl/opencl_device_api.cc | 29 ++++-
tests/python/relay/test_conv2d_nchw_texture.py | 8 +-
11 files changed, 312 insertions(+), 63 deletions(-)
diff --git a/python/tvm/relay/op/strategy/adreno.py
b/python/tvm/relay/op/strategy/adreno.py
index cc082c9d61..a537fa1e7b 100644
--- a/python/tvm/relay/op/strategy/adreno.py
+++ b/python/tvm/relay/op/strategy/adreno.py
@@ -257,3 +257,21 @@ def schedule_pool_adreno(attrs, outs, target):
if attrs.layout == "NCHW4c":
return topi.adreno.schedule_pool(outs, attrs.layout)
return topi.cuda.schedule_pool(outs, attrs.layout)
+
+
+@schedule_injective.register(["adreno"])
+def schedule_injective_adreno(attrs, outs, target):
+ """schedule injective ops for adreno"""
+ with target:
+ return topi.adreno.schedule_injective(outs)
+
+
+@concatenate_strategy.register(["adreno"])
+def concatenate_strategy_adreno(attrs, inputs, out_type, target):
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_concat(topi.transform.concatenate),
+ wrap_topi_schedule(topi.adreno.schedule_injective),
+ name="concatenate.adreno",
+ )
+ return strategy
diff --git a/python/tvm/topi/adreno/__init__.py
b/python/tvm/topi/adreno/__init__.py
index 57a9013b1a..227ca6aa9a 100644
--- a/python/tvm/topi/adreno/__init__.py
+++ b/python/tvm/topi/adreno/__init__.py
@@ -25,3 +25,4 @@ from .pooling import *
from .conv2d_alter_op import *
from .conv2d_nchw_winograd import *
from .conv2d_nhwc_winograd import *
+from .injective import schedule_injective
diff --git a/python/tvm/topi/adreno/conv2d_nchw.py
b/python/tvm/topi/adreno/conv2d_nchw.py
index 16ecaa84d0..65cd8e0150 100644
--- a/python/tvm/topi/adreno/conv2d_nchw.py
+++ b/python/tvm/topi/adreno/conv2d_nchw.py
@@ -279,28 +279,35 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
): # len(latest.op.axis) == 4:
# manage scheduling of datacopy
pad_data, kernel = s[conv].op.input_tensors
- pack_data = pad_data.op.input_tensors[0]
- bind_data_copy(s[pack_data])
+ if "pad_temp" in pad_data.op.name:
+ pack_data = pad_data.op.input_tensors[0]
+ bind_data_copy(s[pack_data])
+ else:
+ bind_data_copy(s[pad_data])
bind_data_copy(s[kernel])
pad_data, kernel = s[conv].op.input_tensors
- s[pad_data].compute_inline()
-
- s[conv].set_scope("local")
- if latest_blocked == latest and output != latest:
- s[output].compute_inline()
-
- # create cache stage
- AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
- bind_data_copy(s[AT])
if (
autotvm.GLOBAL_SCOPE.in_tuning
or isinstance(kernel.op, tvm.te.ComputeOp)
and "filter_pack" in kernel.op.tag
):
+ if "pad_temp" in pad_data.op.name:
+ s[pad_data].compute_inline()
+ AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape),
[conv])
+ bind_data_copy(s[AT])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])
+ elif "pad_temp" in pad_data.op.name:
+ s[pad_data].compute_inline()
+ # create cache stage
+ AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape),
[conv])
+ bind_data_copy(s[AT])
+
+ s[conv].set_scope("local")
+ if latest_blocked == latest and output != latest:
+ s[output].compute_inline()
# tile and bind spatial axes
n, fc, y, x, fb = s[latest_blocked].op.axis
diff --git a/python/tvm/topi/adreno/conv2d_nhwc.py
b/python/tvm/topi/adreno/conv2d_nhwc.py
index ce7bf0ccc9..b377169ca8 100644
--- a/python/tvm/topi/adreno/conv2d_nhwc.py
+++ b/python/tvm/topi/adreno/conv2d_nhwc.py
@@ -275,28 +275,35 @@ def schedule_conv2d_NHWC(cfg, s, output):
): # len(latest.op.axis) == 4:
# manage scheduling of datacopy
pad_data, kernel = s[conv].op.input_tensors
- pack_data = pad_data.op.input_tensors[0]
- bind_data_copy(s[pack_data])
+ if "pad_temp" in pad_data.op.name:
+ pack_data = pad_data.op.input_tensors[0]
+ bind_data_copy(s[pack_data])
+ else:
+ bind_data_copy(s[pad_data])
bind_data_copy(s[kernel])
pad_data, kernel = s[conv].op.input_tensors
- s[pad_data].compute_inline()
-
- s[conv].set_scope("local")
- if latest_blocked == latest and output != latest:
- s[output].compute_inline()
-
- # create cache stage
- AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
- bind_data_copy(s[AT])
if (
autotvm.GLOBAL_SCOPE.in_tuning
or isinstance(kernel.op, tvm.te.ComputeOp)
and "filter_pack" in kernel.op.tag
):
+ if "pad_temp" in pad_data.op.name:
+ s[pad_data].compute_inline()
+ AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape),
[conv])
+ bind_data_copy(s[AT])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])
+ elif "pad_temp" in pad_data.op.name:
+ s[pad_data].compute_inline()
+ # create cache stage
+ AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape),
[conv])
+ bind_data_copy(s[AT])
+
+ s[conv].set_scope("local")
+ if latest_blocked == latest and output != latest:
+ s[output].compute_inline()
# tile and bind spatial axes
n, y, x, fc, fb = s[latest_blocked].op.axis
diff --git a/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
b/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
index a11c3f3d36..37713b4584 100644
--- a/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
+++ b/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
@@ -253,13 +253,17 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output):
): # len(latest.op.axis) == 4:
# manage scheduling of datacopy
pad_data, kernel = s[conv].op.input_tensors
- pack_data = pad_data.op.input_tensors[0]
- bind_data_copy(s[pack_data])
+ if "pad_temp" in pad_data.op.name:
+ pack_data = pad_data.op.input_tensors[0]
+ bind_data_copy(s[pack_data])
+ else:
+ bind_data_copy(s[pad_data])
bind_data_copy(s[kernel])
pad_data, kernel = s[conv].op.input_tensors
- s[pad_data].compute_inline()
+ if "pad_temp" in pad_data.op.name:
+ s[pad_data].compute_inline()
s[conv].set_scope("local")
if latest_blocked == latest and output != latest:
diff --git a/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
b/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
index 117daf825d..2b228b444f 100644
--- a/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
+++ b/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
@@ -247,13 +247,17 @@ def schedule_depthwise_conv2d_NHWC_HWOI(cfg, s, output):
): # len(latest.op.axis) == 4:
# manage scheduling of datacopy
pad_data, kernel = s[conv].op.input_tensors
- pack_data = pad_data.op.input_tensors[0]
- bind_data_copy(s[pack_data])
+ if "pad_temp" in pad_data.op.name:
+ pack_data = pad_data.op.input_tensors[0]
+ bind_data_copy(s[pack_data])
+ else:
+ bind_data_copy(s[pad_data])
bind_data_copy(s[kernel])
pad_data, kernel = s[conv].op.input_tensors
- s[pad_data].compute_inline()
+ if "pad_temp" in pad_data.op.name:
+ s[pad_data].compute_inline()
s[conv].set_scope("local")
if latest_blocked == latest and output != latest:
diff --git a/python/tvm/topi/adreno/injective.py
b/python/tvm/topi/adreno/injective.py
new file mode 100644
index 0000000000..52ab0eab33
--- /dev/null
+++ b/python/tvm/topi/adreno/injective.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable,
+"""Schedule for composition of injective operator"""
+import tvm
+from tvm import te
+from .utils import bind_data_copy
+from .. import utils
+
+
+def schedule_injective_from_existing(sch, out):
+ """Schedule for injective op from existing schedule.
+
+ Parameters
+ ----------
+ sch: Schedule
+ The schedule to update.
+ out: Tensor
+ The tensor representing the injective op.
+
+ Returns
+ -------
+ sch: Schedule
+ The updated schedule.
+ """
+
+ bind_data_copy(sch[out])
+ return sch
+
+
+def schedule_injective(outs):
+ """Schedule for injective op.
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of injective in the format
+ of an array of tensors.
+
+ Returns
+ -------
+ sch: Schedule
+ The computation schedule for the op.
+ """
+ outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+ s = te.create_schedule([x.op for x in outs])
+
+ tvm.te.schedule.AutoInlineInjective(s)
+ for out in outs:
+ if not utils.is_empty_shape(out.shape):
+ schedule_injective_from_existing(s, out)
+ return s
diff --git a/python/tvm/topi/adreno/utils.py b/python/tvm/topi/adreno/utils.py
index ea19e7d77d..6ad5271744 100644
--- a/python/tvm/topi/adreno/utils.py
+++ b/python/tvm/topi/adreno/utils.py
@@ -474,7 +474,19 @@ def add_pad(
pad_after[x_axis] -= in_width + pad_before[x_axis] + pad_after[x_axis]
- input_latest_w
if input_latest_h < in_height + pad_before[y_axis] + pad_after[y_axis]:
pad_after[y_axis] -= in_height + pad_before[y_axis] +
pad_after[y_axis] - input_latest_h
- return nn.pad(data, pad_before, pad_after, name="pad_temp")
+ if (
+ pad_before[0] == 0
+ and pad_before[1] == 0
+ and pad_before[2] == 0
+ and pad_before[3] == 0
+ and pad_after[0] == 0
+ and pad_after[1] == 0
+ and pad_after[2] == 0
+ and pad_after[3] == 0
+ ):
+ return data
+ else:
+ return nn.pad(data, pad_before, pad_after, name="pad_temp")
def bind_data_copy(stage, axis_to_vectorize=None):
@@ -522,6 +534,15 @@ def bind_data_copy(stage, axis_to_vectorize=None):
stage.bind(thread, te.thread_axis("threadIdx.x"))
if shape[-1] == 4:
stage.vectorize(axes[-1])
+ # 1024 is the maximum work group size for Adreno devices.
+ # See: CL_DEVICE_MAX_WORK_GROUP_SIZE
+ elif shape[-1] > 1024:
+ ftc = numpy.prod(shape[:-1])
+ div = get_div(ftc, 1024)
+ by, ty = stage.split(axes[-1], factor=div)
+ stage.bind(fused, te.thread_axis("blockIdx.x"))
+ stage.bind(by, te.thread_axis("blockIdx.y"))
+ stage.bind(ty, te.thread_axis("threadIdx.y"))
else:
stage.bind(fused, te.thread_axis("blockIdx.x"))
stage.bind(*axes[-1:], te.thread_axis("threadIdx.x"))
diff --git a/src/relay/transforms/annotate_texture_storage.cc
b/src/relay/transforms/annotate_texture_storage.cc
index 3dd918d962..b3ed28db45 100644
--- a/src/relay/transforms/annotate_texture_storage.cc
+++ b/src/relay/transforms/annotate_texture_storage.cc
@@ -23,7 +23,7 @@
* storage scope related information.
*
* - CollectStorageInfo returns a mapping from relay expr
- * to a list of output storage scopes for each output.
+ * to a map of storage scopes for each call argument.
* These scopes are used during memory planning as well
* as downstream when doing codegen and in the graph runtime when doing
runtime dataspace
* allocations.
@@ -42,6 +42,8 @@
#include <memory>
#include <unordered_map>
+#include "../op/memory/device_copy.h"
+#include "../op/memory/memory.h"
#include "../transforms/device_aware_visitors.h"
namespace tvm {
@@ -55,15 +57,17 @@ class StorageInfo : private
transform::DeviceAwareExprVisitor {
public:
StorageInfo() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
- static Map<Expr, Array<String>> GetStorageMap(const Expr& expr) {
+ static Map<Expr, Map<Expr, Array<String>>> GetStorageMap(const Expr& expr) {
StorageInfo storage_info;
storage_info.VisitExpr(expr);
storage_info.LegalizeProducerStorage();
- Map<Expr, Array<String>> storage_map;
+ Map<Expr, Map<Expr, Array<String>>> storage_map =
storage_info.accept_textures_;
for (auto& kv : storage_info.storage_scope_) {
std::vector<String> storage_scopes;
std::copy(kv.second.begin(), kv.second.end(),
std::back_inserter(storage_scopes));
- storage_map.Set(GetRef<Expr>(kv.first), Array<String>{storage_scopes});
+ Map<Expr, Array<String>> ent;
+ ent.Set(Expr(), Array<String>{storage_scopes});
+ storage_map.Set(GetRef<Expr>(kv.first), ent);
}
// Filling the input arguments by "global" scope to handle PlanDevice algo
which propagates
@@ -75,7 +79,9 @@ class StorageInfo : private transform::DeviceAwareExprVisitor
{
// even without verification of the consumer's outputs scope
if (storage_info.CanConsumeTextures(cs.second) &&
storage_map.find(GetRef<Expr>(cs.first)) == storage_map.end()) {
- storage_map.Set(GetRef<Expr>(cs.first), Array<String>{"global"});
+ Map<Expr, Array<String>> ent;
+ ent.Set(Expr(), Array<String>{"global"});
+ storage_map.Set(GetRef<Expr>(cs.first), ent);
}
}
@@ -85,6 +91,25 @@ class StorageInfo : private
transform::DeviceAwareExprVisitor {
if (storage_map.count(a.first)) {
for (const auto& v : a.second) {
storage_map.Set(v, storage_map[a.first]);
+ if (storage_map[a.first][Expr()][0] == "global" &&
+ storage_info.accept_textures_.count(v)) {
+ Map<Expr, Array<String>> ent;
+ ent.Set(Expr(), storage_info.accept_textures_[v][Expr()]);
+ storage_map.Set(v, ent);
+ for (const auto& calls : storage_info.accept_textures_[v]) {
+ if (calls.first != Expr()) {
+ if (storage_map.count(a.first)) {
+ Map<Expr, Array<String>> ent_call = storage_map[a.first];
+ ent_call.Set(calls.first, calls.second);
+ storage_map.Set(a.first, ent_call);
+ } else {
+ Map<Expr, Array<String>> ent_call;
+ ent_call.Set(calls.first, calls.second);
+ storage_map.Set(a.first, ent_call);
+ }
+ }
+ }
+ }
}
}
}
@@ -109,6 +134,18 @@ class StorageInfo : private
transform::DeviceAwareExprVisitor {
void VisitExpr_(const ConstantNode* cn) final {
ApplyConsumerScopeToInputs(cn); }
+ void DeviceAwareVisitExpr_(const FunctionNode* function_node) final {
+ if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
+ for (auto&& param : function_node->params) {
+ auto virtual_device = GetVirtualDevice(param);
+ param->virtual_device_ =
+ VirtualDevice(virtual_device->device_type(),
virtual_device->virtual_device_id,
+ virtual_device->target, "global");
+ }
+ }
+ transform::DeviceAwareExprVisitor::DeviceAwareVisitExpr_(function_node);
+ }
+
void DeviceAwareVisitExpr_(const CallNode* call) final {
// Check the contents of this primitive function
if (const auto* fn = call->op.as<FunctionNode>()) {
@@ -135,6 +172,23 @@ class StorageInfo : private
transform::DeviceAwareExprVisitor {
}
for (size_t i = 0; i < fn->params.size(); i++) {
args_to_vars_[call->args[i]].push_back(fn->params[i]);
+ // adding info about arguments if they can be converted to texture
+ for (const auto& ttype :
FlattenTupleType(fn->params[i]->checked_type())) {
+ std::string scope = Scope(ttype->shape,
GetVirtualDevice(GetRef<Expr>(call)));
+ if (scope.find("global.texture") != std::string::npos) {
+ if (accept_textures_.count(fn->params[i])) {
+ Map<Expr, Array<String>> ent =
accept_textures_[fn->params[i]];
+ ent.Set(GetRef<Expr>(call), Array<String>{scope});
+ ent.Set(Expr(), Array<String>{scope});
+ accept_textures_.Set(fn->params[i], ent);
+ } else {
+ Map<Expr, Array<String>> ent;
+ ent.Set(GetRef<Expr>(call), Array<String>{scope});
+ ent.Set(Expr(), Array<String>{scope});
+ accept_textures_.Set(fn->params[i], ent);
+ }
+ }
+ }
}
}
// Add consumer storage scope information for call arguments
@@ -164,11 +218,6 @@ class StorageInfo : private
transform::DeviceAwareExprVisitor {
if (consumer_storage_scopes_.count(arg.operator->()) &&
GetConsumerScope(consumer_storage_scopes_[arg.operator->()]) !=
"global.texture") {
storage_scope_.erase(arg.operator->());
- if (const auto* cn = arg.as<CallNode>()) {
- if (const auto* fn = cn->op.as<FunctionNode>()) {
- storage_scope_.erase(fn->body.operator->());
- }
- }
}
}
}
@@ -336,6 +385,16 @@ class StorageInfo : private
transform::DeviceAwareExprVisitor {
if (attrs->layout == "NCHW4c") {
supports_texture_storage = true;
}
+ } else if (const OpNode* opnode = call->op.as<OpNode>()) {
+ auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+ auto pattern = fpattern[GetRef<Op>(opnode)];
+ if (pattern <= kInjective) {
+ if (const auto* ttype = call->checked_type().as<TensorTypeNode>()) {
+ if (ttype->shape.size() == 5) {
+ supports_texture_storage = true;
+ }
+ }
+ }
}
return supports_texture_storage;
@@ -350,6 +409,8 @@ class StorageInfo : private
transform::DeviceAwareExprVisitor {
std::unordered_map<const ExprNode*, std::vector<std::string>>
consumer_storage_scopes_;
/*! \brief mapping of arguments to call to function variables*/
std::unordered_map<Expr, std::vector<Var>, ObjectPtrHash, ObjectPtrEqual>
args_to_vars_;
+ /*! \brief mapping of arguments that can be converted to texture*/
+ Map<Expr, Map<Expr, Array<String>>> accept_textures_;
};
} // namespace
@@ -365,43 +426,84 @@ class RewriteVDStorageScopes : public
transform::DeviceAwareExprMutator {
using VarMap = std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual>;
public:
- explicit RewriteVDStorageScopes(const Map<Expr, Array<String>>&
storage_scope)
+ explicit RewriteVDStorageScopes(const Map<Expr, Map<Expr, Array<String>>>&
storage_scope)
: transform::DeviceAwareExprMutator(Optional<IRModule>()),
storage_scope_(storage_scope) {}
Function Rewrite(const Expr& expr) { return
Downcast<Function>(Mutate(expr)); }
Expr VisitExpr_(const VarNode* vn) final {
if (storage_scope_.find(GetRef<Expr>(vn)) != storage_scope_.end() &&
- storage_scope_[GetRef<Expr>(vn)][0] != "global") {
+ storage_scope_[GetRef<Expr>(vn)].find(Expr()) !=
storage_scope_[GetRef<Expr>(vn)].end() &&
+ storage_scope_[GetRef<Expr>(vn)][Expr()][0] != "global") {
Var c = Var(vn->vid, vn->type_annotation, vn->span);
auto virtual_device = GetVirtualDevice(GetRef<Expr>(vn));
c->virtual_device_ =
VirtualDevice(virtual_device->device_type(),
virtual_device->virtual_device_id,
- virtual_device->target,
storage_scope_[GetRef<Expr>(vn)][0]);
+ virtual_device->target,
storage_scope_[GetRef<Expr>(vn)][Expr()][0]);
return c;
}
return GetRef<Var>(vn);
}
Expr VisitExpr_(const ConstantNode* vn) final {
- if (storage_scope_.find(GetRef<Expr>(vn)) != storage_scope_.end()) {
+ if (storage_scope_.find(GetRef<Expr>(vn)) != storage_scope_.end() &&
+ storage_scope_[GetRef<Expr>(vn)].find(Expr()) !=
storage_scope_[GetRef<Expr>(vn)].end()) {
Expr c = Constant(vn->data, vn->span);
auto virtual_device = GetVirtualDevice(GetRef<Expr>(vn));
- c = OnDevice(c,
- VirtualDevice(virtual_device->device_type(),
virtual_device->virtual_device_id,
- virtual_device->target,
storage_scope_[GetRef<Expr>(vn)][0]),
- true);
+ c = OnDevice(
+ c,
+ VirtualDevice(virtual_device->device_type(),
virtual_device->virtual_device_id,
+ virtual_device->target,
storage_scope_[GetRef<Expr>(vn)][Expr()][0]),
+ true);
return c;
}
return GetRef<Constant>(vn);
}
Expr DeviceAwareVisitExpr_(const CallNode* call_node) final {
- auto new_call =
transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
+ // we need to duplicate ExprMutator::VisitExpr_ to correct argument scopes
and
+ // put device_copy
+ auto new_op = this->Mutate(call_node->op);
+
+ tvm::Array<Type> ty_args;
+ ty_args.reserve(call_node->type_args.size());
+
+ for (auto ty_arg : call_node->type_args) {
+ auto new_ty_arg = this->VisitType(ty_arg);
+ ty_args.push_back(new_ty_arg);
+ }
+
+ tvm::Array<Expr> call_args;
+ call_args.reserve(call_node->args.size());
+ for (auto arg : call_node->args) {
+ auto new_arg = this->Mutate(arg);
+ // verification if we need to put device_copy
+ if (storage_scope_.count(arg) &&
storage_scope_[arg].count(GetRef<Expr>(call_node))) {
+ auto virtual_device = GetVirtualDevice(GetRef<Expr>(call_node));
+ VirtualDevice virtual_device_from =
+ VirtualDevice(virtual_device->device_type(),
virtual_device->virtual_device_id,
+ virtual_device->target,
virtual_device->memory_scope);
+ VirtualDevice virtual_device_to =
+ VirtualDevice(virtual_device->device_type(),
virtual_device->virtual_device_id,
+ virtual_device->target,
storage_scope_[arg][GetRef<Expr>(call_node)][0]);
+ new_arg = DeviceCopy(new_arg, virtual_device_from, virtual_device_to);
+ new_arg = OnDevice(
+ new_arg,
+ VirtualDevice(virtual_device->device_type(),
virtual_device->virtual_device_id,
+ virtual_device->target,
storage_scope_[arg][GetRef<Expr>(call_node)][0]),
+ true);
+ }
+ call_args.push_back(new_arg);
+ }
+
+ auto new_call = WithFields(GetRef<Call>(call_node), new_op, call_args, {},
ty_args);
+
auto virtual_device = GetVirtualDevice(GetRef<Expr>(call_node));
std::string memory_scope = "";
- if (storage_scope_.find(GetRef<Expr>(call_node)) != storage_scope_.end()) {
- memory_scope = storage_scope_[GetRef<Expr>(call_node)][0];
+ if (storage_scope_.find(GetRef<Expr>(call_node)) != storage_scope_.end() &&
+ storage_scope_[GetRef<Expr>(call_node)].find(Expr()) !=
+ storage_scope_[GetRef<Expr>(call_node)].end()) {
+ memory_scope = storage_scope_[GetRef<Expr>(call_node)][Expr()][0];
} else if (virtual_device->memory_scope != "") {
memory_scope = virtual_device->memory_scope;
} else if (!call_node->op.as<FunctionNode>()) {
@@ -418,12 +520,12 @@ class RewriteVDStorageScopes : public
transform::DeviceAwareExprMutator {
}
private:
- Map<Expr, Array<String>> storage_scope_;
+ Map<Expr, Map<Expr, Array<String>>> storage_scope_;
VarMap new_vars_;
Array<String> current_function_scope_;
};
-Map<Expr, Array<String>> CollectTextureStorage(const Expr& expr) {
+Map<Expr, Map<Expr, Array<String>>> CollectTextureStorage(const Expr& expr) {
return StorageInfo::GetStorageMap(expr);
}
@@ -479,7 +581,7 @@ class CollectVirtualDevices : public
transform::DeviceAwareExprVisitor {
* \param expr The expression.
* \return The device based storage mapping.
*/
-Map<Expr, Array<String>> CollectStorageInfo(const Expr& expr) {
+Map<Expr, Map<Expr, Array<String>>> CollectStorageInfo(const Expr& expr) {
std::set<std::string> device_types =
CollectVirtualDevices().GetDevices(expr);
// TODO(amalyshe): current approach collects all targets withing graph and
call the only
// function corresponding to all these targets in alphabetic order
@@ -490,7 +592,7 @@ Map<Expr, Array<String>> CollectStorageInfo(const Expr&
expr) {
ftarget_prefix += (std::string(".") + dev_id);
}
- Map<Expr, Array<String>> storage_info = {};
+ Map<Expr, Map<Expr, Array<String>>> storage_info = {};
if (const auto* f = runtime::Registry::Get(ftarget_prefix +
"._CollectStorageInfo")) {
storage_info = (*f)(expr);
}
diff --git a/src/runtime/opencl/opencl_device_api.cc
b/src/runtime/opencl/opencl_device_api.cc
index cea0acc07c..d67864287d 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -273,12 +273,31 @@ void OpenCLWorkspace::CopyDataFromTo(DLTensor* from,
DLTensor* to, TVMStreamHand
if (IsOpenCLDevice(from->device) && IsOpenCLDevice(to->device)) {
const auto* from_desc = static_cast<const
cl::BufferDescriptor*>(from->data);
- ICHECK(from_desc->layout == cl::BufferDescriptor::MemoryLayout::kBuffer1D)
- << "Device to device copying is currently only implemented for OpenCL
buffer storage";
auto* to_desc = static_cast<cl::BufferDescriptor*>(to->data);
- OPENCL_CALL(clEnqueueCopyBuffer(this->GetQueue(to->device),
from_desc->buffer, to_desc->buffer,
- from->byte_offset, to->byte_offset,
nbytes, 0, nullptr,
- nullptr));
+ if (to_desc->layout == cl::BufferDescriptor::MemoryLayout::kBuffer1D &&
+ from_desc->layout == cl::BufferDescriptor::MemoryLayout::kBuffer1D) {
+ OPENCL_CALL(clEnqueueCopyBuffer(this->GetQueue(to->device),
from_desc->buffer,
+ to_desc->buffer, from->byte_offset,
to->byte_offset, nbytes,
+ 0, nullptr, nullptr));
+ } else if (to_desc->layout !=
cl::BufferDescriptor::MemoryLayout::kBuffer1D &&
+ from_desc->layout ==
cl::BufferDescriptor::MemoryLayout::kBuffer1D) {
+ auto image_info = GetImageInfo(to_desc, to);
+ OPENCL_CALL(clEnqueueCopyBufferToImage(this->GetQueue(to->device),
from_desc->buffer,
+ to_desc->buffer,
from->byte_offset, image_info.origin,
+ image_info.region, 0, nullptr,
nullptr));
+ } else if (to_desc->layout ==
cl::BufferDescriptor::MemoryLayout::kBuffer1D &&
+ from_desc->layout !=
cl::BufferDescriptor::MemoryLayout::kBuffer1D) {
+ auto image_info = GetImageInfo(from_desc, from);
+ OPENCL_CALL(clEnqueueCopyImageToBuffer(this->GetQueue(to->device),
from_desc->buffer,
+ to_desc->buffer,
image_info.origin, image_info.region,
+ to->byte_offset, 0, nullptr,
nullptr));
+ } else {
+ auto to_image_info = GetImageInfo(to_desc, to);
+ auto from_image_info = GetImageInfo(from_desc, from);
+ OPENCL_CALL(clEnqueueCopyImage(this->GetQueue(to->device),
from_desc->buffer, to_desc->buffer,
+ from_image_info.origin,
to_image_info.origin,
+ to_image_info.region, 0, nullptr,
nullptr));
+ }
} else if (IsOpenCLDevice(from->device) && to->device.device_type == kDLCPU)
{
const auto* from_desc = static_cast<const
cl::BufferDescriptor*>(from->data);
switch (from_desc->layout) {
diff --git a/tests/python/relay/test_conv2d_nchw_texture.py
b/tests/python/relay/test_conv2d_nchw_texture.py
index 58590998fd..6eadd8fc1c 100644
--- a/tests/python/relay/test_conv2d_nchw_texture.py
+++ b/tests/python/relay/test_conv2d_nchw_texture.py
@@ -590,8 +590,8 @@ def test_residual_block():
}
static_memory_scope = [
- "",
"global",
+ "global.texture",
"global.texture-weight",
"global.texture-weight",
"global.texture",
@@ -830,8 +830,8 @@ def test_pooling_branching_texture_params():
}
static_memory_scope = [
- "",
"global",
+ "global.texture",
"global.texture-weight",
"global.texture",
"global.texture",
@@ -957,8 +957,8 @@ def test_branching_texture_params():
}
static_memory_scope = [
- "",
"global",
+ "global.texture",
"global.texture-weight",
"global.texture",
"global.texture-weight",
@@ -1046,8 +1046,8 @@ def test_conv2d_different_lowering_same_op():
}
static_memory_scope = [
- "",
"global",
+ "global.texture",
"global.texture-weight",
"global.texture",
"global.texture",