[
https://issues.apache.org/jira/browse/ARROW-2039?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel=16365860#comment-16365860
]
ASF GitHub Bot commented on ARROW-2039:
---
wesm closed pull request #1605: ARROW-2039: [Python] Avoid crashing on
uninitialized Buffer
URL: https://github.com/apache/arrow/pull/1605
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/pyarrow/io.pxi b/python/pyarrow/io.pxi
index b0996f85e..aa2f7ed07 100644
--- a/python/pyarrow/io.pxi
+++ b/python/pyarrow/io.pxi
@@ -595,22 +595,30 @@ cdef class Buffer:
self.shape[0] = self.size
self.strides[0] = (1)
+cdef int _check_nullptr(self) except -1:
+if self.buffer.get() == NULL:
+raise ReferenceError("operation on uninitialized Buffer object")
+return 0
+
def __len__(self):
return self.size
property size:
def __get__(self):
+self._check_nullptr()
return self.buffer.get().size()
property is_mutable:
def __get__(self):
+self._check_nullptr()
return self.buffer.get().is_mutable()
property parent:
def __get__(self):
+self._check_nullptr()
cdef shared_ptr[CBuffer] parent_buf = self.buffer.get().parent()
if parent_buf.get() == NULL:
@@ -620,6 +628,7 @@ cdef class Buffer:
def __getitem__(self, key):
# TODO(wesm): buffer slicing
+self._check_nullptr()
raise NotImplementedError
def equals(self, Buffer other):
@@ -634,17 +643,21 @@ cdef class Buffer:
---
are_equal : True if buffer contents and size are equal
"""
+self._check_nullptr()
+other._check_nullptr()
cdef c_bool result = False
with nogil:
result = self.buffer.get().Equals(deref(other.buffer.get()))
return result
def to_pybytes(self):
+self._check_nullptr()
return cp.PyBytes_FromStringAndSize(
self.buffer.get().data(),
self.buffer.get().size())
def __getbuffer__(self, cp.Py_buffer* buffer, int flags):
+self._check_nullptr()
buffer.buf = self.buffer.get().data()
buffer.format = 'b'
@@ -662,11 +675,13 @@ cdef class Buffer:
buffer.suboffsets = NULL
def __getsegcount__(self, Py_ssize_t *len_out):
+self._check_nullptr()
if len_out != NULL:
len_out[0] = self.size
return 1
def __getreadbuffer__(self, Py_ssize_t idx, void **p):
+self._check_nullptr()
if idx != 0:
raise SystemError("accessing non-existent buffer segment")
if p != NULL:
@@ -674,6 +689,7 @@ cdef class Buffer:
return self.size
def __getwritebuffer__(self, Py_ssize_t idx, void **p):
+self._check_nullptr()
if not self.buffer.get().is_mutable():
raise SystemError("trying to write an immutable buffer")
if idx != 0:
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index b1433ecde..31732a6e0 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -321,6 +321,7 @@ cdef class Buffer:
Py_ssize_t strides[1]
cdef void init(self, const shared_ptr[CBuffer]& buffer)
+cdef int _check_nullptr(self) except -1
cdef class ResizableBuffer(Buffer):
diff --git a/python/pyarrow/tests/test_io.py b/python/pyarrow/tests/test_io.py
index 0947cb7c7..736020f60 100644
--- a/python/pyarrow/tests/test_io.py
+++ b/python/pyarrow/tests/test_io.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+from functools import partial
from io import BytesIO, TextIOWrapper
import gc
import os
@@ -176,6 +177,8 @@ def test_buffer_to_numpy():
buf = pa.frombuffer(byte_array)
array = np.frombuffer(buf, dtype="uint8")
assert array[0] == byte_array[0]
+byte_array[0] += 1
+assert array[0] == byte_array[0]
assert array.base == buf
@@ -192,6 +195,25 @@ def test_buffer_from_numpy():
buf = pa.frombuffer(arr.T[::2])
+def test_buffer_equals():
+# Buffer.equals() returns true iff the buffers have the same contents
+b1 = b'some data!'
+b2 = bytearray(b1)
+b3 = bytearray(b1)
+b3[0] = 42
+buf1 = pa.frombuffer(b1)
+buf2 = pa.frombuffer(b2)
+buf3 = pa.frombuffer(b2)
+buf4 = pa.frombuffer(b3)
+buf5 = pa.frombuffer(np.frombuffer(b2, dtype=np.int16))
+assert buf1.equals(buf1)
+assert buf1.equals(buf2)
+assert buf2.equals(buf3)
+assert not buf2.equals(buf4)
+# Data type is indifferent
+assert buf2.equals(buf5)
+
+
def