masahi commented on a change in pull request #7626:
URL: https://github.com/apache/tvm/pull/7626#discussion_r591162725
##########
File path: src/auto_scheduler/search_task.cc
##########
@@ -106,6 +106,28 @@ HardwareParams
HardwareParamsNode::GetDefaultHardwareParams(const Target& target
auto target_device = target->GetAttr<String>("device", "");
LOG(FATAL) << "No default hardware parameters for opencl target device:
" << target_device;
}
+ } else if (device_type == kDLVulkan) {
+ auto ctx = TVMContext{static_cast<DLDeviceType>(device_type), 0};
+ auto device_name = "device_api.vulkan";
+ auto func = tvm::runtime::Registry::Get(device_name);
+ ICHECK(func != nullptr) << "Cannot find Vulkan device_api in registry";
+ auto device_api =
static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
+
+ tvm::runtime::TVMRetValue ret;
+ device_api->GetAttr(ctx,
tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
+ int max_shared_memory_per_block = ret;
+
+ int max_local_memory_per_block = INT32_MAX;
+
+ device_api->GetAttr(ctx,
tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
+ int max_threads_per_block = ret;
+
+ device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
+ int warp_size = ret;
+
+ int max_vthread_extent = warp_size / 4;
Review comment:
A good catch, indeed the spec only requires warp size to be greater than
or equal to 1. In practice it is always greater than 1, but I'll update this to
`int max_vthread_extent = std::max(1, warp_size / 4)`.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]