This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 25360c00f5 Python: Add limit to table scan (#7163)
25360c00f5 is described below
commit 25360c00f53674bde2ad6dd448c1004d067ffbc7
Author: Daniel Rückert García <[email protected]>
AuthorDate: Fri Mar 24 09:30:53 2023 +0100
Python: Add limit to table scan (#7163)
* Python: Add support for ORC
Creates fragments based on the FileFormat.
Blocked by: https://github.com/apache/iceberg/pull/6997
* Revert
* TableScan add limit
* pyarrow limit number of rows fetched from files if limit is set
* add tests for scan limit
* python ci rebuild container if changes on python/dev/
* remove support for ORC
* remove unused imports
* increase sleep before running tests
* update python docs to include limit in table query
* docs fix format
---------
Co-authored-by: Fokko Driesprong <[email protected]>
Co-authored-by: Daniel Rückert García <[email protected]>
---
.github/workflows/python-integration.yml | 4 +--
python/dev/provision.py | 33 ++++++++++++++++++
python/mkdocs/docs/api.md | 3 +-
python/pyiceberg/files.py | 8 ++---
python/pyiceberg/io/pyarrow.py | 60 +++++++++++++++++++++++---------
python/pyiceberg/manifest.py | 6 ++++
python/pyiceberg/table/__init__.py | 15 ++++++--
python/tests/test_integration.py | 17 +++++++++
8 files changed, 120 insertions(+), 26 deletions(-)
diff --git a/.github/workflows/python-integration.yml
b/.github/workflows/python-integration.yml
index 2ea1d9c464..895e9e8e64 100644
--- a/.github/workflows/python-integration.yml
+++ b/.github/workflows/python-integration.yml
@@ -46,7 +46,7 @@ jobs:
id: check_file_changed
run: |
$diff = git diff --name-only HEAD^ HEAD
- $SourceDiff = $diff | Where-Object { $_ -match
'^python/dev/Dockerfile$' }
+ $SourceDiff = $diff | Where-Object { $_ -match '^python/dev/.+$' }
$HasDiff = $SourceDiff.Length -gt 0
Write-Host "::set-output name=docs_changed::$HasDiff"
- name: Restore image
@@ -84,4 +84,4 @@ jobs:
run: make test-integration
- name: Show debug logs
if: ${{ failure() }}
- run: docker-compose -f python/dev/docker-compose.yml logs
\ No newline at end of file
+ run: docker-compose -f python/dev/docker-compose.yml logs
diff --git a/python/dev/provision.py b/python/dev/provision.py
index 1e6f5b7319..81bd094c58 100644
--- a/python/dev/provision.py
+++ b/python/dev/provision.py
@@ -71,6 +71,39 @@ spark.sql(
"""
)
+spark.sql(
+ """
+ DROP TABLE IF EXISTS test_limit;
+"""
+)
+
+spark.sql(
+ """
+ CREATE TABLE test_limit
+ USING iceberg
+ AS SELECT
+ 1 AS idx
+ UNION ALL SELECT
+ 2 AS idx
+ UNION ALL SELECT
+ 3 AS idx
+ UNION ALL SELECT
+ 4 AS idx
+ UNION ALL SELECT
+ 5 AS idx
+ UNION ALL SELECT
+ 6 AS idx
+ UNION ALL SELECT
+ 7 AS idx
+ UNION ALL SELECT
+ 8 AS idx
+ UNION ALL SELECT
+ 9 AS idx
+ UNION ALL SELECT
+ 10 AS idx
+ """
+)
+
spark.sql(
"""
DROP TABLE IF EXISTS test_deletes;
diff --git a/python/mkdocs/docs/api.md b/python/mkdocs/docs/api.md
index 9516516859..1cb26714ad 100644
--- a/python/mkdocs/docs/api.md
+++ b/python/mkdocs/docs/api.md
@@ -281,7 +281,7 @@ Table(
## Query a table
-To query a table, a table scan is needed. A table scan accepts a filter,
columns and optionally a snapshot ID:
+To query a table, a table scan is needed. A table scan accepts a filter,
columns and optionally a limit and a snapshot ID:
```python
from pyiceberg.catalog import load_catalog
@@ -293,6 +293,7 @@ table = catalog.load_table("nyc.taxis")
scan = table.scan(
row_filter=GreaterThanOrEqual("trip_distance", 10.0),
selected_fields=("VendorID", "tpep_pickup_datetime",
"tpep_dropoff_datetime"),
+ limit=100,
)
# Or filter using a string predicate
diff --git a/python/pyiceberg/files.py b/python/pyiceberg/files.py
index b832bbbfe7..df19f37e0c 100644
--- a/python/pyiceberg/files.py
+++ b/python/pyiceberg/files.py
@@ -28,7 +28,7 @@ class FileContentType(Enum):
class FileFormat(Enum):
"""An enum that includes all possible formats for an Iceberg data file"""
- ORC = auto()
- PARQUET = auto()
- AVRO = auto()
- METADATA = auto()
+ ORC = "ORC"
+ PARQUET = "PARQUET"
+ AVRO = "AVRO"
+ METADATA = "METADATA"
diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py
index 07b59258dd..ae81733968 100644
--- a/python/pyiceberg/io/pyarrow.py
+++ b/python/pyiceberg/io/pyarrow.py
@@ -24,9 +24,11 @@ with the pyarrow library.
"""
from __future__ import annotations
+import multiprocessing
import os
from functools import lru_cache
from multiprocessing.pool import ThreadPool
+from multiprocessing.sharedctypes import Synchronized
from typing import (
TYPE_CHECKING,
Any,
@@ -42,7 +44,7 @@ from urllib.parse import urlparse
import pyarrow as pa
import pyarrow.compute as pc
-import pyarrow.parquet as pq
+import pyarrow.dataset as ds
from pyarrow.fs import (
FileInfo,
FileSystem,
@@ -491,14 +493,19 @@ def _file_to_table(
projected_schema: Schema,
projected_field_ids: Set[int],
case_sensitive: bool,
+ rows_counter: Synchronized[int],
+ limit: Optional[int] = None,
) -> Optional[pa.Table]:
- _, path = PyArrowFileIO.parse_location(task.file.file_path)
+ if limit and rows_counter.value >= limit:
+ return None
- # Get the schema
- with fs.open_input_file(path) as fout:
- parquet_schema = pq.read_schema(fout)
+ _, path = PyArrowFileIO.parse_location(task.file.file_path)
+ arrow_format = ds.ParquetFileFormat(pre_buffer=True,
buffer_size=(ONE_MEGABYTE * 8))
+ with fs.open_input_file(path) as fin:
+ fragment = arrow_format.make_fragment(fin)
+ physical_schema = fragment.physical_schema
schema_raw = None
- if metadata := parquet_schema.metadata:
+ if metadata := physical_schema.metadata:
schema_raw = metadata.get(ICEBERG_SCHEMA)
if schema_raw is None:
raise ValueError(
@@ -517,15 +524,22 @@ def _file_to_table(
if file_schema is None:
raise ValueError(f"Missing Iceberg schema in Metadata for file:
{path}")
- arrow_table = pq.read_table(
- source=fout,
- schema=parquet_schema,
- pre_buffer=True,
- buffer_size=8 * ONE_MEGABYTE,
- filters=pyarrow_filter,
+ fragment_scanner = ds.Scanner.from_fragment(
+ fragment=fragment,
+ schema=physical_schema,
+ filter=pyarrow_filter,
columns=[col.name for col in file_project_schema.columns],
)
+ if limit:
+ arrow_table = fragment_scanner.head(limit)
+ with rows_counter.get_lock():
+ if rows_counter.value >= limit:
+ return None
+ rows_counter.value += len(arrow_table)
+ else:
+ arrow_table = fragment_scanner.to_table()
+
# If there is no data, we don't have to go through the schema
if len(arrow_table) > 0:
return to_requested_schema(projected_schema, file_project_schema,
arrow_table)
@@ -534,7 +548,12 @@ def _file_to_table(
def project_table(
- tasks: Iterable[FileScanTask], table: Table, row_filter:
BooleanExpression, projected_schema: Schema, case_sensitive: bool
+ tasks: Iterable[FileScanTask],
+ table: Table,
+ row_filter: BooleanExpression,
+ projected_schema: Schema,
+ case_sensitive: bool,
+ limit: Optional[int] = None,
) -> pa.Table:
"""Resolves the right columns based on the identifier
@@ -570,23 +589,30 @@ def project_table(
id for id in projected_schema.field_ids if not
isinstance(projected_schema.find_type(id), (MapType, ListType))
}.union(extract_field_ids(bound_row_filter))
+ rows_counter = multiprocessing.Value("i", 0)
+
with ThreadPool() as pool:
tables = [
table
for table in pool.starmap(
func=_file_to_table,
- iterable=[(fs, task, bound_row_filter, projected_schema,
projected_field_ids, case_sensitive) for task in tasks],
+ iterable=[
+ (fs, task, bound_row_filter, projected_schema,
projected_field_ids, case_sensitive, rows_counter, limit)
+ for task in tasks
+ ],
chunksize=None, # we could use this to control how to
materialize the generator of tasks (we should also make the expression above
lazy)
)
if table is not None
]
if len(tables) > 1:
- return pa.concat_tables(tables)
+ final_table = pa.concat_tables(tables)
elif len(tables) == 1:
- return tables[0]
+ final_table = tables[0]
else:
- return pa.Table.from_batches([],
schema=schema_to_pyarrow(projected_schema))
+ final_table = pa.Table.from_batches([],
schema=schema_to_pyarrow(projected_schema))
+
+ return final_table.slice(0, limit)
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table:
pa.Table) -> pa.Table:
diff --git a/python/pyiceberg/manifest.py b/python/pyiceberg/manifest.py
index 757f3bd016..942c582b51 100644
--- a/python/pyiceberg/manifest.py
+++ b/python/pyiceberg/manifest.py
@@ -177,6 +177,12 @@ class DataFile(Record):
sort_order_id: Optional[int]
spec_id: Optional[int]
+ def __setattr__(self, name: str, value: Any) -> None:
+ # The file_format is written as a string, so we need to cast it to the
Enum
+ if name == "file_format":
+ value = FileFormat[value]
+ super().__setattr__(name, value)
+
def __init__(self, *data: Any, **named_data: Any) -> None:
super().__init__(*data, **{"struct": DATA_FILE_TYPE, **named_data})
diff --git a/python/pyiceberg/table/__init__.py
b/python/pyiceberg/table/__init__.py
index ef53087fee..8f19d93acf 100644
--- a/python/pyiceberg/table/__init__.py
+++ b/python/pyiceberg/table/__init__.py
@@ -100,6 +100,7 @@ class Table:
case_sensitive: bool = True,
snapshot_id: Optional[int] = None,
options: Properties = EMPTY_DICT,
+ limit: Optional[int] = None,
) -> DataScan:
return DataScan(
table=self,
@@ -108,6 +109,7 @@ class Table:
case_sensitive=case_sensitive,
snapshot_id=snapshot_id,
options=options,
+ limit=limit,
)
def schema(self) -> Schema:
@@ -220,6 +222,7 @@ class TableScan(ABC):
case_sensitive: bool
snapshot_id: Optional[int]
options: Properties
+ limit: Optional[int]
def __init__(
self,
@@ -229,6 +232,7 @@ class TableScan(ABC):
case_sensitive: bool = True,
snapshot_id: Optional[int] = None,
options: Properties = EMPTY_DICT,
+ limit: Optional[int] = None,
):
self.table = table
self.row_filter = _parse_row_filter(row_filter)
@@ -236,6 +240,7 @@ class TableScan(ABC):
self.case_sensitive = case_sensitive
self.snapshot_id = snapshot_id
self.options = options
+ self.limit = limit
def snapshot(self) -> Optional[Snapshot]:
if self.snapshot_id:
@@ -336,8 +341,9 @@ class DataScan(TableScan):
case_sensitive: bool = True,
snapshot_id: Optional[int] = None,
options: Properties = EMPTY_DICT,
+ limit: Optional[int] = None,
):
- super().__init__(table, row_filter, selected_fields, case_sensitive,
snapshot_id, options)
+ super().__init__(table, row_filter, selected_fields, case_sensitive,
snapshot_id, options, limit)
def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
project = inclusive_projection(self.table.schema(),
self.table.specs()[spec_id])
@@ -403,7 +409,12 @@ class DataScan(TableScan):
from pyiceberg.io.pyarrow import project_table
return project_table(
- self.plan_files(), self.table, self.row_filter, self.projection(),
case_sensitive=self.case_sensitive
+ self.plan_files(),
+ self.table,
+ self.row_filter,
+ self.projection(),
+ case_sensitive=self.case_sensitive,
+ limit=self.limit,
)
def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py
index 653f803527..3eb24cd48e 100644
--- a/python/tests/test_integration.py
+++ b/python/tests/test_integration.py
@@ -49,6 +49,11 @@ def table_test_null_nan_rewritten(catalog: Catalog) -> Table:
return catalog.load_table("default.test_null_nan_rewritten")
[email protected]()
+def table_test_limit(catalog: Catalog) -> Table:
+ return catalog.load_table("default.test_limit")
+
+
@pytest.fixture()
def table_test_all_types(catalog: Catalog) -> Table:
return catalog.load_table("default.test_all_types")
@@ -87,6 +92,18 @@ def test_duckdb_nan(table_test_null_nan_rewritten: Table) ->
None:
assert math.isnan(result[1])
[email protected]
+def test_pyarrow_limit(table_test_limit: Table) -> None:
+ limited_result = table_test_limit.scan(selected_fields=("idx",),
limit=1).to_arrow()
+ assert len(limited_result) == 1
+
+ empty_result = table_test_limit.scan(selected_fields=("idx",),
limit=0).to_arrow()
+ assert len(empty_result) == 0
+
+ full_result = table_test_limit.scan(selected_fields=("idx",),
limit=999).to_arrow()
+ assert len(full_result) == 10
+
+
@pytest.mark.integration
def test_ray_nan(table_test_null_nan_rewritten: Table) -> None:
ray_dataset = table_test_null_nan_rewritten.scan().to_ray()