westonpace commented on a change in pull request #11911:
URL: https://github.com/apache/arrow/pull/11911#discussion_r771119744



##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -3621,6 +3621,189 @@ def compare_tables_ignoring_order(t1, t2):
     assert not extra_file.exists()
 
 
+def _generate_random_int_array(size=4, min=1, max=10):
+    return np.random.randint(min, max, size)
+
+
+def _generate_data_and_columns(num_of_columns, num_of_records):
+    data = []
+    column_names = []
+    for i in range(num_of_columns):
+        data.append(_generate_random_int_array(size=num_of_records,
+                                               min=1,
+                                               max=num_of_records))
+        column_names.append("c" + str(i))
+    record_batch = pa.record_batch(data=data, names=column_names)
+    return record_batch
+
+
+def _get_num_of_files_generated(base_directory, file_format):
+    return len(list(pathlib.Path(base_directory).glob(f'**/*.{file_format}')))
+
+
+def test_write_dataset_max_rows_per_file(tempdir):
+    directory = tempdir / 'ds'
+    max_rows_per_file = 10
+    max_rows_per_group = 10
+    num_of_columns = 2
+    num_of_records = 35
+
+    record_batch = _generate_data_and_columns(num_of_columns,
+                                              num_of_records)
+
+    ds.write_dataset(record_batch, directory, format="parquet",
+                     max_rows_per_file=max_rows_per_file,
+                     max_rows_per_group=max_rows_per_group)
+
+    files_in_dir = os.listdir(directory)
+
+    # number of partitions with max_rows and the partition with the remainder
+    expected_partitions = num_of_records // max_rows_per_file + 1
+
+    # test whether the expected amount of files are written
+    assert len(files_in_dir) == expected_partitions
+
+    # compute the number of rows per each file written
+    result_row_combination = []
+    for _, f_file in enumerate(files_in_dir):
+        f_path = directory / str(f_file)
+        dataset = ds.dataset(f_path, format="parquet")
+        result_row_combination.append(dataset.to_table().shape[0])
+
+    # test whether the generated files have the expected number of rows
+    assert expected_partitions == len(result_row_combination)
+    assert num_of_records == sum(result_row_combination)
+
+
+def test_write_dataset_min_rows_per_group(tempdir):
+    directory = tempdir / 'ds'
+    min_rows_per_group = 20
+    max_rows_per_group = 20
+    num_of_columns = 2
+    num_of_records = 50
+
+    record_batch = _generate_data_and_columns(num_of_columns,
+                                              num_of_records)
+
+    data_source = directory / "min_rows_group"
+
+    ds.write_dataset(data=record_batch, base_dir=data_source,
+                     min_rows_per_group=min_rows_per_group,
+                     max_rows_per_group=max_rows_per_group,
+                     format="parquet")
+
+    files_in_dir = os.listdir(data_source)
+    batched_data = []
+    for _, f_file in enumerate(files_in_dir):
+        f_path = data_source / str(f_file)
+        dataset = ds.dataset(f_path, format="parquet")
+        table = dataset.to_table()
+        batches = table.to_batches()
+        for batch in batches:
+            batched_data.append(batch.num_rows)
+
+    assert batched_data[0] >= min_rows_per_group and \
+        batched_data[0] <= max_rows_per_group
+    assert batched_data[1] >= min_rows_per_group and \
+        batched_data[1] <= max_rows_per_group
+
+
+def test_write_dataset_max_rows_per_group(tempdir):
+    directory = tempdir / 'ds'
+    max_rows_per_group = 18
+    num_of_columns = 2
+    num_of_records = 30
+
+    record_batch = _generate_data_and_columns(num_of_columns,
+                                              num_of_records)
+
+    data_source = directory / "max_rows_group"
+
+    ds.write_dataset(data=record_batch, base_dir=data_source,
+                     max_rows_per_group=max_rows_per_group,
+                     format="parquet")
+
+    files_in_dir = os.listdir(data_source)
+    batched_data = []
+    for f_file in files_in_dir:
+        f_path = data_source / str(f_file)
+        dataset = ds.dataset(f_path, format="parquet")
+        table = dataset.to_table()
+        batches = table.to_batches()
+        for batch in batches:
+            batched_data.append(batch.num_rows)
+
+    assert batched_data == [18, 12]
+
+
+def test_write_dataset_max_open_files(tempdir):
+    directory = tempdir / 'ds'
+    file_format = "parquet"
+    partition_column_id = 1
+    column_names = ['c1', 'c2']
+    record_batch_1 = pa.record_batch(data=[[1, 2, 3, 4, 0, 10],
+                                     ['a', 'b', 'c', 'd', 'e', 'a']],
+                                     names=column_names)
+    record_batch_2 = pa.record_batch(data=[[5, 6, 7, 8, 0, 1],
+                                     ['a', 'b', 'c', 'd', 'e', 'c']],
+                                     names=column_names)
+    record_batch_3 = pa.record_batch(data=[[9, 10, 11, 12, 0, 1],
+                                     ['a', 'b', 'c', 'd', 'e', 'd']],
+                                     names=column_names)
+    record_batch_4 = pa.record_batch(data=[[13, 14, 15, 16, 0, 1],
+                                     ['a', 'b', 'c', 'd', 'e', 'b']],
+                                     names=column_names)
+
+    table = pa.Table.from_batches([record_batch_1, record_batch_2,
+                                   record_batch_3, record_batch_4])
+
+    partitioning = ds.partitioning(
+        pa.schema([(column_names[partition_column_id], pa.string())]),
+        flavor="hive")
+
+    data_source_1 = directory / "default"
+
+    ds.write_dataset(data=table, base_dir=data_source_1,
+                     partitioning=partitioning, format=file_format)
+
+    # Here we consider the number of unique partitions created when partition
+    # column contains duplicate records.
+    #   Returns: (number_of_files_generated, number_of_unique_rows)
+    def _get_compare_pair(data_source, record_batch, file_format, col_id):
+        num_of_files_generated = _get_num_of_files_generated(
+            base_directory=data_source, file_format=file_format)
+        number_of_unique_rows = len(pa.compute.unique(record_batch[col_id]))
+        return num_of_files_generated, number_of_unique_rows
+
+    # CASE 1: when max_open_files=default & max_open_files >= num_of_partitions
+    #         In case of a writing to disk via partitioning based on a
+    #         particular column (considering row labels in that column),
+    #         the number of unique rows must be equal
+    #         to the number of files generated
+
+    num_of_files_generated, number_of_unique_rows \
+        = _get_compare_pair(data_source_1, record_batch_1, file_format,
+                            partition_column_id)
+    assert num_of_files_generated == number_of_unique_rows
+
+    # CASE 2: when max_open_files > 0 & max_open_files < num_of_partitions
+    #         the number of unique rows must be equal to
+    #         the number of files generated

