jorisvandenbossche commented on code in PR #14395:
URL: https://github.com/apache/arrow/pull/14395#discussion_r1016297190
##########
python/pyarrow/tests/test_compute.py:
##########
@@ -2929,3 +2930,114 @@ def test_cast_table_raises():
with pytest.raises(pa.lib.ArrowInvalid):
pc.cast(table, pa.int64())
+
+
[email protected]("start,stop,expected", (
+ (0, 1, [[1], [4], [6], None]),
+ (0, 2, [[1, 2], [4, 5], [6, None], None]),
+ (1, 2, [[2], [5], [None], None]),
+ (2, 4, [[3, None], [None, None], [None, None], None])
+))
[email protected]("value_type", (pa.string, pa.int16, pa.float64))
[email protected]("list_type", (pa.list_, pa.large_list, "fixed"))
+def test_list_slice_output_fixed(start, stop, expected, value_type, list_type):
+ if list_type == "fixed":
+ arr = pa.array([[1, 2, 3], [4, 5, None], [6, None, None], None],
+ pa.list_(pa.int8(), 3)).cast(pa.list_(value_type(), 3))
+ else:
+ arr = pa.array([[1, 2, 3], [4, 5], [6], None],
+ pa.list_(pa.int8())).cast(list_type(value_type()))
+ result = pc.list_slice(arr, start, stop, return_fixed_size_list=True)
+ pylist = result.cast(pa.list_(pa.int8(), stop-start)).to_pylist()
+ assert pylist == expected
+
+
[email protected]("start,stop", (
+ (0, 1,),
+ (0, 2,),
+ (1, 2,),
+ (2, 4,)
+))
[email protected]("value_type", (pa.string, pa.int16, pa.float64))
[email protected]("list_type", (pa.list_, pa.large_list, "fixed"))
+def test_list_slice_output_variable(start, stop, value_type, list_type):
+ if list_type == "fixed":
+ data = [[1, 2, 3], [4, 5, None], [6, None, None], None]
+ arr = pa.array(
+ data,
+ pa.list_(pa.int8(), 3)).cast(pa.list_(value_type(), 3))
+ else:
+ data = [[1, 2, 3], [4, 5], [6], None]
+ arr = pa.array(data,
+ pa.list_(pa.int8())).cast(list_type(value_type()))
+
+ # Gets same list type (ListArray vs LargeList)
+ if list_type == "fixed":
+ list_type = pa.list_ # non fixed output type
+
+ result = pc.list_slice(arr, start, stop, return_fixed_size_list=False)
+ assert result.type == list_type(value_type())
+
+ pylist = result.cast(pa.list_(pa.int8())).to_pylist()
+
+ # Variable output slicing follows Python's slice semantics
+ expected = [d[start:stop] if d is not None else None for d in data]
+ assert pylist == expected
+
+
[email protected]("return_fixed_size", (True, False, None))
[email protected]("type", (
+ lambda: pa.list_(pa.field('col', pa.int8())),
+ lambda: pa.list_(pa.field('col', pa.int8()), 1),
+ lambda: pa.large_list(pa.field('col', pa.int8()))))
+def test_list_slice_field_names_retained(return_fixed_size, type):
+ arr = pa.array([[1]], type())
+ out = pc.list_slice(arr, 0, 1, return_fixed_size_list=return_fixed_size)
+ assert arr.type.field(0).name == out.type.field(0).name
+
+ # Verify out type matches in type if return_fixed_size_list==None
+ if return_fixed_size is None:
+ assert arr.type == out.type
+
+
+def test_list_slice_bad_parameters():
+ arr = pa.array([[1]], pa.list_(pa.int8(), 1))
+ msg = r"`start`(.*) should be greater than 0 and smaller than `stop`(.*)"
+ with pytest.raises(pa.ArrowInvalid, match=msg):
+ pc.list_slice(arr, -1, 1) # negative start?
+ with pytest.raises(pa.ArrowInvalid, match=msg):
+ pc.list_slice(arr, 2, 1) # start > stop?
+
+ # TODO(ARROW-18281): start==stop -> empty lists
+ with pytest.raises(pa.ArrowInvalid, match=msg):
+ pc.list_slice(arr, 0, 0) # start == stop?
+
+ # TODO(ARROW-18282): support step in slicing
+ msg = "Setting `step` to anything other than 1 is not supported; "\
+ "got step=2"
+ with pytest.raises(NotImplementedError, match=msg):
+ pc.list_slice(arr, 0, 1, step=2)
+
+ # TODO(ARROW-18280): support stop == None; slice to end
+ # This fails first at resolve, b/c it doesn't now how big the
+ # resulting FixedSizeListArray item size will be
+ msg = "Unable to produce FixedSizeListArray without `stop`"
+ with pytest.raises(NotImplementedError, match=msg):
+ pc.list_slice(arr, 0, return_fixed_size_list=True)
+
+ # cont. This fails inside of kernel function; resolver doesn't
+ # need to know the item size for ListArray.
+ msg = "Slicing to end not yet implemented*"
+ with pytest.raises(NotImplementedError, match=msg):
+ pc.list_slice(arr, 0, return_fixed_size_list=False)
+
+
+def test_list_slice_non_nulls():
+ # potential for segfault if c++ attempts to take a validity buffer
+ # which might not exist
+ arr = pa.array([[1]], pa.list_(pa.int8()))
+ pc.list_slice(arr, 0, 1, return_fixed_size_list=True)
+ pc.list_slice(arr, 0, 1, return_fixed_size_list=False)
+
+ arr = pa.array([[1]], pa.list_(pa.int8(), 1))
+ pc.list_slice(arr, 0, 1, return_fixed_size_list=True)
Review Comment:
Do we need to keep this one? (it's also covered indirectly by other tests?)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]