This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 34cacb0a64 [VM][Textures] Enable OpenCL textures for VM  (#15419)
34cacb0a64 is described below

commit 34cacb0a6487d6300ffc8bf0bd018879d0a8d548
Author: Egor Churaev <[email protected]>
AuthorDate: Tue Aug 8 17:56:55 2023 +0300

    [VM][Textures] Enable OpenCL textures for VM  (#15419)
    
    * [VM][Textures] Enable OpenCL textures for VM
    
    This commit introduces memory scope to VM and enables using textures.
    
    The following changes have been made:
      - AnnotateMemoryScope pass is used in VM compilation pipeline
      - VM allows to use more than one device with the same device type.
        Also, virtual devices in VM contains information about memory
        scope.
      - Instructions LoadConst and AllocStorage were extended to support
        textures.
      - VM bytecode was updated to support memory scope.
      - Annotate texture storage pass was updated to support dynamic shape.
      - Some other minor changes have been made.
    
    * Implement tests for vm
    
    * Fix lint
    
    * Fix tests
    
    * Use union in allocate_storage struct
    
    * Apply comments
    
    * Fix copy ctor and assignment operator
---
 include/tvm/runtime/ndarray.h                      |   8 +-
 include/tvm/runtime/vm/bytecode.h                  |  19 +-
 include/tvm/runtime/vm/executable.h                |   5 +-
 src/relay/backend/vm/compiler.cc                   |  41 ++-
 src/relay/backend/vm/manifest_lifetimes.cc         |   4 +-
 src/relay/op/memory/memory.cc                      |  20 +-
 src/relay/op/memory/memory.h                       |   5 +-
 src/relay/transforms/annotate_texture_storage.cc   |   9 +
 src/relay/transforms/device_domains.cc             |   7 +-
 src/relay/transforms/memory_alloc.cc               |   4 +-
 src/runtime/c_runtime_api.cc                       |   2 +-
 src/runtime/opencl/opencl_device_api.cc            |   2 +-
 src/runtime/vm/bytecode.cc                         |  56 +++-
 src/runtime/vm/executable.cc                       |  34 +-
 src/runtime/vm/profiler/vm.cc                      |  16 +-
 src/runtime/vm/vm.cc                               |  47 ++-
 .../opencl_texture/test_conv2d_nchw_texture.py     | 361 ++++++++++++++++-----
 .../opencl_texture/test_conv2d_nhwc_texture.py     | 245 ++++++++++----
 .../test_depthwise_conv2d_nchw_texture.py          |  52 ++-
 .../test_depthwise_conv2d_nhwc_texture.py          |  50 ++-
 .../relay/opencl_texture/test_injection_texture.py |  33 +-
 tests/python/relay/opencl_texture/test_network.py  |  24 +-
 .../relay/opencl_texture/test_pool_texture.py      |  63 +++-
 .../relay/opencl_texture/test_reduction_texture.py |  87 +++--
 .../relay/opencl_texture/utils/adreno_utils.py     |  84 +++++
 .../relay/test_pass_dead_code_elimination.py       |  18 +-
 tests/python/relay/test_pass_plan_devices.py       |  11 +-
 27 files changed, 996 insertions(+), 311 deletions(-)

diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h
index 119d0f7fd3..2a06856fea 100644
--- a/include/tvm/runtime/ndarray.h
+++ b/include/tvm/runtime/ndarray.h
@@ -110,9 +110,10 @@ class NDArray : public ObjectRef {
   /*!
    * \brief Copy the data to another device.
    * \param dev The target device.
+   * \param mem_scope The memory scope of the target array.
    * \return The array under another device.
    */
-  inline NDArray CopyTo(const Device& dev) const;
+  inline NDArray CopyTo(const Device& dev, Optional<String> mem_scope = 
NullOpt) const;
   /*!
    * \brief Load NDArray from stream
    * \param stream The input data stream
@@ -398,10 +399,11 @@ inline void NDArray::CopyTo(const NDArray& other) const {
   CopyFromTo(&(get_mutable()->dl_tensor), &(other.get_mutable()->dl_tensor));
 }
 
-inline NDArray NDArray::CopyTo(const Device& dev) const {
+inline NDArray NDArray::CopyTo(const Device& dev, Optional<String> mem_scope) 
const {
   ICHECK(data_ != nullptr);
   const DLTensor* dptr = operator->();
-  NDArray ret = Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), 
dptr->dtype, dev);
+  NDArray ret =
+      Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, 
dev, mem_scope);
   this->CopyTo(ret);
   return ret;
 }
diff --git a/include/tvm/runtime/vm/bytecode.h 
b/include/tvm/runtime/vm/bytecode.h
index 2fe855f964..637c1e70a7 100644
--- a/include/tvm/runtime/vm/bytecode.h
+++ b/include/tvm/runtime/vm/bytecode.h
@@ -157,6 +157,8 @@ struct Instruction {
     struct /* LoadConst Operands */ {
       /* \brief The index into the constant pool. */
       Index const_index;
+      /*! \brief The index of the device on which the load will be made. */
+      Index device_index;
     };
     struct /* LoadConsti Operands */ {
       /* \brief The index into the constant pool. */
@@ -195,12 +197,18 @@ struct Instruction {
       RegName* free_vars;
     };
     struct /* AllocStorage Operands */ {
-      /*! \brief The size of the allocation. */
-      RegName allocation_size;
       /*! \brief The alignment of the allocation. */
       Index alignment;
       /*! \brief The hint of the dtype. */
       DLDataType dtype_hint;
+      /*! \brief The number of dimensions. */
+      uint32_t ndim;
+      union {
+        /*! \brief The shape of tensor. */
+        int64_t* shape;
+        /*! \brief The size of the allocation. */
+        RegName allocation_size;
+      };
       /*! \brief The index of the device on which the allocation will be made. 
*/
       Index device_index;
     } alloc_storage;
@@ -332,10 +340,11 @@ struct Instruction {
   /*!
    * \brief Construct a load constant instruction.
    * \param const_index The index of the constant.
+   * \param device_index The index of the device to load on.
    * \param dst The destination register.
    * \return The load constant instruction.
    */
-  static Instruction LoadConst(Index const_index, RegName dst);
+  static Instruction LoadConst(Index const_index, Index device_index, RegName 
dst);
   /*!
    * \brief Construct a load_constanti instruction.
    * \param val The interger constant value.
@@ -356,11 +365,13 @@ struct Instruction {
    * \param alignment The allocation's alignment.
    * \param dtype_hint The data type hint for the allocator.
    * \param device_index The index of the device to allocate on.
+   * \param shape The shape of the allocation.
    * \param dst The destination to place the storage.
    * \return The alloc storage instruction.
    */
   static Instruction AllocStorage(RegName size, Index alignment, DLDataType 
dtype_hint,
-                                  Index device_index, RegName dst);
+                                  Index device_index, const 
std::vector<int64_t>& shape,
+                                  RegName dst);
   /*!
    * \brief Get the shape of an input tensor.
    * \param tensor The input tensor.
diff --git a/include/tvm/runtime/vm/executable.h 
b/include/tvm/runtime/vm/executable.h
index 0714847400..d4872837b0 100644
--- a/include/tvm/runtime/vm/executable.h
+++ b/include/tvm/runtime/vm/executable.h
@@ -34,6 +34,7 @@
 #include <map>
 #include <string>
 #include <unordered_map>
+#include <utility>
 #include <vector>
 
 namespace tvm {
@@ -262,9 +263,9 @@ class TVM_DLL Executable : public ModuleNode {
 
   /*!
    * \brief The (compile-time, virtual) devices corresponding to each device 
index.
-   * Currently we only support at most one device per device type.
+   * This vector contains a pair Device and its memory_scope.
    */
-  std::vector<Device> virtual_devices;
+  std::vector<std::pair<Device, std::string>> virtual_devices;
   /*!
    * \brief The device index corresponding to the 'host' device. That will 
hold and evaluate
    * shape-related data and code.
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index c5b6c7f2f0..848c23eba6 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -352,19 +352,6 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
       return 0;
     }
 
-    // However, otherwise we allow at most one VirtualDevice per device type.
-    // TODO(mbs): This will eventually need to account for memory scopes 
somehow so device_copy
-    // instructions can do the right thing.
-    itr = std::find_if(context_->virtual_devices_.begin() + 1, 
context_->virtual_devices_.end(),
-                       [&virtual_device](const VirtualDevice& 
existing_virtual_device) {
-                         return existing_virtual_device->device_type() ==
-                                virtual_device->device_type();
-                       });
-    CHECK(itr == context_->virtual_devices_.end())
-        << "The VM does not currently support using more than one device with 
the same device type "
-           "for primitives, however the program is using the distinct scopes "
-        << virtual_device << " and " << *itr << " of device type " << 
virtual_device->device_type();
-
     ICHECK(virtual_device != host_virtual_device_);
     Index index = context_->virtual_devices_.size();
     VLOG(2) << "virtual_device[" << index << "] = " << virtual_device;
@@ -384,7 +371,7 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
     VLOG(2) << "constant[" << const_index << "] on device[" << device_index << 
"]";
     context_->const_device_indexes.push_back(device_index);
     context_->constants.push_back(const_node->data);
-    Emit(Instruction::LoadConst(const_index, NewRegister()));
+    Emit(Instruction::LoadConst(const_index, device_index, NewRegister()));
   }
 
   void VisitExpr_(const VarNode* var_node) final {
@@ -602,13 +589,21 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
                  })
           .Match("memory.alloc_storage",
                  [this](const Array<Expr>& args, const Attrs& attrs, const 
Array<Type>& type_arg) {
-                   ICHECK_EQ(args.size(), 2);
+                   ICHECK_EQ(args.size(), 3);
                    // Compute the size of the allocation.
                    this->VisitExpr(args[0]);
                    auto size_register = last_register_;
 
-                   ICHECK(args[1].as<ConstantNode>());  // Always a literal.
-                   NDArray alignment_arr = args[1].as<ConstantNode>()->data;
+                   auto const_shape = 
AsIgnoringOnDevice<ConstantNode>(args[1]);
+                   std::vector<int64_t> raw_shape;
+                   if (const_shape) {
+                     NDArray shape = const_shape->data;
+                     // TODO(@jroesch): we need to get an RFC done to 
standarize shape dtype
+                     raw_shape = ToAllocTensorShape(shape);
+                   }
+
+                   ICHECK(args[2].as<ConstantNode>());  // Always a literal.
+                   NDArray alignment_arr = args[2].as<ConstantNode>()->data;
                    ICHECK_EQ(alignment_arr->dtype.code, 0U)
                        << "The dtype of constant shape must be int32 or int64, 
but got "
                        << DLDataType2String(alignment_arr->dtype);
@@ -622,7 +617,7 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
 
                    Emit(Instruction::AllocStorage(size_register, alignment, 
dtype,
                                                   
GetDeviceIndex(alloc_attrs->virtual_device),
-                                                  NewRegister()));
+                                                  raw_shape, NewRegister()));
                  })
           .Match("vm.shape_of",
                  [this](const Array<Expr>& args, const Attrs& attrs, const 
Array<Type>& type_arg) {
@@ -739,7 +734,7 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
 
   /*!
    * \brief Compile a match value
-   * Generate byte code that compute the value specificed in val
+   * Generate byte code that compute the value specified in val
    *
    * \return The register number assigned for the final value
    */
@@ -946,9 +941,10 @@ void VMCompiler::LowerImpl(IRModule mod) {
   for (const auto& virtual_device : context_.virtual_devices_) {
     ICHECK(!virtual_device->IsFullyUnconstrained());
     ICHECK_GT(virtual_device->device_type(), 0);
-    // TODO(mbs): We forget the memory scope.
-    
exec_->virtual_devices.push_back(Device{/*device_type=*/virtual_device->device_type(),
-                                            
/*device_id=*/virtual_device->virtual_device_id});
+    exec_->virtual_devices.push_back(
+        std::make_pair(Device{/*device_type=*/virtual_device->device_type(),
+                              /*device_id=*/virtual_device->virtual_device_id},
+                       virtual_device->memory_scope));
   }
   exec_->host_device_index = kHostDeviceIndex;
 
@@ -1068,6 +1064,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
   }
 
   pass_seqs.push_back(transform::FuseOps());
+  pass_seqs.push_back(transform::AnnotateMemoryScope());
 
   // Do layout rewrite for auto-scheduler.
   transform::PassContext pass_ctx = PassContext::Current();
diff --git a/src/relay/backend/vm/manifest_lifetimes.cc 
b/src/relay/backend/vm/manifest_lifetimes.cc
index 7028c88f2e..892648d678 100644
--- a/src/relay/backend/vm/manifest_lifetimes.cc
+++ b/src/relay/backend/vm/manifest_lifetimes.cc
@@ -167,7 +167,9 @@ class AliasEliminator : public MixedModeMutator {
           if (copy_props.src_virtual_device->device_type() ==
                   copy_props.dst_virtual_device->device_type() &&
               copy_props.src_virtual_device->virtual_device_id ==
-                  copy_props.dst_virtual_device->virtual_device_id) {
+                  copy_props.dst_virtual_device->virtual_device_id &&
+              copy_props.src_virtual_device->memory_scope ==
+                  copy_props.dst_virtual_device->memory_scope) {
             Expr to_copy = Downcast<Call>(unwrapped)->args[0];
             if (const VarNode* alias_of_n = to_copy.as<VarNode>()) {
               alias_[var] = Downcast<Var>(VisitExpr_(alias_of_n));
diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc
index 6535156205..008dbff841 100644
--- a/src/relay/op/memory/memory.cc
+++ b/src/relay/op/memory/memory.cc
@@ -50,25 +50,32 @@ TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);
 // The passing value in attrs and args doesn't seem super great.
 // We should consider a better solution, i.e the type relation
 // being able to see the arguments as well?
-Expr AllocStorage(Expr size, Expr alignment, VirtualDevice virtual_device, 
DataType dtype_hint) {
+Expr AllocStorage(Expr size, Expr shape, Expr alignment, VirtualDevice 
virtual_device,
+                  DataType dtype_hint) {
   auto attrs = make_object<AllocStorageAttrs>();
   attrs->dtype = dtype_hint;
   attrs->virtual_device = std::move(virtual_device);
   static const Op& op = Op::Get("memory.alloc_storage");
-  return Call(op, {std::move(size), std::move(alignment)}, 
Attrs(std::move(attrs)), {});
+  return Call(op, {std::move(size), std::move(shape), std::move(alignment)},
+              Attrs(std::move(attrs)), {});
 }
 
 
TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage").set_body_typed(AllocStorage);
 
 bool AllocStorageRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
                      const TypeReporter& reporter) {
-  ICHECK_EQ(types.size(), 3u);
+  ICHECK_EQ(types.size(), 4u);
   auto size_type = types[0];
   auto tensor_type = size_type.as<TensorTypeNode>();
   ICHECK(tensor_type != nullptr);
   ICHECK_EQ(tensor_type->dtype, DataType::Int(64));
   ICHECK_EQ(tensor_type->shape.size(), 0);
-  auto align_type = types[1];
+
+  // Tensor shape
+  auto tt = types[1].as<TensorTypeNode>();
+  ICHECK(tt != nullptr) << "must be tensor type";
+
+  auto align_type = types[2];
   auto align_ttype = align_type.as<TensorTypeNode>();
   ICHECK(align_ttype != nullptr);
   ICHECK_EQ(align_ttype->dtype, DataType::Int(64));
@@ -77,14 +84,15 @@ bool AllocStorageRel(const Array<Type>& types, int 
num_inputs, const Attrs& attr
   ICHECK(mod.defined());
   auto storage_name = mod->GetGlobalTypeVar("Storage");
   auto storage = TypeCall(storage_name, {});
-  reporter->Assign(types[2], storage);
+  reporter->Assign(types[3], storage);
   return true;
 }
 
 RELAY_REGISTER_OP("memory.alloc_storage")
     .describe(R"code(Explicitly allocate storage to be used by tensors.)code" 
TVM_ADD_FILELINE)
-    .set_num_inputs(2)
+    .set_num_inputs(3)
     .add_argument("size", "Tensor", "The size of the storage to allocate.")
+    .add_argument("shape", "Tensor", "The shape of the storage to allocate.")
     .add_argument("alignment", "Tensor", "The alignment of the storage.")
     .add_type_rel("AllocStorage", AllocStorageRel)
     .set_attrs_type_key("relay.attrs.AllocStorageAttrs")
diff --git a/src/relay/op/memory/memory.h b/src/relay/op/memory/memory.h
index 690854c382..5533553393 100644
--- a/src/relay/op/memory/memory.h
+++ b/src/relay/op/memory/memory.h
@@ -34,10 +34,11 @@
 namespace tvm {
 namespace relay {
 
-Expr AllocStorage(Expr size, Expr alignment, VirtualDevice virtual_device, 
DataType dtype_hint);
+Expr AllocStorage(Expr size, Expr shape, Expr alignment, VirtualDevice 
virtual_device,
+                  DataType dtype_hint);
 /*! \brief Returns the "memory.alloc_tensor" operator. */
 const Op& MemoryAllocTensorOp();
-Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType 
dtype,
+Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype,
                  Array<IndexExpr> assert_shape);
 Expr ToTupleType(const Type& ty, const std::vector<Expr>& exprs);
 std::vector<Expr> FromTupleType(const Type& type, const Expr& expr);
diff --git a/src/relay/transforms/annotate_texture_storage.cc 
b/src/relay/transforms/annotate_texture_storage.cc
index d3748449ad..4921cef4c8 100644
--- a/src/relay/transforms/annotate_texture_storage.cc
+++ b/src/relay/transforms/annotate_texture_storage.cc
@@ -407,6 +407,15 @@ class StorageInfo : private 
transform::DeviceAwareExprVisitor {
       if (pattern <= kCommReduce) {
         if (const auto* ttype = call->checked_type().as<TensorTypeNode>()) {
           if (ttype->shape.size() == 5) {
+            auto node0 = ttype->shape[0].as<IntImmNode>();
+            auto node1 = ttype->shape[1].as<IntImmNode>();
+            auto node2 = ttype->shape[2].as<IntImmNode>();
+            auto node3 = ttype->shape[3].as<IntImmNode>();
+            auto node4 = ttype->shape[4].as<IntImmNode>();
+            // if tensor has any dimension then textures are not supported
+            if (!node0 || !node1 || !node2 || !node3 || !node4) {
+              return false;
+            }
             supports_texture_storage = true;
           }
         }
diff --git a/src/relay/transforms/device_domains.cc 
b/src/relay/transforms/device_domains.cc
index e7d3a65dfe..e2af20022a 100644
--- a/src/relay/transforms/device_domains.cc
+++ b/src/relay/transforms/device_domains.cc
@@ -236,12 +236,13 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const 
Call& call) {
     
args_and_result.emplace_back(ForVirtualDevice(device_copy_props.body->checked_type(),
                                                   
device_copy_props.dst_virtual_device));
   } else if (call->op == alloc_storage_op) {
-    ICHECK_EQ(call->args.size(), 2U);
-    // alloc_storage(size, alignment, virtual_device=<t>)
-    // alloc_storage: fn(<cpu>, <cpu>):<t>
+    ICHECK_EQ(call->args.size(), 3U);
+    // alloc_storage(size, shape, alignment, virtual_device=<t>)
+    // alloc_storage: fn(<cpu>, <cpu>, <cpu>):<t>
     const auto* attrs = call->attrs.as<AllocStorageAttrs>();
     args_and_result.emplace_back(host_domain_);
     args_and_result.emplace_back(host_domain_);
+    args_and_result.emplace_back(host_domain_);
     args_and_result.emplace_back(ForVirtualDevice(call->checked_type(), 
attrs->virtual_device));
   } else if (call->op == alloc_tensor_op) {
     ICHECK_EQ(call->args.size(), 3U);
diff --git a/src/relay/transforms/memory_alloc.cc 
b/src/relay/transforms/memory_alloc.cc
index 5b584e199d..fcf8a784a9 100644
--- a/src/relay/transforms/memory_alloc.cc
+++ b/src/relay/transforms/memory_alloc.cc
@@ -260,7 +260,7 @@ class DialectRewriter : public 
transform::DeviceAwareExprMutator {
     Expr alignment = ComputeAlignment(type->dtype);
     // Run type inference later to get the correct type.
     Var var("storage_" + name_hint, Type(nullptr));
-    Expr value = AllocStorage(size, alignment, virtual_device, type->dtype);
+    Expr value = AllocStorage(size, shape, alignment, virtual_device, 
type->dtype);
     auto sto = scope->Push(var, MaybeOnDeviceFixed(value, virtual_device));
 
     // TODO(@jroesch): There is a bug with typing based on the constant shape.
@@ -366,7 +366,7 @@ class DialectRewriter : public 
transform::DeviceAwareExprMutator {
       // Alignment is directly captured in the instruction so don't wrap in 
"on_device".
       auto alignment = ComputeAlignment(out_type->dtype);
       Var sto_var("storage_" + std::to_string(i), Type(nullptr));
-      auto val = AllocStorage(size, alignment, virtual_device, 
out_type->dtype);
+      auto val = AllocStorage(size, out_shape, alignment, virtual_device, 
out_type->dtype);
       storages.push_back(scope->Push(sto_var, MaybeOnDeviceFixed(val, 
virtual_device)));
     }
 
diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc
index 0132e9009c..d7739b7b22 100644
--- a/src/runtime/c_runtime_api.cc
+++ b/src/runtime/c_runtime_api.cc
@@ -152,7 +152,7 @@ static size_t GetDataAlignment(const DLDataType dtype) {
 
 void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, 
DLDataType dtype,
                                 Optional<String> mem_scope) {
-  if (!mem_scope.defined() || mem_scope.value() == "global") {
+  if (!mem_scope.defined() || mem_scope.value() == "" || mem_scope.value() == 
"global") {
     // by default, we can always redirect to the flat memory allocations
     DLTensor temp;
     temp.data = nullptr;
diff --git a/src/runtime/opencl/opencl_device_api.cc 
b/src/runtime/opencl/opencl_device_api.cc
index 0d1f4af2bb..35e77eb6d1 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -239,7 +239,7 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t 
size, size_t alignment,
 
 void* OpenCLWorkspace::AllocDataSpace(Device dev, int ndim, const int64_t* 
shape, DLDataType dtype,
                                       Optional<String> mem_scope) {
-  if (!mem_scope.defined() || mem_scope.value() == "global") {
+  if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() 
== "global") {
     return DeviceAPI::AllocDataSpace(dev, ndim, shape, dtype, mem_scope);
   }
   ICHECK(IsTextureStorage(std::string(mem_scope.value())))
diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc
index 424dfe87c7..dc52e8c8f0 100644
--- a/src/runtime/vm/bytecode.cc
+++ b/src/runtime/vm/bytecode.cc
@@ -99,6 +99,7 @@ Instruction::Instruction(const Instruction& instr) {
       return;
     case Opcode::LoadConst:
       this->const_index = instr.const_index;
+      this->device_index = instr.device_index;
       return;
     case Opcode::LoadConsti:
       this->load_consti = instr.load_consti;
@@ -114,7 +115,15 @@ Instruction::Instruction(const Instruction& instr) {
       this->pc_offset = instr.pc_offset;
       return;
     case Opcode::AllocStorage:
-      this->alloc_storage = instr.alloc_storage;
+      this->alloc_storage.allocation_size = 
instr.alloc_storage.allocation_size;
+      this->alloc_storage.alignment = instr.alloc_storage.alignment;
+      this->alloc_storage.dtype_hint = instr.alloc_storage.dtype_hint;
+      this->alloc_storage.device_index = instr.alloc_storage.device_index;
+      this->alloc_storage.ndim = instr.alloc_storage.ndim;
+      if (this->alloc_storage.ndim > 0) {
+        this->alloc_storage.shape =
+            Duplicate<int64_t>(instr.alloc_storage.shape, 
instr.alloc_storage.ndim);
+      }
       return;
     case Opcode::ShapeOf:
       this->shape_of.tensor = instr.shape_of.tensor;
@@ -207,6 +216,7 @@ Instruction& Instruction::operator=(const Instruction& 
instr) {
       return *this;
     case Opcode::LoadConst:
       this->const_index = instr.const_index;
+      this->device_index = instr.device_index;
       return *this;
     case Opcode::GetField:
       this->object = instr.object;
@@ -219,7 +229,15 @@ Instruction& Instruction::operator=(const Instruction& 
instr) {
       this->pc_offset = instr.pc_offset;
       return *this;
     case Opcode::AllocStorage:
-      this->alloc_storage = instr.alloc_storage;
+      this->alloc_storage.allocation_size = 
instr.alloc_storage.allocation_size;
+      this->alloc_storage.alignment = instr.alloc_storage.alignment;
+      this->alloc_storage.dtype_hint = instr.alloc_storage.dtype_hint;
+      this->alloc_storage.device_index = instr.alloc_storage.device_index;
+      this->alloc_storage.ndim = instr.alloc_storage.ndim;
+      if (this->alloc_storage.ndim > 0) {
+        this->alloc_storage.shape =
+            Duplicate<int64_t>(instr.alloc_storage.shape, 
instr.alloc_storage.ndim);
+      }
       return *this;
     case Opcode::ShapeOf:
       this->shape_of.tensor = instr.shape_of.tensor;
@@ -250,13 +268,17 @@ Instruction::~Instruction() {
     case Opcode::GetTag:
     case Opcode::Goto:
     case Opcode::LoadConsti:
-    case Opcode::AllocStorage:
     case Opcode::ShapeOf:
     case Opcode::ReshapeTensor:
     case Opcode::DeviceCopy:
     case Opcode::Fatal:
     case Opcode::KillRegister:
       return;
+    case Opcode::AllocStorage:
+      if (this->alloc_storage.ndim > 0) {
+        delete[] this->alloc_storage.shape;
+      }
+      return;
     case Opcode::AllocTensor:
       delete[] this->alloc_tensor.shape;
       return;
@@ -338,7 +360,8 @@ Instruction Instruction::AllocTensorReg(RegName storage, 
RegName offset, RegName
 }
 
 Instruction Instruction::AllocStorage(RegName size, Index alignment, 
DLDataType dtype_hint,
-                                      Index device_index, RegName dst) {
+                                      Index device_index, const 
std::vector<int64_t>& shape,
+                                      RegName dst) {
   Instruction instr;
   instr.op = Opcode::AllocStorage;
   instr.dst = dst;
@@ -346,6 +369,13 @@ Instruction Instruction::AllocStorage(RegName size, Index 
alignment, DLDataType
   instr.alloc_storage.alignment = alignment;
   instr.alloc_storage.dtype_hint = dtype_hint;
   instr.alloc_storage.device_index = device_index;
+  instr.alloc_storage.ndim = static_cast<uint32_t>(shape.size());
+  if (instr.alloc_storage.ndim > 0) {
+    instr.alloc_storage.shape = new int64_t[shape.size()];
+    for (size_t i = 0; i < shape.size(); ++i) {
+      instr.alloc_storage.shape[i] = shape[i];
+    }
+  }
   return instr;
 }
 
@@ -474,11 +504,12 @@ Instruction Instruction::InvokeClosure(RegName closure, 
const std::vector<RegNam
   return instr;
 }
 
-Instruction Instruction::LoadConst(Index const_index, RegName dst) {
+Instruction Instruction::LoadConst(Index const_index, Index device_index, 
RegName dst) {
   Instruction instr;
   instr.op = Opcode::LoadConst;
   instr.dst = dst;
   instr.const_index = const_index;
+  instr.device_index = device_index;
   return instr;
 }
 
@@ -596,7 +627,8 @@ void InstructionPrint(std::ostream& os, const Instruction& 
instr) {
       break;
     }
     case Opcode::LoadConst: {
-      os << "load_const $" << instr.dst << " Const[" << instr.const_index << 
"]";
+      os << "load_const $" << instr.dst << " Const[" << instr.const_index << 
"] "
+         << instr.device_index;
       break;
     }
     case Opcode::LoadConsti: {
@@ -616,9 +648,15 @@ void InstructionPrint(std::ostream& os, const Instruction& 
instr) {
       break;
     }
     case Opcode::AllocStorage: {
-      os << "alloc_storage $" << instr.dst << " $" << 
instr.alloc_storage.allocation_size << " "
-         << instr.alloc_storage.alignment << " "
-         << DLDataType2String(instr.alloc_storage.dtype_hint) << " "
+      os << "alloc_storage $" << instr.dst << " ";
+      if (instr.alloc_storage.ndim > 0) {
+        os << "[" << StrJoin<int64_t>(instr.alloc_storage.shape, 0, 
instr.alloc_storage.ndim)
+           << "] ";
+      } else {
+        os << "$" << instr.alloc_storage.allocation_size << " " << 
instr.alloc_storage.alignment
+           << " ";
+      }
+      os << DLDataType2String(instr.alloc_storage.dtype_hint) << " "
          << instr.alloc_storage.device_index;
       break;
     }
diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc
index 2b3119b169..58c509f8d9 100644
--- a/src/runtime/vm/executable.cc
+++ b/src/runtime/vm/executable.cc
@@ -183,7 +183,7 @@ std::string Executable::GetConstants() const {
     const auto& constant = constants[i];
     auto ndarray = Downcast<NDArray>(constant);
     oss << "VM Const[" << i
-        << "]: " << RuntimeObject2String(ndarray, 
virtual_devices[host_device_index])
+        << "]: " << RuntimeObject2String(ndarray, 
virtual_devices[host_device_index].first)
         << " on device index " << const_device_indexes[i] << std::endl;
   }
   return oss.str();
@@ -192,9 +192,9 @@ std::string Executable::GetConstants() const {
 std::string Executable::GetVirtualDevices() const {
   std::ostringstream oss;
   for (size_t i = 0; i < virtual_devices.size(); ++i) {
-    const auto& device = virtual_devices[i];
-    oss << "VM VirtualDevice[" << i << "]: device type " << device.device_type 
<< " and id "
-        << device.device_id << std::endl;
+    const auto& [device, scope] = virtual_devices[i];
+    oss << "VM VirtualDevice[" << i << "]: device type " << device.device_type 
<< ", id "
+        << device.device_id << " and mem_scope " << scope << std::endl;
   }
   return oss.str();
 }
@@ -596,7 +596,13 @@ VMInstructionSerializer SerializeInstruction(const 
Instruction& instr) {
       fields.push_back(dtype.bits);
       fields.push_back(dtype.lanes);
       fields.push_back(instr.alloc_storage.device_index);
+      fields.push_back(instr.alloc_storage.ndim);
       fields.push_back(instr.dst);
+
+      // Save the shape of the tensor.
+      // Note that this field is rotated to the end of the list.
+      fields.insert(fields.end(), instr.alloc_storage.shape,
+                    instr.alloc_storage.shape + instr.alloc_storage.ndim);
       break;
     }
     case Opcode::AllocADT: {
@@ -639,8 +645,8 @@ VMInstructionSerializer SerializeInstruction(const 
Instruction& instr) {
       break;
     }
     case Opcode::LoadConst: {
-      // Number of fields = 2
-      fields.assign({instr.const_index, instr.dst});
+      // Number of fields = 3
+      fields.assign({instr.const_index, instr.device_index, instr.dst});
       break;
     }
     case Opcode::LoadConsti: {
@@ -910,8 +916,8 @@ Instruction DeserializeInstruction(const 
VMInstructionSerializer& instr) {
       return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst);
     }
     case Opcode::AllocStorage: {
-      // Number of fields = 7
-      DCHECK_GE(instr.fields.size(), 7U);
+      // Number of fields = 9
+      DCHECK_GE(instr.fields.size(), 9U);
       Index allocation_size = instr.fields[0];
       Index alignment = instr.fields[1];
 
@@ -921,9 +927,11 @@ Instruction DeserializeInstruction(const 
VMInstructionSerializer& instr) {
       dtype.lanes = instr.fields[4];
 
       Index device_type = instr.fields[5];
-      RegName dst = instr.fields[6];
+      Index ndim = instr.fields[6];
+      RegName dst = instr.fields[7];
+      std::vector<Index> shape = ExtractFields(instr.fields, 8, ndim);
 
-      return Instruction::AllocStorage(allocation_size, alignment, dtype, 
device_type, dst);
+      return Instruction::AllocStorage(allocation_size, alignment, dtype, 
device_type, shape, dst);
     }
     case Opcode::If: {
       // Number of fields = 4
@@ -960,9 +968,9 @@ Instruction DeserializeInstruction(const 
VMInstructionSerializer& instr) {
       return Instruction::InvokeClosure(closure, args, dst);
     }
     case Opcode::LoadConst: {
-      // Number of fields = 2
-      DCHECK_EQ(instr.fields.size(), 2U);
-      return Instruction::LoadConst(instr.fields[0], instr.fields[1]);
+      // Number of fields = 3
+      DCHECK_EQ(instr.fields.size(), 3U);
+      return Instruction::LoadConst(instr.fields[0], instr.fields[1], 
instr.fields[2]);
     }
     case Opcode::LoadConsti: {
       // Number of fields = 2
diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc
index 360185aac5..7df6b928a3 100644
--- a/src/runtime/vm/profiler/vm.cc
+++ b/src/runtime/vm/profiler/vm.cc
@@ -129,9 +129,21 @@ void VirtualMachineDebug::OpStartHook(Instruction instr) {
           {{"Argument Shapes",
             profiling::ShapeString(shape_tensor, 
instr.alloc_tensor_reg.dtype)}});
     } else if (instr.op == Opcode::AllocStorage) {
-      auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
       std::ostringstream shape;
-      shape << DLDataType2String(instr.alloc_storage.dtype_hint) << "[" << 
size << "]";
+      if (instr.alloc_storage.ndim > 0) {
+        std::string shape_str = "[";
+        for (uint32_t i = 0; i < instr.alloc_storage.ndim; ++i) {
+          if (i > 0) {
+            shape_str += ", ";
+          }
+          shape_str += std::to_string(instr.alloc_storage.shape[i]);
+        }
+        shape_str += "]";
+        shape << DLDataType2String(instr.alloc_storage.dtype_hint) << 
shape_str;
+      } else {
+        auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
+        shape << DLDataType2String(instr.alloc_storage.dtype_hint) << "[" << 
size << "]";
+      }
       Device dev = GetDevice(instr.alloc_storage.device_index);
       prof_.operator*().StartCall("VM::AllocStorage", dev,
                                   {{"VM::Argument Shapes", 
String(shape.str())}});
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 50c757f8fb..188a4153e1 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -66,7 +66,7 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& 
vm_func) {
   return os;
 }
 
-inline ObjectRef CopyTo(ObjectRef src, const DLDevice& dev) {
+inline ObjectRef CopyTo(ObjectRef src, const DLDevice& dev, Optional<String> 
mem_scope = NullOpt) {
   if (src->IsInstance<NDArray::ContainerType>()) {
     auto nd_array = Downcast<NDArray>(src);
     // TODO(mbs): Should respect device id also.
@@ -79,7 +79,7 @@ inline ObjectRef CopyTo(ObjectRef src, const DLDevice& dev) {
       VLOG(2) << "copying from " << nd_array->device.device_type << "["
               << nd_array->device.device_id << "] to " << dev.device_type << 
"[" << dev.device_id
               << "]";
-      return nd_array.CopyTo(dev);
+      return nd_array.CopyTo(dev, mem_scope);
     }
     return src;
   } else {
@@ -88,7 +88,7 @@ inline ObjectRef CopyTo(ObjectRef src, const DLDevice& dev) {
     std::vector<ObjectRef> ret;
     ADT adt = Downcast<ADT>(src);
     for (size_t i = 0; i < adt.size(); i++) {
-      ret.push_back(CopyTo(adt[i], dev));
+      ret.push_back(CopyTo(adt[i], dev, mem_scope));
     }
     return ADT(adt->tag, ret.begin(), ret.end());
   }
@@ -532,7 +532,7 @@ void VirtualMachine::Init(const std::vector<Device>& 
physical_devices,
   for (size_t device_index = 0; device_index < num_virtual_devices; 
++device_index) {
     // We'll retain the legacy behaviour and just match by device type.
     // TODO(mbs): Generalize.
-    DLDeviceType virtual_device_type = 
exec_->virtual_devices[device_index].device_type;
+    DLDeviceType virtual_device_type = 
exec_->virtual_devices[device_index].first.device_type;
     auto itr = std::find_if(physical_devices.begin(), physical_devices.end(),
                             [virtual_device_type](const Device& 
physical_device) {
                               return physical_device.device_type == 
virtual_device_type;
@@ -658,8 +658,9 @@ void VirtualMachine::RunLoop(const std::vector<Index>& 
output_tensor_reg_indices
         }
 
         if (!const_pool_[instr.const_index].defined()) {
-          Device dev = 
GetDevice(exec_->const_device_indexes[instr.const_index]);
-          const_pool_[instr.const_index] = CopyTo(constant_obj, dev);
+          auto& [dev, mem_scope] =
+              
exec_->virtual_devices[exec_->const_device_indexes[instr.const_index]];
+          const_pool_[instr.const_index] = CopyTo(constant_obj, dev, 
String(mem_scope));
         }
         WriteRegister(instr.dst, const_pool_[instr.const_index]);
         if (is_not_cached) {
@@ -819,17 +820,36 @@ void VirtualMachine::RunLoop(const std::vector<Index>& 
output_tensor_reg_indices
       }
       case Opcode::AllocStorage: {
         OpStartHook(instr);
-        auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
-        auto alignment = instr.alloc_storage.alignment;
 
         auto storage_obj = SimpleObjAllocator().make_object<StorageObj>();
         Allocator* allocator = GetAllocator(instr.alloc_storage.device_index);
         ICHECK(allocator) << "Did you forget to init the VirtualMachine with 
devices?";
-        VLOG(2) << "allocating with allocation_size=" << size << ", 
alignment=" << alignment
-                << ", dtype_hint=" << 
DLDataType2String(instr.alloc_storage.dtype_hint)
-                << ", device_index=" << instr.alloc_storage.device_index;
 
-        storage_obj->buffer = allocator->Alloc(size, alignment, 
instr.alloc_storage.dtype_hint);
+        if (instr.alloc_storage.ndim > 0) {
+          std::string shape = "[";
+          for (uint32_t i = 0; i < instr.alloc_storage.ndim; ++i) {
+            if (i > 0) {
+              shape += ", ";
+            }
+            shape += std::to_string(instr.alloc_storage.shape[i]);
+          }
+          shape += "]";
+          std::string mem_scope = 
exec_->virtual_devices[instr.alloc_storage.device_index].second;
+          VLOG(2) << "allocating with ndims=" << instr.alloc_storage.ndim << 
", shape=" << shape
+                  << ", dtype_hint=" << 
DLDataType2String(instr.alloc_storage.dtype_hint)
+                  << ", device_index=" << instr.alloc_storage.device_index
+                  << ", memory_scope=" << mem_scope;
+          storage_obj->buffer =
+              allocator->Alloc(instr.alloc_storage.ndim, 
instr.alloc_storage.shape,
+                               instr.alloc_storage.dtype_hint, mem_scope);
+        } else {
+          auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
+          auto alignment = instr.alloc_storage.alignment;
+          VLOG(2) << "allocating with allocation_size=" << size << ", 
alignment=" << alignment
+                  << ", dtype_hint=" << 
DLDataType2String(instr.alloc_storage.dtype_hint)
+                  << ", device_index=" << instr.alloc_storage.device_index;
+          storage_obj->buffer = allocator->Alloc(size, alignment, 
instr.alloc_storage.dtype_hint);
+        }
         Storage storage(storage_obj);
         WriteRegister(instr.dst, storage);
         OpStopHook();
@@ -899,8 +919,9 @@ void VirtualMachine::RunLoop(const std::vector<Index>& 
output_tensor_reg_indices
         ICHECK_EQ(actual_src_dev.device_type, inst_src_dev.device_type);
         ICHECK_EQ(actual_src_dev.device_id, inst_src_dev.device_id);
         Device dst_dev = GetDevice(instr.device_copy.dst_device_index);
+        auto mem_scope = 
exec_->virtual_devices[instr.device_copy.dst_device_index].second;
 
-        NDArray dst_data = src_data.CopyTo(dst_dev);
+        NDArray dst_data = src_data.CopyTo(dst_dev, String(mem_scope));
         WriteRegister(instr.dst, dst_data);
         OpStopHook();
         pc_++;
diff --git a/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py 
b/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py
index 3476037946..3c9c3f2caf 100644
--- a/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py
+++ b/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py
@@ -21,16 +21,17 @@ import numpy as np
 from tvm import relay
 from tvm.relay import testing
 from tvm.contrib import utils
-from utils.adreno_utils import gpu_preprocess, build_run_compare
+from utils.adreno_utils import gpu_preprocess, build_run_compare, 
build_run_compare_vm
 import pytest
 
 
+executor_type = tvm.testing.parameter("ge", "vm")
 dtype = tvm.testing.parameter("float32")
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(remote, target, dtype):
+def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(remote, target, 
executor_type, dtype):
     input_shape = (1, 32, 42, 42)
     filter_shape = (96, 32, 3, 3)
     bias_shape = (1, 96, 1, 1)
@@ -65,14 +66,19 @@ def 
test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(remote, target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
[], gpu_preprocess
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(remote, target, 
dtype):
+def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(remote, target, 
executor_type, dtype):
     input_shape = (1, 32, 40, 40)
     filter_shape = (96, 32, 2, 2)
     bias_shape = (1, 96, 1, 1)
@@ -107,14 +113,19 @@ def 
test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(remote, target, dtype)
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
[], gpu_preprocess
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_inceptionv3_35_35_strides(remote, target, dtype):
+def test_conv2d_inceptionv3_35_35_strides(remote, target, executor_type, 
dtype):
     input_shape = (1, 48, 35, 35)
     filter_shape = (64, 48, 5, 5)
     bias_shape = (1, 64, 1, 1)
@@ -149,14 +160,19 @@ def test_conv2d_inceptionv3_35_35_strides(remote, target, 
dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
[], gpu_preprocess
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_resnet50_v2_nchw_3c(remote, target, dtype):
+def test_conv2d_resnet50_v2_nchw_3c(remote, target, executor_type, dtype):
     input_shape = (1, 3, 224, 224)
     filter_shape = (64, 3, 7, 7)
     bias_shape = (1, 64, 1, 1)
@@ -192,12 +208,15 @@ def test_conv2d_resnet50_v2_nchw_3c(remote, target, 
dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_inceptionv3_nchw_3c(remote, target, dtype):
+def test_conv2d_inceptionv3_nchw_3c(remote, target, executor_type, dtype):
     input_shape = (1, 3, 299, 299)
     filter_shape = (64, 3, 3, 3)
     bias_shape = (1, 64, 1, 1)
@@ -232,12 +251,15 @@ def test_conv2d_inceptionv3_nchw_3c(remote, target, 
dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_1x1_16c16spatial(remote, target, dtype):
+def test_conv2d_1x1_16c16spatial(remote, target, executor_type, dtype):
     input_shape = (1, 16, 256, 256)
     filter_shape = (32, 16, 4, 4)
     bias_shape = (1, 32, 1, 1)
@@ -272,12 +294,15 @@ def test_conv2d_1x1_16c16spatial(remote, target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_4x4_16c16pad(remote, target, dtype):
+def test_conv2d_4x4_16c16pad(remote, target, executor_type, dtype):
     input_shape = (1, 32, 256, 256)
     filter_shape = (32, 32, 4, 4)
     bias_shape = (1, 32, 1, 1)
@@ -312,12 +337,15 @@ def test_conv2d_4x4_16c16pad(remote, target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_4x4x4_16c16pad(remote, target, dtype):
+def test_conv2d_4x4x4_16c16pad(remote, target, executor_type, dtype):
     input_shape = (1, 32, 256, 256)
     filter_shape = (4, 32, 4, 4)
     bias_shape = (1, 4, 1, 1)
@@ -352,12 +380,15 @@ def test_conv2d_4x4x4_16c16pad(remote, target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_yolov3_v2_nchw_3c(remote, target, dtype):
+def test_conv2d_yolov3_v2_nchw_3c(remote, target, executor_type, dtype):
     input_shape = (1, 1024, 13, 13)
     filter_shape = (255, 1024, 1, 1)
     A = relay.var("data", shape=input_shape, dtype=dtype)
@@ -385,12 +416,15 @@ def test_conv2d_yolov3_v2_nchw_3c(remote, target, dtype):
         "weight": tvm.nd.array(filter_data),
     }
 
-    build_run_compare(remote, mod, params, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_vgg16_winograd_4d(remote, target, dtype):
+def test_conv2d_vgg16_winograd_4d(remote, target, executor_type, dtype):
     input_shape = (1, 512, 28, 28)
     filter_shape = (512, 512, 3, 3)
     bias_shape = (1, 512, 1, 1)
@@ -429,16 +463,35 @@ def test_conv2d_vgg16_winograd_4d(remote, target, dtype):
         f.write(
             f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nchw_winograd.image2d", [["TENSOR", [1, 512, 28, 
28], "{dtype}"], ["TENSOR", [512, 512, 3, 3], "{dtype}"], [1, 1], [1, 1, 1, 1], 
[1, 1], "{dtype}"], {{}}], "config": {{"index": 1591, "code_hash": null, 
"entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], 
["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}}, "result": 
[[0.0037244], 0, 7.06374192237854, 165 [...]
         )
-    graph = build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
stat_file=stat_file
-    )
-    matches = re.findall("winograd", graph)
-    assert len(matches) > 0
+    if executor_type == "ge":
+        graph = build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", graph)
+        assert len(matches) > 0
+    else:
+        vmc = build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", vmc.primitives)
+        assert len(matches) > 0
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_winograd_conv(remote, target, dtype):
+def test_conv2d_winograd_conv(remote, target, executor_type, dtype):
     input_shape = (1, 4, 3, 3)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     filter_shape3 = (8, 4, 3, 3)
@@ -476,16 +529,35 @@ def test_conv2d_winograd_conv(remote, target, dtype):
         f.write(
             f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nchw_winograd.image2d", [["TENSOR", [1, 4, 3, 
3], "{dtype}"], ["TENSOR", [8, 4, 3, 3], "{dtype}"], [1, 1], [1, 1, 1, 1], [1, 
1], "{dtype}"], {{}}], "config": {{"index": 1591, "code_hash": null, "entity": 
[["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", 
"sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}}, "result": [[0.0037244], 0, 
7.06374192237854, 1653898629. [...]
         )
-    graph = build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
stat_file=stat_file
-    )
-    matches = re.findall("winograd", graph)
-    assert len(matches) > 0
+    if executor_type == "ge":
+        graph = build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", graph)
+        assert len(matches) > 0
+    else:
+        vmc = build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", vmc.primitives)
+        assert len(matches) > 0
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_residual_block(remote, target, dtype):
+def test_residual_block(remote, target, executor_type, dtype):
     """
     - some kind of residual block followed by convolution to have texture 
after residual block
     - scalar data type verification which should be mapped to global memory 
scope
@@ -602,14 +674,31 @@ def test_residual_block(remote, target, dtype):
             "",
         ]
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
static_memory_scope
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
+    else:
+        build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_concat(remote, target, dtype):
+def test_concat(remote, target, executor_type, dtype):
     """
         layout_transform (NCHW->NCHW4c)
                   |                      <- buffer
@@ -716,14 +805,31 @@ def test_concat(remote, target, dtype):
 
     static_memory_scope = []
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
static_memory_scope
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
+    else:
+        build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_pooling_branching_texture_params(remote, target, dtype):
+def test_pooling_branching_texture_params(remote, target, executor_type, 
dtype):
     """
     Verification of the pooling and many branches having textures
                 layout_transform (NCHW->NCHW4c)
@@ -844,14 +950,31 @@ def test_pooling_branching_texture_params(remote, target, 
dtype):
         "",
     ]
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
static_memory_scope
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
+    else:
+        build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_branching_texture_params(remote, target, dtype):
+def test_branching_texture_params(remote, target, executor_type, dtype):
     """
     Verification of passing texture to several consumers markup of relay 
variables in
     primary functions + on_device
@@ -970,15 +1093,32 @@ def test_branching_texture_params(remote, target, dtype):
         "",
     ]
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
static_memory_scope
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
+    else:
+        build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
 
 
 # function repeat, params scope are different in reused functions
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_different_lowering_same_op(remote, target, dtype):
+def test_conv2d_different_lowering_same_op(remote, target, executor_type, 
dtype):
     """
     Use case for verification of caching compiled functions
     Three convolutions following by each other in this case should be
@@ -1054,14 +1194,31 @@ def test_conv2d_different_lowering_same_op(remote, 
target, dtype):
         "",
     ]
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
static_memory_scope
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
+    else:
+        build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_winograd_non_rect(remote, target, dtype):
+def test_conv2d_winograd_non_rect(remote, target, executor_type, dtype):
     input_shape = (1, 771, 36, 64)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     filter_shape = (128, 771, 3, 3)
@@ -1085,17 +1242,36 @@ def test_conv2d_winograd_non_rect(remote, target, 
dtype):
         f.write(
             f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256 -texture_spatial_limit=16384 -thread_warp_size=1", 
"conv2d_nchw_winograd.image2d", [["TENSOR", [1, 771, 36, 64], "{dtype}"], 
["TENSOR", [128, 771, 3, 3], "{dtype}"], [1, 1], [1, 1, 1, 1], [1, 1], 
"{dtype}"], {{}}], "config": {{"index": 5399, "code_hash": null, "entity": 
[["auto_unroll_max_step", "ot", 16], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", 
"sp", [-1, 4, 8]], ["tile_rc", "sp", [-1, 193]]] [...]
         )
-    graph = build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
stat_file=stat_file
-    )
-    matches = re.findall("winograd", graph)
-    assert len(matches) > 0
+    if executor_type == "ge":
+        graph = build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", graph)
+        assert len(matches) > 0
+    else:
+        vmc = build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", vmc.primitives)
+        assert len(matches) > 0
 
 
 # function repeat, params scope are different in reused functions
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_injective_nwo_inputs1(remote, target, dtype):
+def test_injective_nwo_inputs1(remote, target, executor_type, dtype):
     """
     Use case for verification of stability of annotation primary functions
     having several ops accepting data outside of Primary function
@@ -1186,15 +1362,32 @@ def test_injective_nwo_inputs1(remote, target, dtype):
         "global",
         "global",
     ]
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
static_memory_scope
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
+    else:
+        build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
 
 
 # function repeat, params scope are different in reused functions
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_injective_nwo_inputs2(remote, target, dtype):
+def test_injective_nwo_inputs2(remote, target, executor_type, dtype):
     """
     Use case for verification of stability of annotation primary functions
     having several ops accepting data outside of Primary function
@@ -1284,14 +1477,31 @@ def test_injective_nwo_inputs2(remote, target, dtype):
         "global.texture",
         "global",
     ]
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
static_memory_scope
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
+    else:
+        build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            static_memory_scope,
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_to_3_channels(remote, target, dtype):
+def test_conv2d_to_3_channels(remote, target, executor_type, dtype):
     input_shape = (1, 256, 200, 200)
     filter_shape = (3, 256, 1, 1)
     A = relay.var("data", shape=input_shape, dtype=dtype)
@@ -1316,7 +1526,12 @@ def test_conv2d_to_3_channels(remote, target, dtype):
         "weight": tvm.nd.array(filter_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target, [])
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target, [])
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, []
+        )
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py 
b/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py
index 5f69e777d9..dc86a23187 100644
--- a/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py
+++ b/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py
@@ -22,16 +22,17 @@ import numpy as np
 from tvm import relay
 from tvm.relay import testing
 from tvm.contrib import utils
-from utils.adreno_utils import gpu_preprocess, build_run_compare
+from utils.adreno_utils import gpu_preprocess, build_run_compare, 
build_run_compare_vm
 import pytest
 
 
+executor_type = tvm.testing.parameter("ge", "vm")
 dtype = tvm.testing.parameter("float32")
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16(remote, target, dtype):
+def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16(remote, target, 
executor_type, dtype):
     input_shape = (1, 257, 257, 32)
     filter_shape = (1, 1, 32, 16)
     bias_shape = (filter_shape[-1],)
@@ -63,12 +64,15 @@ def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16(remote, 
target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16_with_padding(remote, target, 
dtype):
+def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16_with_padding(remote, target, 
executor_type, dtype):
     input_shape = (1, 257, 257, 32)
     filter_shape = (1, 1, 32, 16)
     bias_shape = (filter_shape[-1],)
@@ -103,12 +107,15 @@ def 
test_conv2d_deeplabv3_1_257_257_32x1_1_32_16_with_padding(remote, target, dt
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_4_35_35_32x3_3_144_16(remote, target, dtype):
+def test_conv2d_4_35_35_32x3_3_144_16(remote, target, executor_type, dtype):
     input_shape = (4, 35, 35, 32)
     filter_shape = (3, 3, 32, 16)
     bias_shape = (filter_shape[-1],)
@@ -141,12 +148,15 @@ def test_conv2d_4_35_35_32x3_3_144_16(remote, target, 
dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_deeplabv3_1_513_513_3x3_3_3_32(remote, target, dtype):
+def test_conv2d_deeplabv3_1_513_513_3x3_3_3_32(remote, target, executor_type, 
dtype):
     input_shape = (1, 513, 513, 3)
     filter_shape = (3, 3, 3, 32)
     bias_shape = (filter_shape[-1],)
@@ -179,12 +189,15 @@ def test_conv2d_deeplabv3_1_513_513_3x3_3_3_32(remote, 
target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(remote, target, dtype):
+def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(remote, target, 
executor_type, dtype):
     input_shape = (1, 42, 42, 32)
     filter_shape = (3, 3, 32, 96)
     bias_shape = (1, 1, 1, 96)
@@ -219,14 +232,19 @@ def 
test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(remote, target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
[], gpu_preprocess
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(remote, target, 
dtype):
+def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(remote, target, 
executor_type, dtype):
     input_shape = (1, 40, 40, 32)
     filter_shape = (2, 2, 32, 96)
     bias_shape = (1, 1, 1, 96)
@@ -261,14 +279,19 @@ def 
test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(remote, target, dtype)
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
[], gpu_preprocess
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_inceptionv3_35_35_strides(remote, target, dtype):
+def test_conv2d_inceptionv3_35_35_strides(remote, target, executor_type, 
dtype):
     input_shape = (1, 35, 35, 48)
     filter_shape = (5, 5, 48, 64)
     bias_shape = (1, 1, 1, 64)
@@ -303,14 +326,19 @@ def test_conv2d_inceptionv3_35_35_strides(remote, target, 
dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
[], gpu_preprocess
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_resnet50_v2_nhwc_3c(remote, target, dtype):
+def test_conv2d_resnet50_v2_nhwc_3c(remote, target, executor_type, dtype):
     input_shape = (1, 224, 224, 3)
     filter_shape = (7, 7, 3, 64)
     bias_shape = (1, 1, 1, 64)
@@ -346,12 +374,15 @@ def test_conv2d_resnet50_v2_nhwc_3c(remote, target, 
dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_inceptionv3_nhwc_3c(remote, target, dtype):
+def test_conv2d_inceptionv3_nhwc_3c(remote, target, executor_type, dtype):
     input_shape = (1, 299, 299, 3)
     filter_shape = (3, 3, 3, 64)
     bias_shape = (1, 1, 1, 64)
@@ -386,12 +417,15 @@ def test_conv2d_inceptionv3_nhwc_3c(remote, target, 
dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_1x1_16c16spatial(remote, target, dtype):
+def test_conv2d_1x1_16c16spatial(remote, target, executor_type, dtype):
     input_shape = (1, 128, 128, 16)
     filter_shape = (4, 4, 16, 32)
     bias_shape = (1, 1, 1, 32)
@@ -426,12 +460,15 @@ def test_conv2d_1x1_16c16spatial(remote, target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_4x4_16c16pad(remote, target, dtype):
+def test_conv2d_4x4_16c16pad(remote, target, executor_type, dtype):
     input_shape = (1, 256, 256, 32)
     filter_shape = (4, 4, 32, 32)
     bias_shape = (1, 1, 1, 32)
@@ -466,12 +503,15 @@ def test_conv2d_4x4_16c16pad(remote, target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_4x4x4_16c16pad(remote, target, dtype):
+def test_conv2d_4x4x4_16c16pad(remote, target, executor_type, dtype):
     input_shape = (1, 256, 256, 32)
     filter_shape = (4, 4, 32, 4)
     bias_shape = (1, 1, 1, 4)
@@ -505,12 +545,15 @@ def test_conv2d_4x4x4_16c16pad(remote, target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_yolov3_v2_nhwc_3c(remote, target, dtype):
+def test_conv2d_yolov3_v2_nhwc_3c(remote, target, executor_type, dtype):
     input_shape = (1, 13, 13, 1024)
     filter_shape = (1, 1, 1024, 255)
     A = relay.var("data", shape=input_shape, dtype=dtype)
@@ -538,12 +581,15 @@ def test_conv2d_yolov3_v2_nhwc_3c(remote, target, dtype):
         "weight": tvm.nd.array(filter_data),
     }
 
-    build_run_compare(remote, mod, params, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_vgg16_winograd_4d(remote, target, dtype):
+def test_conv2d_vgg16_winograd_4d(remote, target, executor_type, dtype):
     input_shape = (1, 28, 28, 512)
     filter_shape = (3, 3, 512, 512)
     bias_shape = (1, 1, 1, 512)
@@ -582,16 +628,35 @@ def test_conv2d_vgg16_winograd_4d(remote, target, dtype):
         f.write(
             f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nhwc_winograd.image2d", [["TENSOR", [1, 28, 28, 
512], "{dtype}"], ["TENSOR", [3, 3, 512, 512], "{dtype}"], [1, 1], [1, 1, 1, 
1], [1, 1], "{dtype}"], {{}}], "config": {{"index": 1591, "code_hash": null, 
"entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], 
["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}}, "result": 
[[0.0037244], 0, 7.06374192237854, 165 [...]
         )
-    graph = build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
stat_file=stat_file
-    )
-    matches = re.findall("winograd", graph)
-    assert len(matches) > 0
+    if executor_type == "ge":
+        graph = build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", graph)
+        assert len(matches) > 0
+    else:
+        vmc = build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", vmc.primitives)
+        assert len(matches) > 0
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_vgg16_winograd_4d_expand_spatial_dims(remote, target, dtype):
+def test_conv2d_vgg16_winograd_4d_expand_spatial_dims(remote, target, 
executor_type, dtype):
     input_shape = (1, 28, 28, 1)
     filter_shape = (3, 3, 1, 64)
     bias_shape = (1, 1, 1, 64)
@@ -629,16 +694,35 @@ def 
test_conv2d_vgg16_winograd_4d_expand_spatial_dims(remote, target, dtype):
         f.write(
             f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nhwc_winograd.image2d", [["TENSOR", [1, 28, 28, 
1], "{dtype}"], ["TENSOR", [3, 3, 1, 64], "{dtype}"], [1, 1], [0, 0, 0, 0], [1, 
1], "{dtype}"], {{}}], "config": {{"index": 1591, "code_hash": null, "entity": 
[["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", 
"sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}}, "result": [[0.0037244], 0, 
7.06374192237854, 16538986 [...]
         )
-    graph = build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
stat_file=stat_file
-    )
-    matches = re.findall("winograd", graph)
-    assert len(matches) > 0
+    if executor_type == "ge":
+        graph = build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", graph)
+        assert len(matches) > 0
+    else:
+        vmc = build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", vmc.primitives)
+        assert len(matches) > 0
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_winograd_conv(remote, target, dtype):
+def test_conv2d_winograd_conv(remote, target, executor_type, dtype):
     input_shape = (1, 3, 3, 4)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     filter_shape3 = (3, 3, 4, 8)
@@ -690,16 +774,35 @@ def test_conv2d_winograd_conv(remote, target, dtype):
         f.write(
             f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nhwc_winograd.image2d", [["TENSOR", [1, 3, 3, 
4], "{dtype}"], ["TENSOR", [3, 3, 4, 8], "{dtype}"], [1, 1], [1, 1, 1, 1], [1, 
1], "{dtype}"], {{}}], "config": {{"index": 1591, "code_hash": null, "entity": 
[["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", 
"sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}}, "result": [[0.0037244], 0, 
7.06374192237854, 1653898629. [...]
         )
-    graph = build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
stat_file=stat_file
-    )
-    matches = re.findall("winograd", graph)
-    assert len(matches) > 0
+    if executor_type == "ge":
+        graph = build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", graph)
+        assert len(matches) > 0
+    else:
+        vmc = build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", vmc.primitives)
+        assert len(matches) > 0
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_winograd_non_rect(remote, target, dtype):
+def test_conv2d_winograd_non_rect(remote, target, executor_type, dtype):
     input_shape = (1, 36, 64, 771)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     filter_shape = (3, 3, 771, 128)
@@ -730,16 +833,35 @@ def test_conv2d_winograd_non_rect(remote, target, dtype):
         f.write(
             f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256 -texture_spatial_limit=16384 -thread_warp_size=1", 
"conv2d_nhwc_winograd.image2d", [["TENSOR", [1, 36, 64, 771], "{dtype}"], 
["TENSOR", [3, 3, 771, 128], "{dtype}"], [1, 1], [1, 1, 1, 1], [1, 1], 
"{dtype}"], {{}}], "config": {{"index": 5399, "code_hash": null, "entity": 
[["auto_unroll_max_step", "ot", 16], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", 
"sp", [-1, 4, 8]], ["tile_rc", "sp", [-1, 193]]] [...]
         )
-    graph = build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
stat_file=stat_file
-    )
-    matches = re.findall("winograd", graph)
-    assert len(matches) > 0
+    if executor_type == "ge":
+        graph = build_run_compare(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", graph)
+        assert len(matches) > 0
+    else:
+        vmc = build_run_compare_vm(
+            remote,
+            mod,
+            params1,
+            {"data": input_shape},
+            {"data": dtype},
+            target,
+            stat_file=stat_file,
+        )
+        matches = re.findall("winograd", vmc.primitives)
+        assert len(matches) > 0
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_to_3_channels(remote, target, dtype):
+def test_conv2d_to_3_channels(remote, target, executor_type, dtype):
     input_shape = (1, 200, 200, 256)
     filter_shape = (1, 1, 256, 3)
     A = relay.var("data", shape=input_shape, dtype=dtype)
@@ -764,7 +886,12 @@ def test_conv2d_to_3_channels(remote, target, dtype):
         "weight": tvm.nd.array(filter_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target, [])
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target, [])
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, []
+        )
 
 
 if __name__ == "__main__":
diff --git 
a/tests/python/relay/opencl_texture/test_depthwise_conv2d_nchw_texture.py 
b/tests/python/relay/opencl_texture/test_depthwise_conv2d_nchw_texture.py
index 2c729a36eb..87e9542140 100644
--- a/tests/python/relay/opencl_texture/test_depthwise_conv2d_nchw_texture.py
+++ b/tests/python/relay/opencl_texture/test_depthwise_conv2d_nchw_texture.py
@@ -20,14 +20,15 @@ import tvm
 import numpy as np
 from tvm import relay
 from tvm.relay import testing
-from utils.adreno_utils import gpu_preprocess, build_run_compare
+from utils.adreno_utils import gpu_preprocess, build_run_compare, 
build_run_compare_vm
 
+executor_type = tvm.testing.parameter("ge", "vm")
 dtype = tvm.testing.parameter("float32")
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_depthwise_conv2d_bias_nchwc(remote, target, dtype):
+def test_depthwise_conv2d_bias_nchwc(remote, target, executor_type, dtype):
     input_shape = (1, 64, 112, 112)
     filter_shape = (64, 1, 3, 3)
     bias_shape = (1, 64, 1, 1)
@@ -64,14 +65,19 @@ def test_depthwise_conv2d_bias_nchwc(remote, target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
[], gpu_preprocess
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_depthwise_conv2d_nchwc(remote, target, dtype):
+def test_depthwise_conv2d_nchwc(remote, target, executor_type, dtype):
     input_shape = (1, 64, 112, 112)
     filter_shape = (64, 1, 3, 3)
     bias_shape = (1, 64, 1, 1)
@@ -103,14 +109,19 @@ def test_depthwise_conv2d_nchwc(remote, target, dtype):
         "weight": tvm.nd.array(filter_data),
     }
 
-    build_run_compare(
-        remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, 
[], gpu_preprocess
-    )
+    if executor_type == "ge":
+        build_run_compare(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, [], gpu_preprocess
+        )
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_depthwise_conv2d_bias_nchw(remote, target, dtype):
+def test_depthwise_conv2d_bias_nchw(remote, target, executor_type, dtype):
     input_shape = (1, 64, 112, 112)
     filter_shape = (64, 1, 3, 3)
     bias_shape = (1, 64, 1, 1)
@@ -147,12 +158,15 @@ def test_depthwise_conv2d_bias_nchw(remote, target, 
dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_depthwise_conv2d_repack_bias_nchw(remote, target, dtype):
+def test_depthwise_conv2d_repack_bias_nchw(remote, target, executor_type, 
dtype):
     input_shape = (1, 63, 112, 112)
     filter_shape = (63, 1, 3, 3)
     bias_shape = (1, 63, 1, 1)
@@ -189,12 +203,15 @@ def test_depthwise_conv2d_repack_bias_nchw(remote, 
target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_to_3_channels(remote, target, dtype):
+def test_conv2d_to_3_channels(remote, target, executor_type, dtype):
     input_shape = (1, 3, 200, 200)
     filter_shape = (3, 1, 1, 1)
     A = relay.var("data", shape=input_shape, dtype=dtype)
@@ -220,7 +237,12 @@ def test_conv2d_to_3_channels(remote, target, dtype):
         "weight": tvm.nd.array(filter_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target, [])
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target, [])
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, []
+        )
 
 
 if __name__ == "__main__":
diff --git 
a/tests/python/relay/opencl_texture/test_depthwise_conv2d_nhwc_texture.py 
b/tests/python/relay/opencl_texture/test_depthwise_conv2d_nhwc_texture.py
index 28f0f4cefa..782c99a96a 100644
--- a/tests/python/relay/opencl_texture/test_depthwise_conv2d_nhwc_texture.py
+++ b/tests/python/relay/opencl_texture/test_depthwise_conv2d_nhwc_texture.py
@@ -20,14 +20,16 @@ import tvm
 import numpy as np
 from tvm import relay
 from tvm.relay import testing
-from utils.adreno_utils import build_run_compare
+from utils.adreno_utils import build_run_compare, build_run_compare_vm
 
+
+executor_type = tvm.testing.parameter("ge", "vm")
 dtype = tvm.testing.parameter("float32")
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1(remote, target, 
dtype):
+def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1(remote, target, 
executor_type, dtype):
     input_shape = (1, 129, 129, 144)
     filter_shape = (3, 3, 144, 1)
     kernel_size = (filter_shape[0], filter_shape[1])
@@ -62,12 +64,15 @@ def 
test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1(remote, target, dtyp
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_depthwise_conv2d_deeplabv3_4_35_35_576x3_3_576_1(remote, target, 
dtype):
+def test_depthwise_conv2d_deeplabv3_4_35_35_576x3_3_576_1(remote, target, 
executor_type, dtype):
     input_shape = (4, 35, 35, 576)
     filter_shape = (3, 3, 576, 1)
     kernel_size = (filter_shape[0], filter_shape[1])
@@ -102,12 +107,17 @@ def 
test_depthwise_conv2d_deeplabv3_4_35_35_576x3_3_576_1(remote, target, dtype)
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def 
test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1_with_padding(remote, 
target, dtype):
+def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1_with_padding(
+    remote, target, executor_type, dtype
+):
     input_shape = (1, 129, 129, 144)
     filter_shape = (3, 3, 144, 1)
     kernel_size = (filter_shape[0], filter_shape[1])
@@ -144,12 +154,15 @@ def 
test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1_with_padding(remote,
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_depthwise_conv2d_1_513_513_7x3_3_7_1(remote, target, dtype):
+def test_depthwise_conv2d_1_513_513_7x3_3_7_1(remote, target, executor_type, 
dtype):
     input_shape = (1, 513, 513, 7)
     filter_shape = (3, 3, 7, 1)
     bias_shape = (filter_shape[2],)
@@ -183,12 +196,15 @@ def test_depthwise_conv2d_1_513_513_7x3_3_7_1(remote, 
target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_depthwise_conv2d_1_513_513_3x3_3_3_1(remote, target, dtype):
+def test_depthwise_conv2d_1_513_513_3x3_3_3_1(remote, target, executor_type, 
dtype):
     input_shape = (1, 513, 513, 3)
     filter_shape = (3, 3, 3, 1)
     bias_shape = (filter_shape[2],)
@@ -222,12 +238,15 @@ def test_depthwise_conv2d_1_513_513_3x3_3_3_1(remote, 
target, dtype):
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_conv2d_to_3_channels(remote, target, dtype):
+def test_conv2d_to_3_channels(remote, target, executor_type, dtype):
     input_shape = (1, 200, 200, 3)
     filter_shape = (1, 1, 3, 1)
     A = relay.var("data", shape=input_shape, dtype=dtype)
@@ -253,7 +272,12 @@ def test_conv2d_to_3_channels(remote, target, dtype):
         "weight": tvm.nd.array(filter_data),
     }
 
-    build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": 
dtype}, target, [])
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target, [])
+    else:
+        build_run_compare_vm(
+            remote, mod, params1, {"data": input_shape}, {"data": dtype}, 
target, []
+        )
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/opencl_texture/test_injection_texture.py 
b/tests/python/relay/opencl_texture/test_injection_texture.py
index 991983706f..31c082c994 100644
--- a/tests/python/relay/opencl_texture/test_injection_texture.py
+++ b/tests/python/relay/opencl_texture/test_injection_texture.py
@@ -20,48 +20,56 @@ import pytest
 import tvm
 import numpy as np
 from tvm import relay
-from tvm.relay import testing
-from tvm.contrib import utils
-from utils.adreno_utils import gpu_preprocess, build_run_compare
+from utils.adreno_utils import build_run_compare, build_run_compare_vm
 
 
+executor_type = tvm.testing.parameter("ge", "vm")
 dtype = tvm.testing.parameter("float32")
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_layout_transform_to_block_nchw4c(remote, target, dtype):
+def test_layout_transform_to_block_nchw4c(remote, target, executor_type, 
dtype):
     """Verification of the case NCHW->NCHW4c"""
     input_shape = (1, 32, 720, 1280)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     lt = relay.layout_transform(A, "NCHW", "NCHW4c")
     mod = relay.Function([A], lt)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_layout_transform_to_block_nchw(remote, target, dtype):
+def test_layout_transform_to_block_nchw(remote, target, executor_type, dtype):
     """Verification of the case NCHW4c->NCHW"""
     input_shape = (1, 36, 1, 1, 4)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     lt = relay.layout_transform(A, "NCHW4c", "NCHW")
     mod = relay.Function([A], lt)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_layout_transform_to_block_nhwc4c(remote, target, dtype):
+def test_layout_transform_to_block_nhwc4c(remote, target, executor_type, 
dtype):
     """Verification of the case NHWC->NHWC4c"""
     input_shape = (1, 1, 1, 144)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     lt = relay.layout_transform(A, "NHWC", "NHWC4c")
     mod = relay.Function([A], lt)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @pytest.mark.skipif(
@@ -69,7 +77,7 @@ def test_layout_transform_to_block_nhwc4c(remote, target, 
dtype):
 )
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_layout_transform_to_block_nhwc(remote, target, dtype):
+def test_layout_transform_to_block_nhwc(remote, target, executor_type, dtype):
     """Verification of the case NHWC4c->NHWC"""
     input_shape = (1, 80, 80, 36, 4)
     A = relay.var("data", shape=input_shape, dtype=dtype)
@@ -78,7 +86,10 @@ def test_layout_transform_to_block_nhwc(remote, target, 
dtype):
     lt = relay.layout_transform(cast, "NHWC4c", "NHWC")
     mod = relay.Function([A], lt)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/opencl_texture/test_network.py 
b/tests/python/relay/opencl_texture/test_network.py
index 1d0e996f9f..2b2f3741cb 100644
--- a/tests/python/relay/opencl_texture/test_network.py
+++ b/tests/python/relay/opencl_texture/test_network.py
@@ -24,10 +24,13 @@ from tvm import relay
 from tvm.contrib import utils
 from tvm.relay import testing
 from tvm.relay.op import register_mixed_precision_conversion
-from utils.adreno_utils import build_run_compare, get_model, gpu_preprocess
+from utils.adreno_utils import build_run_compare, build_run_compare_vm, 
get_model, gpu_preprocess
 
 
-def _test_mobilenet_v1(remote, target, calc_dtype, acc_dtype):
+executor_type = tvm.testing.parameter("ge", "vm")
+
+
+def _test_mobilenet_v1(remote, target, calc_dtype, executor_type, acc_dtype):
     mod, params, inputs, dtypes = get_model(
         
"https://github.com/mlcommons/mobile_models/raw/main/v0_7/tflite/mobilenet_edgetpu_224_1.0_float.tflite";,
         "mobilenet_edgetpu_224_1.0_float.tflite",
@@ -46,29 +49,32 @@ def _test_mobilenet_v1(remote, target, calc_dtype, 
acc_dtype):
             },
         )
 
-    build_run_compare(remote, mod, params, inputs, dtypes, target, [])
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params, inputs, dtypes, target, [])
+    else:
+        build_run_compare_vm(remote, mod, params, inputs, dtypes, target, [])
 
 
 @pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/13443";)
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
 @pytest.mark.skipif(tvm.testing.utils.IS_IN_CI, reason="CI doesn't support 
fp16(half datatypes)")
-def test_mobilenet_v1_fp16(remote, target):
-    _test_mobilenet_v1(remote, target, "float16", "float16")
+def test_mobilenet_v1_fp16(remote, target, executor_type):
+    _test_mobilenet_v1(remote, target, "float16", executor_type, "float16")
 
 
 @pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/13443";)
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_mobilenet_v1_fp32(remote, target):
-    _test_mobilenet_v1(remote, target, "float32", "float32")
+def test_mobilenet_v1_fp32(remote, target, executor_type):
+    _test_mobilenet_v1(remote, target, "float32", executor_type, "float32")
 
 
 @pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/13443";)
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_mobilenet_v1_fp16_acc32(remote, target):
-    _test_mobilenet_v1(remote, target, "float16", "float32")
+def test_mobilenet_v1_fp16_acc32(remote, target, executor_type):
+    _test_mobilenet_v1(remote, target, "float16", executor_type, "float32")
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/opencl_texture/test_pool_texture.py 
b/tests/python/relay/opencl_texture/test_pool_texture.py
index faeb121c80..6190790a3d 100644
--- a/tests/python/relay/opencl_texture/test_pool_texture.py
+++ b/tests/python/relay/opencl_texture/test_pool_texture.py
@@ -17,15 +17,16 @@
 
 import tvm
 from tvm import relay
-from utils.adreno_utils import build_run_compare
+from utils.adreno_utils import build_run_compare, build_run_compare_vm
 
 
+executor_type = tvm.testing.parameter("ge", "vm")
 dtype = tvm.testing.parameter("float32")
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_global_pool2d_nchw_wide(remote, target, dtype):
+def test_global_pool2d_nchw_wide(remote, target, executor_type, dtype):
     """
     Use case of NCHW global pooling with big spatial valies
     """
@@ -34,12 +35,15 @@ def test_global_pool2d_nchw_wide(remote, target, dtype):
     C = relay.nn.global_avg_pool2d(A)
     mod = relay.Function([A], C)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_global_pool2d_nchw4c_wide(remote, target, dtype):
+def test_global_pool2d_nchw4c_wide(remote, target, executor_type, dtype):
     """
     Use case of blocked NCHW4c global pooling with big spatial valies
     """
@@ -48,12 +52,15 @@ def test_global_pool2d_nchw4c_wide(remote, target, dtype):
     C = relay.nn.global_avg_pool2d(A, layout="NCHW4c")
     mod = relay.Function([A], C)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_global_pool2d_nchw_deep(remote, target, dtype):
+def test_global_pool2d_nchw_deep(remote, target, executor_type, dtype):
     """
     Use case of NCHW deep global pooling
     """
@@ -62,12 +69,15 @@ def test_global_pool2d_nchw_deep(remote, target, dtype):
     C = relay.nn.global_avg_pool2d(A)
     mod = relay.Function([A], C)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_global_pool2d_nchw4c_deep(remote, target, dtype):
+def test_global_pool2d_nchw4c_deep(remote, target, executor_type, dtype):
     """
     Use case of blocked NCHW4c deep global pooling
     """
@@ -76,12 +86,15 @@ def test_global_pool2d_nchw4c_deep(remote, target, dtype):
     C = relay.nn.global_avg_pool2d(A, layout="NCHW4c")
     mod = relay.Function([A], C)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_global_pool2d_nhwc(remote, target, dtype):
+def test_global_pool2d_nhwc(remote, target, executor_type, dtype):
     """
     Use case of NHWC global pooling with big spatial valies
     """
@@ -90,12 +103,15 @@ def test_global_pool2d_nhwc(remote, target, dtype):
     C = relay.nn.global_avg_pool2d(A, layout="NHWC")
     mod = relay.Function([A], C)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_global_pool2d_nhwc4c(remote, target, dtype):
+def test_global_pool2d_nhwc4c(remote, target, executor_type, dtype):
     """
     Use case of NHWC deep global pooling
     """
@@ -104,12 +120,15 @@ def test_global_pool2d_nhwc4c(remote, target, dtype):
     C = relay.nn.global_avg_pool2d(A, layout="NHWC4c")
     mod = relay.Function([A], C)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_global_max_pool2d_nchw_wide(remote, target, dtype):
+def test_global_max_pool2d_nchw_wide(remote, target, executor_type, dtype):
     """
     Use case of NCHW global pooling with big spatial valies
     """
@@ -118,12 +137,15 @@ def test_global_max_pool2d_nchw_wide(remote, target, 
dtype):
     C = relay.nn.global_max_pool2d(A)
     mod = relay.Function([A], C)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_global_max_pool2d_nchw4c_wide(remote, target, dtype):
+def test_global_max_pool2d_nchw4c_wide(remote, target, executor_type, dtype):
     """
     Use case of blocked NCHW4c global pooling with big spatial valies
     """
@@ -132,4 +154,11 @@ def test_global_max_pool2d_nchw4c_wide(remote, target, 
dtype):
     C = relay.nn.global_max_pool2d(A, layout="NCHW4c")
     mod = relay.Function([A], C)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relay/opencl_texture/test_reduction_texture.py 
b/tests/python/relay/opencl_texture/test_reduction_texture.py
index 5728e6294f..1016a7c88e 100644
--- a/tests/python/relay/opencl_texture/test_reduction_texture.py
+++ b/tests/python/relay/opencl_texture/test_reduction_texture.py
@@ -21,123 +21,151 @@ import numpy as np
 from tvm import relay
 from tvm.relay import testing
 from tvm.contrib import utils
-from utils.adreno_utils import gpu_preprocess, build_run_compare
+from utils.adreno_utils import gpu_preprocess, build_run_compare, 
build_run_compare_vm
 
 
+executor_type = tvm.testing.parameter("ge", "vm")
 dtype = tvm.testing.parameter("float32")
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_mean(remote, target, dtype):
+def test_mean(remote, target, executor_type, dtype):
     # NCHW
     input_shape = (1, 3, 720, 1280)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     mean = relay.mean(A, axis=1, keepdims=True)
     mod = relay.Function([A], mean)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_argmax(remote, target, dtype):
+def test_argmax(remote, target, executor_type, dtype):
     # NCHW
     input_shape = (1, 3, 720, 1280)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     argmax = relay.op.argmax(A, axis=[1])
     mod = relay.Function([A], argmax)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_reduction_max(remote, target, dtype):
+def test_reduction_max(remote, target, executor_type, dtype):
     # NCHW
     input_shape = (1, 3, 720, 1280)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     argmax = relay.op.max(A, axis=[1])
     mod = relay.Function([A], argmax)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_mean_nd4(remote, target, dtype):
+def test_mean_nd4(remote, target, executor_type, dtype):
     # NCHW
     input_shape = (1, 3, 729, 729)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     mean = relay.mean(A, axis=1, keepdims=True)
     mod = relay.Function([A], mean)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_argmax_nd4(remote, target, dtype):
+def test_argmax_nd4(remote, target, executor_type, dtype):
     # NCHW
     input_shape = (1, 3, 729, 729)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     argmax = relay.op.argmax(A, axis=[1])
     mod = relay.Function([A], argmax)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_reduction_max_nd4(remote, target, dtype):
+def test_reduction_max_nd4(remote, target, executor_type, dtype):
     # NCHW
     input_shape = (1, 3, 729, 729)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     argmax = relay.op.max(A, axis=[1])
     mod = relay.Function([A], argmax)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_mean_b4(remote, target, dtype):
+def test_mean_b4(remote, target, executor_type, dtype):
     # NCHW
     input_shape = (1, 3, 720, 320, 4)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     mean = relay.mean(A, axis=1, keepdims=True)
     mod = relay.Function([A], mean)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_argmax_b4(remote, target, dtype):
+def test_argmax_b4(remote, target, executor_type, dtype):
     # NCHW
     input_shape = (1, 3, 720, 320, 4)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     argmax = relay.op.argmax(A, axis=[1])
     mod = relay.Function([A], argmax)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_reduction_max_b4(remote, target, dtype):
+def test_reduction_max_b4(remote, target, executor_type, dtype):
     # NCHW
     input_shape = (1, 3, 720, 320, 4)
     A = relay.var("data", shape=input_shape, dtype=dtype)
     argmax = relay.op.max(A, axis=[1])
     mod = relay.Function([A], argmax)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_mean_global_pooling(remote, target, dtype):
+def test_mean_global_pooling(remote, target, executor_type, dtype):
     """
     Use case of blocked NCHW4c global pooling with big spatial valies
     """
@@ -146,12 +174,15 @@ def test_mean_global_pooling(remote, target, dtype):
     mean = relay.mean(A, axis=[1, 2], keepdims=True)
     mod = relay.Function([A], mean)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_mean_global_pooling_block4(remote, target, dtype):
+def test_mean_global_pooling_block4(remote, target, executor_type, dtype):
     """
     Use case of blocked NCHW4c global pooling with big spatial valies
     """
@@ -160,12 +191,15 @@ def test_mean_global_pooling_block4(remote, target, 
dtype):
     mean = relay.mean(A, axis=[1, 2], keepdims=True)
     mod = relay.Function([A], mean)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
 @tvm.testing.parametrize_targets("opencl -device=adreno")
-def test_max_global_pooling_block4(remote, target, dtype):
+def test_max_global_pooling_block4(remote, target, executor_type, dtype):
     """
     Use case of blocked NCHW4c global pooling with big spatial valies
     """
@@ -174,7 +208,10 @@ def test_max_global_pooling_block4(remote, target, dtype):
     mean = relay.max(A, axis=[1, 2], keepdims=True)
     mod = relay.Function([A], mean)
 
-    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, 
target)
+    if executor_type == "ge":
+        build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": 
dtype}, target)
 
 
 @tvm.testing.requires_opencl
diff --git a/tests/python/relay/opencl_texture/utils/adreno_utils.py 
b/tests/python/relay/opencl_texture/utils/adreno_utils.py
index e2a271d9f6..de325d822c 100644
--- a/tests/python/relay/opencl_texture/utils/adreno_utils.py
+++ b/tests/python/relay/opencl_texture/utils/adreno_utils.py
@@ -26,6 +26,7 @@ from tvm.contrib import utils, ndk
 from tvm.relay import testing
 from tvm.relay.transform import recast
 from tvm.contrib import graph_runtime
+from tvm.runtime.vm import VirtualMachine
 import json
 
 
@@ -122,6 +123,89 @@ def build_run_compare(
     return graph
 
 
+def build_run_compare_vm(
+    remote,
+    tvm_mod,
+    params1,
+    input_shape,
+    dtypes,
+    target="llvm",
+    static_mem_scopes=[],
+    gpu_preprocess=None,
+    stat_file=None,
+):
+    if remote is None:
+        target_host = "llvm"
+    else:
+        target_host = "llvm -mtriple=arm64-linux-android"
+
+    if gpu_preprocess:
+        tvm_mod_nchwc = gpu_preprocess(tvm_mod)
+    else:
+        tvm_mod_nchwc = tvm_mod
+
+    if isinstance(tvm_mod_nchwc, relay.Function):
+        module = tvm.IRModule({})
+        module["main"] = tvm_mod_nchwc
+        tvm_mod_nchwc = module
+
+    if stat_file is not None:
+        with autotvm.apply_history_best(stat_file):
+            with tvm.transform.PassContext(opt_level=3):
+                vmc = relay.vm.compile(
+                    tvm_mod_nchwc, target=target, target_host=target_host, 
params=params1
+                )
+    else:
+        with tvm.transform.PassContext(opt_level=3):
+            vmc = relay.vm.compile(
+                tvm_mod_nchwc, target=target, target_host=target_host, 
params=params1
+            )
+
+    # TODO(echuraev): enable scope checking
+    ## verification that storage_scope has expected textures scopes
+    # graph_json = json.loads(graph)
+    # if "storage_scope" in graph_json["attrs"]:
+    #    assert (
+    #        len(static_mem_scopes) == 
len(graph_json["attrs"]["storage_scope"][1])
+    #        or len(static_mem_scopes) == 0
+    #    )
+    # else:
+    #    assert len(static_mem_scopes) == 0
+
+    # for i in range(0, len(static_mem_scopes)):
+    #    assert static_mem_scopes[i] == 
graph_json["attrs"]["storage_scope"][1][i]
+
+    if remote is None:
+        dev = tvm.opencl()
+        vm = VirtualMachine(vmc, dev, "naive")
+    else:
+        temp = utils.tempdir()
+        dso_binary = "dev_lib_cl.so"
+        dso_binary_path = temp.relpath(dso_binary)
+        dev = remote.cl(0)
+        vmc.mod.export_library(dso_binary_path, ndk.create_shared)
+        remote.upload(dso_binary_path)
+        rlib = remote.load_module(dso_binary)
+        vm = VirtualMachine(rlib, dev, "naive")
+    data = {}
+    inputs = []
+    for key in input_shape:
+        
inputs.append(np.random.normal(size=input_shape[key]).astype(dtypes[key]))
+        data[key] = tvm.nd.array(inputs[-1], dev)
+    for k, v in params1.items():
+        data[k] = tvm.nd.array(v, dev)
+    vm.set_input("main", **data)
+    vm.invoke_stateful("main")
+
+    ref_outputs = get_cpu_reference(tvm_mod, params1, input_shape, inputs)
+    for i, ref_output in enumerate(ref_outputs):
+        tvm_output = vm.get_outputs()[i]
+        output = tvm_output.asnumpy()
+
+        np.testing.assert_allclose(output, ref_output, rtol=1e-1, atol=1e-1)
+    return vmc
+
+
 def gpu_preprocess(tvm_mod):
     layout_config = relay.transform.LayoutConfig()
     desired_layouts = {"nn.conv2d": ["NCHW4c", "OIHW4o"]}
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py 
b/tests/python/relay/test_pass_dead_code_elimination.py
index 68d2919ec3..70dc1dd4f7 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -16,8 +16,10 @@
 # under the License.
 import tvm
 import tvm.testing
+from tvm import relay
 from tvm.relay import Function, transform
 from tvm.relay.testing import inception_v3
+import numpy as np
 import pytest
 
 cpu_scope = tvm.target.VirtualDevice(tvm.cpu(), tvm.target.Target("llvm"))
@@ -228,6 +230,11 @@ def test_inline_into_function():
 
 
 def test_impure_op():
+    shape = np.array([64, 2])
+    metatable = {
+        "VirtualDevice": [cpu_scope],
+        "relay.Constant": [relay.const(shape, dtype="int64")],
+    }
     """Don't elide calls to side-effecting operators."""
     before_program = tvm.relay.parse(
         """
@@ -235,7 +242,7 @@ def test_impure_op():
         def @main() {
            let %size: int64 = cast(1024, dtype="int64");
            let %alignment: int64 = cast(64, dtype="int64");
-           let %x = memory.alloc_storage(%size, %alignment, 
virtual_device=meta[VirtualDevice][0]);
+           let %x = memory.alloc_storage(%size, meta[relay.Constant][0], 
%alignment, virtual_device=meta[VirtualDevice][0]);
            let %_ = memory.kill(%x);
            0
         }
@@ -250,6 +257,7 @@ def test_impure_op():
         #[version = "0.0.5"]
         def @main() {
            %0 = memory.alloc_storage(cast(1024, dtype="int64"),
+                                     meta[relay.Constant][0],
                                      cast(64, dtype="int64"),
                                      virtual_device=meta[VirtualDevice][0]);
            let %_ = memory.kill(%0);
@@ -267,6 +275,11 @@ def test_impure_op():
 
 
 def test_impure_func():
+    shape = np.array([64, 2])
+    metatable = {
+        "VirtualDevice": [cpu_scope],
+        "relay.Constant": [relay.const(shape, dtype="int64")],
+    }
     """Don't elide calls to side-effecting functions."""
     before_program = tvm.relay.parse(
         """
@@ -274,7 +287,7 @@ def test_impure_func():
         def @f() -> int {
            let %size: int64 = cast(1024, dtype="int64");
            let %alignment: int64 = cast(64, dtype="int64");
-           let %x = memory.alloc_storage(%size, %alignment, 
virtual_device=meta[VirtualDevice][0]);
+           let %x = memory.alloc_storage(%size, meta[relay.Constant][0], 
%alignment, virtual_device=meta[VirtualDevice][0]);
            let %_ = memory.kill(%x);
            0
         }
@@ -293,6 +306,7 @@ def test_impure_func():
         #[version = "0.0.5"]
         def @f() -> int {
            %0 = memory.alloc_storage(cast(1024, dtype="int64"),
+                                     meta[relay.Constant][0],
                                      cast(64, dtype="int64"),
                                      virtual_device=meta[VirtualDevice][0]);
            let %_ = memory.kill(%0);
diff --git a/tests/python/relay/test_pass_plan_devices.py 
b/tests/python/relay/test_pass_plan_devices.py
index c7f42103ca..f654b4b453 100644
--- a/tests/python/relay/test_pass_plan_devices.py
+++ b/tests/python/relay/test_pass_plan_devices.py
@@ -761,14 +761,18 @@ def test_shape_of():
 
 
 def test_alloc_storage():
-    metatable = {"VirtualDevice": [HOST, GPU]}
+    shape = np.array([3, 2])
+    metatable = {
+        "VirtualDevice": [HOST, GPU],
+        "relay.Constant": [relay.const(shape, dtype="int64")],
+    }
 
     def input():
         return tvm.relay.parse(
             """
             #[version = "0.0.5"]
             def @main(%size: int64, %alignment: int64) {
-              memory.alloc_storage(%size, %alignment, 
virtual_device=meta[VirtualDevice][1])
+              memory.alloc_storage(%size, meta[relay.Constant][0], %alignment, 
virtual_device=meta[VirtualDevice][1])
             }
         """,
             "from_string",
@@ -782,7 +786,8 @@ def test_alloc_storage():
             #[version = "0.0.5"]
             def @main(%size {virtual_device=meta[VirtualDevice][0]}: int64, 
%alignment {virtual_device=meta[VirtualDevice][0]}: int64,
                       virtual_device=meta[VirtualDevice][1]) {
-              memory.alloc_storage(%size, %alignment, 
virtual_device=meta[VirtualDevice][1])
+              %0 = on_device(meta[relay.Constant][0], 
virtual_device=meta[VirtualDevice][0], constrain_result=True);
+              memory.alloc_storage(%size, %0, %alignment, 
virtual_device=meta[VirtualDevice][1])
             }
         """,
             "from_string",

Reply via email to