Review comment:
       This comment doesn't seem to line up with the assert on line 3804 
`assert num_of_files_generated > number_of_unique_rows`.  Also, I'm a little 
confused how `number_of_unique_rows` is related to the test.  Could you maybe 
name that variable `number_of_unique_partitions` (or even 
`number_of_partitions` would be fine)

##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -3621,6 +3621,189 @@ def compare_tables_ignoring_order(t1, t2):
     assert not extra_file.exists()
 
 
+def _generate_random_int_array(size=4, min=1, max=10):
+    return np.random.randint(min, max, size)
+
+
+def _generate_data_and_columns(num_of_columns, num_of_records):
+    data = []
+    column_names = []
+    for i in range(num_of_columns):
+        data.append(_generate_random_int_array(size=num_of_records,
+                                               min=1,
+                                               max=num_of_records))
+        column_names.append("c" + str(i))
+    record_batch = pa.record_batch(data=data, names=column_names)
+    return record_batch
+
+
+def _get_num_of_files_generated(base_directory, file_format):
+    return len(list(pathlib.Path(base_directory).glob(f'**/*.{file_format}')))
+
+
+def test_write_dataset_max_rows_per_file(tempdir):
+    directory = tempdir / 'ds'
+    max_rows_per_file = 10
+    max_rows_per_group = 10
+    num_of_columns = 2
+    num_of_records = 35
+
+    record_batch = _generate_data_and_columns(num_of_columns,
+                                              num_of_records)
+
+    ds.write_dataset(record_batch, directory, format="parquet",
+                     max_rows_per_file=max_rows_per_file,
+                     max_rows_per_group=max_rows_per_group)
+
+    files_in_dir = os.listdir(directory)
+
+    # number of partitions with max_rows and the partition with the remainder
+    expected_partitions = num_of_records // max_rows_per_file + 1
+
+    # test whether the expected amount of files are written
+    assert len(files_in_dir) == expected_partitions
+
+    # compute the number of rows per each file written
+    result_row_combination = []
+    for _, f_file in enumerate(files_in_dir):
+        f_path = directory / str(f_file)
+        dataset = ds.dataset(f_path, format="parquet")
+        result_row_combination.append(dataset.to_table().shape[0])
+
+    # test whether the generated files have the expected number of rows
+    assert expected_partitions == len(result_row_combination)
+    assert num_of_records == sum(result_row_combination)
+
+
+def test_write_dataset_min_rows_per_group(tempdir):
+    directory = tempdir / 'ds'
+    min_rows_per_group = 20
+    max_rows_per_group = 20
+    num_of_columns = 2
+    num_of_records = 50
+
+    record_batch = _generate_data_and_columns(num_of_columns,
+                                              num_of_records)
+
+    data_source = directory / "min_rows_group"
+
+    ds.write_dataset(data=record_batch, base_dir=data_source,
+                     min_rows_per_group=min_rows_per_group,
+                     max_rows_per_group=max_rows_per_group,
+                     format="parquet")
+
+    files_in_dir = os.listdir(data_source)
+    batched_data = []
+    for _, f_file in enumerate(files_in_dir):
+        f_path = data_source / str(f_file)
+        dataset = ds.dataset(f_path, format="parquet")
+        table = dataset.to_table()
+        batches = table.to_batches()
+        for batch in batches:
+            batched_data.append(batch.num_rows)
+
+    assert batched_data[0] >= min_rows_per_group and \
+        batched_data[0] <= max_rows_per_group
+    assert batched_data[1] >= min_rows_per_group and \
+        batched_data[1] <= max_rows_per_group

Review comment:
       I think this test is still a bit off.  I don't think you will be able to 
completely test `min_rows_per_group` without writing multiple batches.
   
   The key logic for this property is "This batch is not big enough, I must 
hold onto this batch until I get more data".  I don't see that property being 
tested here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to