szha closed pull request #10083: [TENSOR] Fix DLTensor conversion for int64
URL: https://github.com/apache/incubator-mxnet/pull/10083
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h
index 59c1eacb2c5..6f604a5bb8d 100755
--- a/include/mxnet/tensor_blob.h
+++ b/include/mxnet/tensor_blob.h
@@ -322,16 +322,19 @@ class TBlob {
 
  private:
   static DLDataType DTypeTransform(int type_flag) {
-    static std::unordered_map<int, DLDataType>
-      MSHADOW_DTYPE_TO_DLPACK_DTYPE = {
-        {0, {2, 32, 1}},  // Float32
-        {1, {2, 64, 1}},  // Float64
-        {2, {2, 16, 1}},  // Float16
-        {3, {1,  8, 1}},  // UInt8
-        {4, {0, 32, 1}},  // Int32
-        {5, {0,  8, 1}}   // Int8
-      };
-    return MSHADOW_DTYPE_TO_DLPACK_DTYPE[type_flag];
+    switch (type_flag) {
+      case mshadow::kFloat32: return DLDataType{kDLFloat, 32, 1};
+      case mshadow::kFloat64: return DLDataType{kDLFloat, 64, 1};
+      case mshadow::kFloat16: return DLDataType{kDLFloat, 16, 1};
+      case mshadow::kUint8: return DLDataType{kDLUInt, 8, 1};
+      case mshadow::kInt32: return DLDataType{kDLInt, 32, 1};
+      case mshadow::kInt8: return DLDataType{kDLInt, 8, 1};
+      case mshadow::kInt64: return DLDataType{kDLInt, 64, 1};
+      default: {
+        LOG(FATAL) << "Unknown type_flag=" << type_flag;
+        return DLDataType();
+      }
+    }
   }
 
   inline void SetDLTensor(int dev_mask, int dev_id) {
diff --git a/tests/python/gpu/test_tvm_bridge.py 
b/tests/python/gpu/test_tvm_bridge.py
index 292b9d91e5f..69a713d6a28 100644
--- a/tests/python/gpu/test_tvm_bridge.py
+++ b/tests/python/gpu/test_tvm_bridge.py
@@ -30,13 +30,13 @@ def test_tvm_bridge():
         logging.warn("TVM bridge test skipped because TVM is missing...")
         return
 
-    def check(target):
+    def check(target, dtype):
         shape = (20,)
         scale = tvm.var("scale", dtype="float32")
-        x = tvm.placeholder(shape)
-        y = tvm.placeholder(shape)
+        x = tvm.placeholder(shape, dtype=dtype)
+        y = tvm.placeholder(shape, dtype=dtype)
         z = tvm.compute(shape, lambda i: x[i] + y[i])
-        zz = tvm.compute(shape, lambda *i: z(*i) * scale)
+        zz = tvm.compute(shape, lambda *i: z(*i) * scale.astype(dtype))
         ctx = mx.gpu(0) if target == "cuda" else mx.cpu(0)
         target = tvm.target.create(target)
 
@@ -47,17 +47,18 @@ def check(target):
 
         # get a mxnet version
         mxf = tvm.contrib.mxnet.to_mxnet_func(f, const_loc=[0, 1])
-        xx = mx.nd.uniform(shape=shape, ctx=ctx)
-        yy = mx.nd.uniform(shape=shape, ctx=ctx)
-        zz = mx.nd.empty(shape=shape, ctx=ctx)
+        xx = mx.nd.uniform(shape=shape, ctx=ctx).astype(dtype)
+        yy = mx.nd.uniform(shape=shape, ctx=ctx).astype(dtype)
+        zz = mx.nd.empty(shape=shape, ctx=ctx).astype(dtype)
         # invoke myf: this runs in mxnet engine
         mxf(xx, yy, zz, 10.0)
         np.testing.assert_allclose(
             zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10)
 
-    check("llvm")
-    check("cuda")
-
+    for tgt in ["llvm", "cuda"]:
+        for dtype in ["int8", "uint8", "int64",
+                      "float32", "float64"]:
+            check(tgt, dtype)
 
 
 if __name__ == "__main__":


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to