wjones127 commented on a change in pull request #11911:
URL: https://github.com/apache/arrow/pull/11911#discussion_r770688357



##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -3621,6 +3621,204 @@ def compare_tables_ignoring_order(t1, t2):
     assert not extra_file.exists()
 
 
+def _generate_random_int_array(size=4, min=1, max=10):
+    return np.random.randint(min, max, size)
+
+
+def _generate_data_and_columns(num_of_columns, records_per_row,
+                               unique_records=None):
+    data = []
+    column_names = []
+    if unique_records is None:
+        unique_records = records_per_row
+    for i in range(num_of_columns):
+        data.append(_generate_random_int_array(size=records_per_row,
+                                               min=1,
+                                               max=unique_records))
+        column_names.append("c" + str(i))
+    return data, column_names
+
+
+def _get_num_of_files_generated(base_directory):
+    file_dirs = os.listdir(base_directory)
+    number_of_files = 0
+    for _, file_dir in enumerate(file_dirs):
+        sub_dir_path = base_directory / file_dir
+        number_of_files += len(os.listdir(sub_dir_path))
+    return number_of_files
+
+
+def _get_compare_pair(data_source, record_batch):
+    num_of_files_generated = _get_num_of_files_generated(
+        base_directory=data_source)
+    number_of_unique_rows = len(pa.compute.unique(record_batch[0]))
+    return num_of_files_generated, number_of_unique_rows
+
+
+def test_write_dataset_max_rows_per_file(tempdir):

Review comment:
       I'm talking about these three tests:
   
    * `test_write_dataset_max_rows_per_file`
    * `test_write_dataset_min_rows_per_group`
    * `test_write_dataset_max_rows_per_group`
   
   You could rewrite this as one test function that measured the number of rows 
in each group, file, and partition, and then assert general tests based on 
those:
   
   ```python
   # Every file has less rows than max_rows_per_file
   assert all(file.nrows <= max_rows_per_file for file in dataset_files)
   # Every group has less rows than max_rows_per_group
   assert all(group.nrows <= max_rows_per_group for file in dataset_files for 
group in file.groups)
   # Every file has at least min_rows_per_file rows, unless whole partition has 
less than min_rows_per_file rows
   ...
   # Every group has at least min_rows_per_group, unless ...
   ...
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to