Author: Mehdi Amini Date: 2024-05-29T23:20:35-06:00 New Revision: 52ef9864abecea0cf8d20e7eaf49c256248af5f7
URL: https://github.com/llvm/llvm-project/commit/52ef9864abecea0cf8d20e7eaf49c256248af5f7 DIFF: https://github.com/llvm/llvm-project/commit/52ef9864abecea0cf8d20e7eaf49c256248af5f7.diff LOG: Revert "[MLIR][Python] add ctype python binding support for bf16 (#92489)" This reverts commit 89801c74c3e25f5a1eaa3999863be398f6a82abb. Added: Modified: mlir/python/mlir/runtime/np_to_memref.py mlir/python/requirements.txt mlir/test/python/execution_engine.py Removed: ################################################################################ diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 882b2751921bf..f6b706f9bc8ae 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -7,12 +7,6 @@ import numpy as np import ctypes -try: - import ml_dtypes -except ModuleNotFoundError: - # The third-party ml_dtypes provides some optional low precision data-types for NumPy. - ml_dtypes = None - class C128(ctypes.Structure): """A ctype representation for MLIR's Double Complex.""" @@ -32,12 +26,6 @@ class F16(ctypes.Structure): _fields_ = [("f16", ctypes.c_int16)] -class BF16(ctypes.Structure): - """A ctype representation for MLIR's BFloat16.""" - - _fields_ = [("bf16", ctypes.c_int16)] - - # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): """Converts dtype to ctype.""" @@ -47,8 +35,6 @@ def as_ctype(dtp): return C64 if dtp == np.dtype(np.float16): return F16 - if ml_dtypes is not None and dtp == ml_dtypes.bfloat16: - return BF16 return np.ctypeslib.as_ctypes_type(dtp) @@ -60,11 +46,6 @@ def to_numpy(array): return array.view("complex64") if array.dtype == F16: return array.view("float16") - assert not ( - array.dtype == BF16 and ml_dtypes is None - ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" - if array.dtype == BF16: - return array.view("bfloat16") return array diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 6ec63e43adf89..acd6dbb25edaf 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,3 @@ numpy>=1.19.5, <=1.26 pybind11>=2.9.0, <=2.10.3 -PyYAML>=5.3.1, <=6.0.1 -ml_dtypes # provides several NumPy dtype extensions, including the bf16 \ No newline at end of file +PyYAML>=5.3.1, <=6.0.1 \ No newline at end of file diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index 8125bf3fb8fc9..e8b47007a8907 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -5,7 +5,6 @@ from mlir.passmanager import * from mlir.execution_engine import * from mlir.runtime import * -from ml_dtypes import bfloat16 # Log everything to stderr and flush so that we have a unified stream to match @@ -522,45 +521,6 @@ def testComplexUnrankedMemrefAdd(): run(testComplexUnrankedMemrefAdd) -# Test bf16 memrefs -# CHECK-LABEL: TEST: testBF16Memref -def testBF16Memref(): - with Context(): - module = Module.parse( - """ - module { - func.func @main(%arg0: memref<1xbf16>, - %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } { - %0 = arith.constant 0 : index - %1 = memref.load %arg0[%0] : memref<1xbf16> - memref.store %1, %arg1[%0] : memref<1xbf16> - return - } - } """ - ) - - arg1 = np.array([0.5]).astype(bfloat16) - arg2 = np.array([0.0]).astype(bfloat16) - - arg1_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg1)) - ) - arg2_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg2)) - ) - - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) - - # test to-numpy utility - # CHECK: [0.5] - npout = ranked_memref_to_numpy(arg2_memref_ptr[0]) - log(npout) - - -run(testBF16Memref) - - # Test addition of two 2d_memref # CHECK-LABEL: TEST: testDynamicMemrefAdd2D def testDynamicMemrefAdd2D(): _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits