This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 25360c00f5 Python: Add limit to table scan (#7163)
25360c00f5 is described below

commit 25360c00f53674bde2ad6dd448c1004d067ffbc7
Author: Daniel Rückert García <[email protected]>
AuthorDate: Fri Mar 24 09:30:53 2023 +0100

    Python: Add limit to table scan (#7163)
    
    * Python: Add support for ORC
    
    Creates fragments based on the FileFormat.
    
    Blocked by: https://github.com/apache/iceberg/pull/6997
    
    * Revert
    
    * TableScan add limit
    
    * pyarrow limit number of rows fetched from files if limit is set
    
    * add tests for scan limit
    
    * python ci rebuild container if changes on python/dev/
    
    * remove support for ORC
    
    * remove unused imports
    
    * increase sleep before running tests
    
    * update python docs to include limit in table query
    
    * docs fix format
    
    ---------
    
    Co-authored-by: Fokko Driesprong <[email protected]>
    Co-authored-by: Daniel Rückert García <[email protected]>
---
 .github/workflows/python-integration.yml |  4 +--
 python/dev/provision.py                  | 33 ++++++++++++++++++
 python/mkdocs/docs/api.md                |  3 +-
 python/pyiceberg/files.py                |  8 ++---
 python/pyiceberg/io/pyarrow.py           | 60 +++++++++++++++++++++++---------
 python/pyiceberg/manifest.py             |  6 ++++
 python/pyiceberg/table/__init__.py       | 15 ++++++--
 python/tests/test_integration.py         | 17 +++++++++
 8 files changed, 120 insertions(+), 26 deletions(-)

diff --git a/.github/workflows/python-integration.yml 
b/.github/workflows/python-integration.yml
index 2ea1d9c464..895e9e8e64 100644
--- a/.github/workflows/python-integration.yml
+++ b/.github/workflows/python-integration.yml
@@ -46,7 +46,7 @@ jobs:
       id: check_file_changed
       run: |
         $diff = git diff --name-only HEAD^ HEAD
-        $SourceDiff = $diff | Where-Object { $_ -match 
'^python/dev/Dockerfile$' }
+        $SourceDiff = $diff | Where-Object { $_ -match '^python/dev/.+$' }
         $HasDiff = $SourceDiff.Length -gt 0
         Write-Host "::set-output name=docs_changed::$HasDiff"
     - name: Restore image
@@ -84,4 +84,4 @@ jobs:
       run: make test-integration
     - name: Show debug logs
       if: ${{ failure() }}
-      run: docker-compose -f python/dev/docker-compose.yml logs
\ No newline at end of file
+      run: docker-compose -f python/dev/docker-compose.yml logs
diff --git a/python/dev/provision.py b/python/dev/provision.py
index 1e6f5b7319..81bd094c58 100644
--- a/python/dev/provision.py
+++ b/python/dev/provision.py
@@ -71,6 +71,39 @@ spark.sql(
 """
 )
 
+spark.sql(
+    """
+  DROP TABLE IF EXISTS test_limit;
+"""
+)
+
+spark.sql(
+    """
+    CREATE TABLE test_limit
+    USING iceberg
+      AS SELECT
+          1            AS idx
+      UNION ALL SELECT
+          2            AS idx
+      UNION ALL SELECT
+          3            AS idx
+      UNION ALL SELECT
+          4            AS idx
+      UNION ALL SELECT
+          5            AS idx
+      UNION ALL SELECT
+          6            AS idx
+      UNION ALL SELECT
+          7            AS idx
+      UNION ALL SELECT
+          8            AS idx
+      UNION ALL SELECT
+          9            AS idx
+      UNION ALL SELECT
+          10           AS idx
+    """
+)
+
 spark.sql(
     """
   DROP TABLE IF EXISTS test_deletes;
diff --git a/python/mkdocs/docs/api.md b/python/mkdocs/docs/api.md
index 9516516859..1cb26714ad 100644
--- a/python/mkdocs/docs/api.md
+++ b/python/mkdocs/docs/api.md
@@ -281,7 +281,7 @@ Table(
 
 ## Query a table
 
-To query a table, a table scan is needed. A table scan accepts a filter, 
columns and optionally a snapshot ID:
+To query a table, a table scan is needed. A table scan accepts a filter, 
columns and optionally a limit and a snapshot ID:
 
 ```python
 from pyiceberg.catalog import load_catalog
@@ -293,6 +293,7 @@ table = catalog.load_table("nyc.taxis")
 scan = table.scan(
     row_filter=GreaterThanOrEqual("trip_distance", 10.0),
     selected_fields=("VendorID", "tpep_pickup_datetime", 
"tpep_dropoff_datetime"),
+    limit=100,
 )
 
 # Or filter using a string predicate
diff --git a/python/pyiceberg/files.py b/python/pyiceberg/files.py
index b832bbbfe7..df19f37e0c 100644
--- a/python/pyiceberg/files.py
+++ b/python/pyiceberg/files.py
@@ -28,7 +28,7 @@ class FileContentType(Enum):
 class FileFormat(Enum):
     """An enum that includes all possible formats for an Iceberg data file"""
 
-    ORC = auto()
-    PARQUET = auto()
-    AVRO = auto()
-    METADATA = auto()
+    ORC = "ORC"
+    PARQUET = "PARQUET"
+    AVRO = "AVRO"
+    METADATA = "METADATA"
diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py
index 07b59258dd..ae81733968 100644
--- a/python/pyiceberg/io/pyarrow.py
+++ b/python/pyiceberg/io/pyarrow.py
@@ -24,9 +24,11 @@ with the pyarrow library.
 """
 from __future__ import annotations
 
+import multiprocessing
 import os
 from functools import lru_cache
 from multiprocessing.pool import ThreadPool
+from multiprocessing.sharedctypes import Synchronized
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -42,7 +44,7 @@ from urllib.parse import urlparse
 
 import pyarrow as pa
 import pyarrow.compute as pc
-import pyarrow.parquet as pq
+import pyarrow.dataset as ds
 from pyarrow.fs import (
     FileInfo,
     FileSystem,
@@ -491,14 +493,19 @@ def _file_to_table(
     projected_schema: Schema,
     projected_field_ids: Set[int],
     case_sensitive: bool,
+    rows_counter: Synchronized[int],
+    limit: Optional[int] = None,
 ) -> Optional[pa.Table]:
-    _, path = PyArrowFileIO.parse_location(task.file.file_path)
+    if limit and rows_counter.value >= limit:
+        return None
 
-    # Get the schema
-    with fs.open_input_file(path) as fout:
-        parquet_schema = pq.read_schema(fout)
+    _, path = PyArrowFileIO.parse_location(task.file.file_path)
+    arrow_format = ds.ParquetFileFormat(pre_buffer=True, 
buffer_size=(ONE_MEGABYTE * 8))
+    with fs.open_input_file(path) as fin:
+        fragment = arrow_format.make_fragment(fin)
+        physical_schema = fragment.physical_schema
         schema_raw = None
-        if metadata := parquet_schema.metadata:
+        if metadata := physical_schema.metadata:
             schema_raw = metadata.get(ICEBERG_SCHEMA)
         if schema_raw is None:
             raise ValueError(
@@ -517,15 +524,22 @@ def _file_to_table(
         if file_schema is None:
             raise ValueError(f"Missing Iceberg schema in Metadata for file: 
{path}")
 
-        arrow_table = pq.read_table(
-            source=fout,
-            schema=parquet_schema,
-            pre_buffer=True,
-            buffer_size=8 * ONE_MEGABYTE,
-            filters=pyarrow_filter,
+        fragment_scanner = ds.Scanner.from_fragment(
+            fragment=fragment,
+            schema=physical_schema,
+            filter=pyarrow_filter,
             columns=[col.name for col in file_project_schema.columns],
         )
 
+        if limit:
+            arrow_table = fragment_scanner.head(limit)
+            with rows_counter.get_lock():
+                if rows_counter.value >= limit:
+                    return None
+                rows_counter.value += len(arrow_table)
+        else:
+            arrow_table = fragment_scanner.to_table()
+
         # If there is no data, we don't have to go through the schema
         if len(arrow_table) > 0:
             return to_requested_schema(projected_schema, file_project_schema, 
arrow_table)
@@ -534,7 +548,12 @@ def _file_to_table(
 
 
 def project_table(
-    tasks: Iterable[FileScanTask], table: Table, row_filter: 
BooleanExpression, projected_schema: Schema, case_sensitive: bool
+    tasks: Iterable[FileScanTask],
+    table: Table,
+    row_filter: BooleanExpression,
+    projected_schema: Schema,
+    case_sensitive: bool,
+    limit: Optional[int] = None,
 ) -> pa.Table:
     """Resolves the right columns based on the identifier
 
@@ -570,23 +589,30 @@ def project_table(
         id for id in projected_schema.field_ids if not 
isinstance(projected_schema.find_type(id), (MapType, ListType))
     }.union(extract_field_ids(bound_row_filter))
 
+    rows_counter = multiprocessing.Value("i", 0)
+
     with ThreadPool() as pool:
         tables = [
             table
             for table in pool.starmap(
                 func=_file_to_table,
-                iterable=[(fs, task, bound_row_filter, projected_schema, 
projected_field_ids, case_sensitive) for task in tasks],
+                iterable=[
+                    (fs, task, bound_row_filter, projected_schema, 
projected_field_ids, case_sensitive, rows_counter, limit)
+                    for task in tasks
+                ],
                 chunksize=None,  # we could use this to control how to 
materialize the generator of tasks (we should also make the expression above 
lazy)
             )
             if table is not None
         ]
 
     if len(tables) > 1:
-        return pa.concat_tables(tables)
+        final_table = pa.concat_tables(tables)
     elif len(tables) == 1:
-        return tables[0]
+        final_table = tables[0]
     else:
-        return pa.Table.from_batches([], 
schema=schema_to_pyarrow(projected_schema))
+        final_table = pa.Table.from_batches([], 
schema=schema_to_pyarrow(projected_schema))
+
+    return final_table.slice(0, limit)
 
 
 def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: 
pa.Table) -> pa.Table:
diff --git a/python/pyiceberg/manifest.py b/python/pyiceberg/manifest.py
index 757f3bd016..942c582b51 100644
--- a/python/pyiceberg/manifest.py
+++ b/python/pyiceberg/manifest.py
@@ -177,6 +177,12 @@ class DataFile(Record):
     sort_order_id: Optional[int]
     spec_id: Optional[int]
 
+    def __setattr__(self, name: str, value: Any) -> None:
+        # The file_format is written as a string, so we need to cast it to the 
Enum
+        if name == "file_format":
+            value = FileFormat[value]
+        super().__setattr__(name, value)
+
     def __init__(self, *data: Any, **named_data: Any) -> None:
         super().__init__(*data, **{"struct": DATA_FILE_TYPE, **named_data})
 
diff --git a/python/pyiceberg/table/__init__.py 
b/python/pyiceberg/table/__init__.py
index ef53087fee..8f19d93acf 100644
--- a/python/pyiceberg/table/__init__.py
+++ b/python/pyiceberg/table/__init__.py
@@ -100,6 +100,7 @@ class Table:
         case_sensitive: bool = True,
         snapshot_id: Optional[int] = None,
         options: Properties = EMPTY_DICT,
+        limit: Optional[int] = None,
     ) -> DataScan:
         return DataScan(
             table=self,
@@ -108,6 +109,7 @@ class Table:
             case_sensitive=case_sensitive,
             snapshot_id=snapshot_id,
             options=options,
+            limit=limit,
         )
 
     def schema(self) -> Schema:
@@ -220,6 +222,7 @@ class TableScan(ABC):
     case_sensitive: bool
     snapshot_id: Optional[int]
     options: Properties
+    limit: Optional[int]
 
     def __init__(
         self,
@@ -229,6 +232,7 @@ class TableScan(ABC):
         case_sensitive: bool = True,
         snapshot_id: Optional[int] = None,
         options: Properties = EMPTY_DICT,
+        limit: Optional[int] = None,
     ):
         self.table = table
         self.row_filter = _parse_row_filter(row_filter)
@@ -236,6 +240,7 @@ class TableScan(ABC):
         self.case_sensitive = case_sensitive
         self.snapshot_id = snapshot_id
         self.options = options
+        self.limit = limit
 
     def snapshot(self) -> Optional[Snapshot]:
         if self.snapshot_id:
@@ -336,8 +341,9 @@ class DataScan(TableScan):
         case_sensitive: bool = True,
         snapshot_id: Optional[int] = None,
         options: Properties = EMPTY_DICT,
+        limit: Optional[int] = None,
     ):
-        super().__init__(table, row_filter, selected_fields, case_sensitive, 
snapshot_id, options)
+        super().__init__(table, row_filter, selected_fields, case_sensitive, 
snapshot_id, options, limit)
 
     def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
         project = inclusive_projection(self.table.schema(), 
self.table.specs()[spec_id])
@@ -403,7 +409,12 @@ class DataScan(TableScan):
         from pyiceberg.io.pyarrow import project_table
 
         return project_table(
-            self.plan_files(), self.table, self.row_filter, self.projection(), 
case_sensitive=self.case_sensitive
+            self.plan_files(),
+            self.table,
+            self.row_filter,
+            self.projection(),
+            case_sensitive=self.case_sensitive,
+            limit=self.limit,
         )
 
     def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py
index 653f803527..3eb24cd48e 100644
--- a/python/tests/test_integration.py
+++ b/python/tests/test_integration.py
@@ -49,6 +49,11 @@ def table_test_null_nan_rewritten(catalog: Catalog) -> Table:
     return catalog.load_table("default.test_null_nan_rewritten")
 
 
[email protected]()
+def table_test_limit(catalog: Catalog) -> Table:
+    return catalog.load_table("default.test_limit")
+
+
 @pytest.fixture()
 def table_test_all_types(catalog: Catalog) -> Table:
     return catalog.load_table("default.test_all_types")
@@ -87,6 +92,18 @@ def test_duckdb_nan(table_test_null_nan_rewritten: Table) -> 
None:
     assert math.isnan(result[1])
 
 
[email protected]
+def test_pyarrow_limit(table_test_limit: Table) -> None:
+    limited_result = table_test_limit.scan(selected_fields=("idx",), 
limit=1).to_arrow()
+    assert len(limited_result) == 1
+
+    empty_result = table_test_limit.scan(selected_fields=("idx",), 
limit=0).to_arrow()
+    assert len(empty_result) == 0
+
+    full_result = table_test_limit.scan(selected_fields=("idx",), 
limit=999).to_arrow()
+    assert len(full_result) == 10
+
+
 @pytest.mark.integration
 def test_ray_nan(table_test_null_nan_rewritten: Table) -> None:
     ray_dataset = table_test_null_nan_rewritten.scan().to_ray()

Reply via email to