This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 96df2c5 Add bytearray support back to imdecode (#12855, #12868)
(#12912)
96df2c5 is described below
commit 96df2c5760ce9a1c1238f1f41228486a583de7ff
Author: Frank Liu <[email protected]>
AuthorDate: Wed Oct 24 17:28:48 2018 -0700
Add bytearray support back to imdecode (#12855, #12868) (#12912)
1. Avoid raise exception when input is bytearray.
2. Avoid OpenCV crash for empty input.
3. Added unittests.
---
python/mxnet/image/image.py | 11 ++++++++---
tests/python/unittest/test_image.py | 16 ++++++++++++++++
2 files changed, 24 insertions(+), 3 deletions(-)
diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py
index eee2ccf..b846700 100644
--- a/python/mxnet/image/image.py
+++ b/python/mxnet/image/image.py
@@ -93,7 +93,7 @@ def imdecode(buf, *args, **kwargs):
Parameters
----------
- buf : str/bytes or numpy.ndarray
+ buf : str/bytes/bytearray or numpy.ndarray
Binary image data as string or numpy ndarray.
flag : int, optional, default=1
1 for three channel color output. 0 for grayscale output.
@@ -135,10 +135,15 @@ def imdecode(buf, *args, **kwargs):
<NDArray 224x224x3 @cpu(0)>
"""
if not isinstance(buf, nd.NDArray):
- if sys.version_info[0] == 3 and not isinstance(buf, (bytes,
np.ndarray)):
- raise ValueError('buf must be of type bytes or numpy.ndarray,'
+ if sys.version_info[0] == 3 and not isinstance(buf, (bytes, bytearray,
np.ndarray)):
+ raise ValueError('buf must be of type bytes, bytearray or
numpy.ndarray,'
'if you would like to input type str, please
convert to bytes')
buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8)
+
+ if len(buf) == 0:
+ # empty buf causes OpenCV crash.
+ raise ValueError("input buf cannot be empty.")
+
return _internal._cvimdecode(buf, *args, **kwargs)
diff --git a/tests/python/unittest/test_image.py
b/tests/python/unittest/test_image.py
index 0df08af..c8022b6 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -92,6 +92,22 @@ class TestImage(unittest.TestCase):
cv_image = cv2.imread(img)
assert_almost_equal(image.asnumpy(), cv_image)
+ def test_imdecode_bytearray(self):
+ try:
+ import cv2
+ except ImportError:
+ return
+ for img in TestImage.IMAGES:
+ with open(img, 'rb') as fp:
+ str_image = bytearray(fp.read())
+ image = mx.image.imdecode(str_image, to_rgb=0)
+ cv_image = cv2.imread(img)
+ assert_almost_equal(image.asnumpy(), cv_image)
+
+ @raises(ValueError)
+ def test_imdecode_empty_buffer(self):
+ mx.image.imdecode(b'', to_rgb=0)
+
def test_scale_down(self):
assert mx.image.scale_down((640, 480), (720, 120)) == (640, 106)
assert mx.image.scale_down((360, 1000), (480, 500)) == (360, 375)