This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0a3e178 [metal] update language version (#7116)
0a3e178 is described below
commit 0a3e1783b30910f4496e437b0ebefd998bf5a935
Author: Bing Xu <[email protected]>
AuthorDate: Tue Dec 15 23:24:46 2020 -0800
[metal] update language version (#7116)
* [metal] update language version
* fix mps
---
src/runtime/contrib/mps/conv.mm | 9 ++++++---
src/runtime/metal/metal_module.mm | 3 +--
2 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm
index 3b16f08..b860ee2 100644
--- a/src/runtime/contrib/mps/conv.mm
+++ b/src/runtime/contrib/mps/conv.mm
@@ -34,7 +34,8 @@
TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img").set_body([](TVMArgs args, TVMR
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx);
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
entry_ptr->metal_api->CopyDataFromTo((__bridge void*)mtlbuf, 0, (__bridge
void*)temp, 0,
- [mtlbuf length], buf -> ctx, buf ->
ctx, nullptr);
+ [mtlbuf length], buf -> ctx, buf ->
ctx, buf -> dtype,
+ nullptr);
MPSImageDescriptor* desc =
[MPSImageDescriptor
imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
@@ -69,7 +70,8 @@
TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer").set_body([](TVMArgs args, TVMR
imageIndex:0];
entry_ptr->metal_api->CopyDataFromTo((__bridge void*)temp, 0, (__bridge
void*)mtlbuf, 0,
- [mtlbuf length], buf -> ctx, buf ->
ctx, nullptr);
+ [mtlbuf length], buf -> ctx, buf ->
ctx, buf -> dtype,
+ nullptr);
});
TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args,
TVMRetValue* ret) {
@@ -111,7 +113,8 @@
TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, TVMRetVa
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data);
id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]);
entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge
void*)tempB, 0,
- [bufB length], weight -> ctx, weight ->
ctx, nullptr);
+ [bufB length], weight -> ctx, weight ->
ctx, tmp_in.dtype,
+ nullptr);
float* ptr_w = (float*)[tempB contents];
// output to MPSImage
DLTensor tmp_out;
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index 7d46811..981dd61 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -88,8 +88,7 @@ class MetalModuleNode final : public runtime::ModuleNode {
if (e.lib == nil) {
if (fmt_ == "metal") {
MTLCompileOptions* opts = [MTLCompileOptions alloc];
- // Use the Metal 1.2 for now.
- opts.languageVersion = MTLLanguageVersion1_2;
+ opts.languageVersion = MTLLanguageVersion2_3;
opts.fastMathEnabled = YES;
// opts = nil;
e.lib = [w->devices[device_id]