This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new d7e137ee Enable remaining pylints (#1298)
d7e137ee is described below
commit d7e137eef224d5ed6e72b8ed798b3cf8fe9db40d
Author: Tim Saucer <[email protected]>
AuthorDate: Sun Nov 2 07:16:54 2025 -0500
Enable remaining pylints (#1298)
* Now that we are on Python 3.10 change from union and otional to |
* Enable additional lint
* Add check for dead code
* Verify all python arguments ahve type annotations
* Add return types on functions
* Cleaning up pyproj
* More lints
* Enable path ruff check
* Fix Path.glob code
* Remove deprecated test
* Expect deprecation warning
---
benchmarks/db-benchmark/groupby-datafusion.py | 5 +-
benchmarks/db-benchmark/join-datafusion.py | 11 +-
benchmarks/tpch/tpch.py | 7 +-
dev/create_license.py | 3 +-
dev/release/check-rat-report.py | 3 +-
examples/python-udf-comparisons.py | 6 +-
examples/tpch/convert_data_to_parquet.py | 10 +-
examples/tpch/q07_volume_shipping.py | 2 +-
examples/tpch/q12_ship_mode_order_priority.py | 2 +-
examples/tpch/util.py | 16 ++-
pyproject.toml | 30 ++----
python/datafusion/__init__.py | 8 +-
python/datafusion/dataframe.py | 49 +++++----
python/datafusion/dataframe_formatter.py | 14 +--
python/datafusion/expr.py | 53 +++++-----
python/datafusion/functions.py | 138 +++++++++++++-------------
python/datafusion/input/location.py | 4 +-
python/datafusion/user_defined.py | 33 +++---
python/tests/test_aggregation.py | 2 +-
python/tests/test_dataframe.py | 28 +-----
python/tests/test_pyclass_frozen.py | 5 +-
python/tests/test_sql.py | 25 +++--
22 files changed, 216 insertions(+), 238 deletions(-)
diff --git a/benchmarks/db-benchmark/groupby-datafusion.py
b/benchmarks/db-benchmark/groupby-datafusion.py
index f9e8d638..53316669 100644
--- a/benchmarks/db-benchmark/groupby-datafusion.py
+++ b/benchmarks/db-benchmark/groupby-datafusion.py
@@ -18,6 +18,7 @@
import gc
import os
import timeit
+from pathlib import Path
import datafusion as df
import pyarrow as pa
@@ -34,7 +35,7 @@ from pyarrow import csv as pacsv
print("# groupby-datafusion.py", flush=True)
-exec(open("./_helpers/helpers.py").read())
+exec(Path.open("./_helpers/helpers.py").read())
def ans_shape(batches) -> tuple[int, int]:
@@ -65,7 +66,7 @@ on_disk = "FALSE"
sql = True
data_name = os.environ["SRC_DATANAME"]
-src_grp = os.path.join("data", data_name + ".csv")
+src_grp = "data" / data_name / ".csv"
print("loading dataset %s" % src_grp, flush=True)
schema = pa.schema(
diff --git a/benchmarks/db-benchmark/join-datafusion.py
b/benchmarks/db-benchmark/join-datafusion.py
index 03986803..3be296c8 100755
--- a/benchmarks/db-benchmark/join-datafusion.py
+++ b/benchmarks/db-benchmark/join-datafusion.py
@@ -18,6 +18,7 @@
import gc
import os
import timeit
+from pathlib import Path
import datafusion as df
from datafusion import col
@@ -26,7 +27,7 @@ from pyarrow import csv as pacsv
print("# join-datafusion.py", flush=True)
-exec(open("./_helpers/helpers.py").read())
+exec(Path.open("./_helpers/helpers.py").read())
def ans_shape(batches) -> tuple[int, int]:
@@ -49,12 +50,12 @@ cache = "TRUE"
on_disk = "FALSE"
data_name = os.environ["SRC_DATANAME"]
-src_jn_x = os.path.join("data", data_name + ".csv")
+src_jn_x = "data" / data_name / ".csv"
y_data_name = join_to_tbls(data_name)
src_jn_y = [
- os.path.join("data", y_data_name[0] + ".csv"),
- os.path.join("data", y_data_name[1] + ".csv"),
- os.path.join("data", y_data_name[2] + ".csv"),
+ "data" / y_data_name[0] / ".csv",
+ "data" / y_data_name[1] / ".csv",
+ "data" / y_data_name[2] / ".csv",
]
if len(src_jn_y) != 3:
error_msg = "Something went wrong in preparing files used for join"
diff --git a/benchmarks/tpch/tpch.py b/benchmarks/tpch/tpch.py
index 2d1bbae5..9cc897e7 100644
--- a/benchmarks/tpch/tpch.py
+++ b/benchmarks/tpch/tpch.py
@@ -17,12 +17,13 @@
import argparse
import time
+from pathlib import Path
from datafusion import SessionContext
def bench(data_path, query_path) -> None:
- with open("results.csv", "w") as results:
+ with Path.open("results.csv", "w") as results:
# register tables
start = time.time()
total_time_millis = 0
@@ -45,7 +46,7 @@ def bench(data_path, query_path) -> None:
print("Configuration:\n", ctx)
# register tables
- with open("create_tables.sql") as f:
+ with Path.open("create_tables.sql") as f:
sql = ""
for line in f.readlines():
if line.startswith("--"):
@@ -65,7 +66,7 @@ def bench(data_path, query_path) -> None:
# run queries
for query in range(1, 23):
- with open(f"{query_path}/q{query}.sql") as f:
+ with Path.open(f"{query_path}/q{query}.sql") as f:
text = f.read()
tmp = text.split(";")
queries = [s.strip() for s in tmp if len(s.strip()) > 0]
diff --git a/dev/create_license.py b/dev/create_license.py
index 2a67cb8f..a28a0abe 100644
--- a/dev/create_license.py
+++ b/dev/create_license.py
@@ -20,6 +20,7 @@
import json
import subprocess
+from pathlib import Path
subprocess.check_output(["cargo", "install", "cargo-license"])
data = subprocess.check_output(
@@ -248,5 +249,5 @@ for item in data:
result += "------------------\n\n"
result += f"### {name} {version}\n* source:
[{repository}]({repository})\n* license: {license}\n\n"
-with open("LICENSE.txt", "w") as f:
+with Path.open("LICENSE.txt", "w") as f:
f.write(result)
diff --git a/dev/release/check-rat-report.py b/dev/release/check-rat-report.py
index 0c9f4c32..72a35212 100644
--- a/dev/release/check-rat-report.py
+++ b/dev/release/check-rat-report.py
@@ -21,6 +21,7 @@ import fnmatch
import re
import sys
import xml.etree.ElementTree as ET
+from pathlib import Path
if len(sys.argv) != 3:
sys.stderr.write("Usage: %s exclude_globs.lst rat_report.xml\n" %
sys.argv[0])
@@ -29,7 +30,7 @@ if len(sys.argv) != 3:
exclude_globs_filename = sys.argv[1]
xml_filename = sys.argv[2]
-globs = [line.strip() for line in open(exclude_globs_filename)]
+globs = [line.strip() for line in Path.open(exclude_globs_filename)]
tree = ET.parse(xml_filename)
root = tree.getroot()
diff --git a/examples/python-udf-comparisons.py
b/examples/python-udf-comparisons.py
index eb082501..b870645a 100644
--- a/examples/python-udf-comparisons.py
+++ b/examples/python-udf-comparisons.py
@@ -15,16 +15,16 @@
# specific language governing permissions and limitations
# under the License.
-import os
import time
+from pathlib import Path
import pyarrow as pa
import pyarrow.compute as pc
from datafusion import SessionContext, col, lit, udf
from datafusion import functions as F
-path = os.path.dirname(os.path.abspath(__file__))
-filepath = os.path.join(path, "./tpch/data/lineitem.parquet")
+path = Path(__file__).parent.resolve()
+filepath = path / "./tpch/data/lineitem.parquet"
# This example serves to demonstrate alternate approaches to answering the
# question "return all of the rows that have a specific combination of these
diff --git a/examples/tpch/convert_data_to_parquet.py
b/examples/tpch/convert_data_to_parquet.py
index fd0fcca4..af554c39 100644
--- a/examples/tpch/convert_data_to_parquet.py
+++ b/examples/tpch/convert_data_to_parquet.py
@@ -22,7 +22,7 @@ the data generated resides in a path
../../benchmarks/tpch/data relative to the
as will be generated by the script provided in this repository.
"""
-import os
+from pathlib import Path
import datafusion
import pyarrow as pa
@@ -116,7 +116,7 @@ all_schemas["supplier"] = [
("S_COMMENT", pa.string()),
]
-curr_dir = os.path.dirname(os.path.abspath(__file__))
+curr_dir = Path(__file__).resolve().parent
for filename, curr_schema_val in all_schemas.items():
# For convenience, go ahead and convert the schema column names to
lowercase
curr_schema = [(s[0].lower(), s[1]) for s in curr_schema_val]
@@ -132,10 +132,8 @@ for filename, curr_schema_val in all_schemas.items():
schema = pa.schema(curr_schema)
- source_file = os.path.abspath(
- os.path.join(curr_dir, f"../../benchmarks/tpch/data/{filename}.csv")
- )
- dest_file = os.path.abspath(os.path.join(curr_dir,
f"./data/{filename}.parquet"))
+ source_file = (curr_dir /
f"../../benchmarks/tpch/data/{filename}.csv").resolve()
+ dest_file = (curr_dir / f"./data/{filename}.parquet").resolve()
df = ctx.read_csv(source_file, schema=schema, has_header=False,
delimiter="|")
diff --git a/examples/tpch/q07_volume_shipping.py
b/examples/tpch/q07_volume_shipping.py
index a84cf728..ff2f891f 100644
--- a/examples/tpch/q07_volume_shipping.py
+++ b/examples/tpch/q07_volume_shipping.py
@@ -80,7 +80,7 @@ df_lineitem = df_lineitem.filter(col("l_shipdate") >=
start_date).filter(
# not match these will result in a null value and then get filtered out.
#
# To do the same using a simple filter would be:
-# df_nation = df_nation.filter((F.col("n_name") == nation_1) |
(F.col("n_name") == nation_2))
+# df_nation = df_nation.filter((F.col("n_name") == nation_1) |
(F.col("n_name") == nation_2)) # noqa: ERA001
df_nation = df_nation.with_column(
"n_name",
F.case(col("n_name"))
diff --git a/examples/tpch/q12_ship_mode_order_priority.py
b/examples/tpch/q12_ship_mode_order_priority.py
index f1d89494..9071597f 100644
--- a/examples/tpch/q12_ship_mode_order_priority.py
+++ b/examples/tpch/q12_ship_mode_order_priority.py
@@ -73,7 +73,7 @@ df = df.filter(
# matches either of the two values, but we want to show doing some array
operations in this
# example. If you want to see this done with filters, comment out the above
line and uncomment
# this one.
-# df = df.filter((col("l_shipmode") == lit(SHIP_MODE_1)) | (col("l_shipmode")
== lit(SHIP_MODE_2)))
+# df = df.filter((col("l_shipmode") == lit(SHIP_MODE_1)) | (col("l_shipmode")
== lit(SHIP_MODE_2))) # noqa: ERA001
# We need order priority, so join order df to line item
diff --git a/examples/tpch/util.py b/examples/tpch/util.py
index 7e3d659d..ec53bcd1 100644
--- a/examples/tpch/util.py
+++ b/examples/tpch/util.py
@@ -19,18 +19,16 @@
Common utilities for running TPC-H examples.
"""
-import os
+from pathlib import Path
-def get_data_path(filename: str) -> str:
- path = os.path.dirname(os.path.abspath(__file__))
+def get_data_path(filename: str) -> Path:
+ path = Path(__file__).resolve().parent
- return os.path.join(path, "data", filename)
+ return path / "data" / filename
-def get_answer_file(answer_file: str) -> str:
- path = os.path.dirname(os.path.abspath(__file__))
+def get_answer_file(answer_file: str) -> Path:
+ path = Path(__file__).resolve().parent
- return os.path.join(
- path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out"
- )
+ return path / "../../benchmarks/tpch/data/answers" / f"{answer_file}.out"
diff --git a/pyproject.toml b/pyproject.toml
index f47a2b1c..25f30b8e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -84,25 +84,12 @@ ignore = [
"FBT002", # Allow boolean positional args
"ISC001", # Recommended to ignore these rules when using with ruff-format
"SLF001", # Allow accessing private members
- "TD002",
+ "TD002", # Do not require author names in TODO statements
"TD003", # Allow TODO lines
- "UP007", # Disallowing Union is pedantic
- # TODO: Enable all of the following, but this PR is getting too large
already
- "PLR0913",
- "TRY003",
- "PLR2004",
- "PD901",
- "ERA001",
- "ANN001",
- "ANN202",
- "PTH",
- "N812",
- "INP001",
- "DTZ007",
- "RUF015",
- "A005",
- "TC001",
- "UP035",
+ "PLR0913", # Allow many arguments in function definition
+ "PD901", # Allow variable name df
+ "N812", # Allow importing functions as `F`
+ "A005", # Allow module named io
]
[tool.ruff.lint.pydocstyle]
@@ -131,10 +118,11 @@ extend-allowed-calls = ["lit", "datafusion.lit"]
"PLR0913",
"PT004",
]
-"examples/*" = ["D", "W505", "E501", "T201", "S101"]
-"dev/*" = ["D", "E", "T", "S", "PLR", "C", "SIM", "UP", "EXE", "N817"]
-"benchmarks/*" = ["D", "F", "T", "BLE", "FURB", "PLR", "E", "TD", "TRY", "S",
"SIM", "EXE", "UP"]
+"examples/*" = ["D", "W505", "E501", "T201", "S101", "PLR2004", "ANN001",
"ANN202", "INP001", "DTZ007", "RUF015"]
+"dev/*" = ["D", "E", "T", "S", "PLR", "C", "SIM", "UP", "EXE", "N817",
"ERA001", "ANN001"]
+"benchmarks/*" = ["D", "F", "T", "BLE", "FURB", "PLR", "E", "TD", "TRY", "S",
"SIM", "EXE", "UP", "ERA001", "ANN001", "INP001"]
"docs/*" = ["D"]
+"docs/source/conf.py" = ["ERA001", "ANN001", "INP001"]
[tool.codespell]
skip = [
diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py
index 77765223..784d4ccc 100644
--- a/python/datafusion/__init__.py
+++ b/python/datafusion/__init__.py
@@ -119,12 +119,12 @@ __all__ = [
]
-def literal(value) -> Expr:
+def literal(value: Any) -> Expr:
"""Create a literal expression."""
return Expr.literal(value)
-def string_literal(value):
+def string_literal(value: str) -> Expr:
"""Create a UTF8 literal expression.
It differs from `literal` which creates a UTF8view literal.
@@ -132,12 +132,12 @@ def string_literal(value):
return Expr.string_literal(value)
-def str_lit(value):
+def str_lit(value: str) -> Expr:
"""Alias for `string_literal`."""
return string_literal(value)
-def lit(value) -> Expr:
+def lit(value: Any) -> Expr:
"""Create a literal expression."""
return Expr.literal(value)
diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py
index 05b84c6b..8d692aca 100644
--- a/python/datafusion/dataframe.py
+++ b/python/datafusion/dataframe.py
@@ -22,14 +22,11 @@ See :ref:`user_guide_concepts` in the online documentation
for more information.
from __future__ import annotations
import warnings
-from collections.abc import Sequence
+from collections.abc import Iterable, Sequence
from typing import (
TYPE_CHECKING,
Any,
- Iterable,
Literal,
- Optional,
- Union,
overload,
)
@@ -57,7 +54,7 @@ from datafusion.record_batch import RecordBatchStream
if TYPE_CHECKING:
import pathlib
- from typing import Callable
+ from collections.abc import Callable
import pandas as pd
import polars as pl
@@ -80,7 +77,7 @@ class Compression(Enum):
LZ4 = "lz4"
# lzo is not implemented yet
# https://github.com/apache/arrow-rs/issues/6970
- # LZO = "lzo"
+ # LZO = "lzo" # noqa: ERA001
ZSTD = "zstd"
LZ4_RAW = "lz4_raw"
@@ -107,7 +104,7 @@ class Compression(Enum):
"""
raise ValueError(error_msg) from err
- def get_default_level(self) -> Optional[int]:
+ def get_default_level(self) -> int | None:
"""Get the default compression level for the compression type.
Returns:
@@ -140,24 +137,24 @@ class ParquetWriterOptions:
write_batch_size: int = 1024,
writer_version: str = "1.0",
skip_arrow_metadata: bool = False,
- compression: Optional[str] = "zstd(3)",
- compression_level: Optional[int] = None,
- dictionary_enabled: Optional[bool] = True,
+ compression: str | None = "zstd(3)",
+ compression_level: int | None = None,
+ dictionary_enabled: bool | None = True,
dictionary_page_size_limit: int = 1024 * 1024,
- statistics_enabled: Optional[str] = "page",
+ statistics_enabled: str | None = "page",
max_row_group_size: int = 1024 * 1024,
created_by: str = "datafusion-python",
- column_index_truncate_length: Optional[int] = 64,
- statistics_truncate_length: Optional[int] = None,
+ column_index_truncate_length: int | None = 64,
+ statistics_truncate_length: int | None = None,
data_page_row_count_limit: int = 20_000,
- encoding: Optional[str] = None,
+ encoding: str | None = None,
bloom_filter_on_write: bool = False,
- bloom_filter_fpp: Optional[float] = None,
- bloom_filter_ndv: Optional[int] = None,
+ bloom_filter_fpp: float | None = None,
+ bloom_filter_ndv: int | None = None,
allow_single_file_parallelism: bool = True,
maximum_parallel_row_group_writers: int = 1,
maximum_buffered_record_batches_per_stream: int = 2,
- column_specific_options: Optional[dict[str, ParquetColumnOptions]] =
None,
+ column_specific_options: dict[str, ParquetColumnOptions] | None = None,
) -> None:
"""Initialize the ParquetWriterOptions.
@@ -262,13 +259,13 @@ class ParquetColumnOptions:
def __init__(
self,
- encoding: Optional[str] = None,
- dictionary_enabled: Optional[bool] = None,
- compression: Optional[str] = None,
- statistics_enabled: Optional[str] = None,
- bloom_filter_enabled: Optional[bool] = None,
- bloom_filter_fpp: Optional[float] = None,
- bloom_filter_ndv: Optional[int] = None,
+ encoding: str | None = None,
+ dictionary_enabled: bool | None = None,
+ compression: str | None = None,
+ statistics_enabled: str | None = None,
+ bloom_filter_enabled: bool | None = None,
+ bloom_filter_fpp: float | None = None,
+ bloom_filter_ndv: int | None = None,
) -> None:
"""Initialize the ParquetColumnOptions.
@@ -831,7 +828,7 @@ class DataFrame:
# of a keyword argument.
if (
isinstance(on, tuple)
- and len(on) == 2
+ and len(on) == 2 # noqa: PLR2004
and isinstance(on[0], list)
and isinstance(on[1], list)
):
@@ -1063,7 +1060,7 @@ class DataFrame:
def write_parquet(
self,
path: str | pathlib.Path,
- compression: Union[str, Compression, ParquetWriterOptions] =
Compression.ZSTD,
+ compression: str | Compression | ParquetWriterOptions =
Compression.ZSTD,
compression_level: int | None = None,
write_options: DataFrameWriteOptions | None = None,
) -> None:
diff --git a/python/datafusion/dataframe_formatter.py
b/python/datafusion/dataframe_formatter.py
index 2323224b..4082ff4e 100644
--- a/python/datafusion/dataframe_formatter.py
+++ b/python/datafusion/dataframe_formatter.py
@@ -19,15 +19,17 @@
from __future__ import annotations
from typing import (
+ TYPE_CHECKING,
Any,
- Callable,
- Optional,
Protocol,
runtime_checkable,
)
from datafusion._internal import DataFrame as DataFrameInternal
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
def _validate_positive_int(value: Any, param_name: str) -> None:
"""Validate that a parameter is a positive integer.
@@ -144,9 +146,9 @@ class DataFrameHtmlFormatter:
min_rows_display: int = 20,
repr_rows: int = 10,
enable_cell_expansion: bool = True,
- custom_css: Optional[str] = None,
+ custom_css: str | None = None,
show_truncation_message: bool = True,
- style_provider: Optional[StyleProvider] = None,
+ style_provider: StyleProvider | None = None,
use_shared_styles: bool = True,
) -> None:
"""Initialize the HTML formatter.
@@ -226,8 +228,8 @@ class DataFrameHtmlFormatter:
# Registry for custom type formatters
self._type_formatters: dict[type, CellFormatter] = {}
# Custom cell builders
- self._custom_cell_builder: Optional[Callable[[Any, int, int, str],
str]] = None
- self._custom_header_builder: Optional[Callable[[Any], str]] = None
+ self._custom_cell_builder: Callable[[Any, int, int, str], str] | None
= None
+ self._custom_header_builder: Callable[[Any], str] | None = None
def register_formatter(self, type_class: type, formatter: CellFormatter)
-> None:
"""Register a custom formatter for a specific data type.
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index a84b2e6d..3a6d0441 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -22,8 +22,8 @@ See :ref:`Expressions` in the online documentation for more
details.
from __future__ import annotations
-import typing as _typing
-from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Sequence
+from collections.abc import Iterable, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar
try:
from warnings import deprecated # Python 3.13+
@@ -230,7 +230,7 @@ __all__ = [
]
-def ensure_expr(value: _typing.Union[Expr, Any]) -> expr_internal.Expr:
+def ensure_expr(value: Expr | Any) -> expr_internal.Expr:
"""Return the internal expression from ``Expr`` or raise ``TypeError``.
This helper rejects plain strings and other non-:class:`Expr` values so
@@ -252,7 +252,7 @@ def ensure_expr(value: _typing.Union[Expr, Any]) ->
expr_internal.Expr:
def ensure_expr_list(
- exprs: Iterable[_typing.Union[Expr, Iterable[Expr]]],
+ exprs: Iterable[Expr | Iterable[Expr]],
) -> list[expr_internal.Expr]:
"""Flatten an iterable of expressions, validating each via ``ensure_expr``.
@@ -267,7 +267,7 @@ def ensure_expr_list(
"""
def _iter(
- items: Iterable[_typing.Union[Expr, Iterable[Expr]]],
+ items: Iterable[Expr | Iterable[Expr]],
) -> Iterable[expr_internal.Expr]:
for expr in items:
if isinstance(expr, Iterable) and not isinstance(
@@ -281,7 +281,7 @@ def ensure_expr_list(
return list(_iter(exprs))
-def _to_raw_expr(value: _typing.Union[Expr, str]) -> expr_internal.Expr:
+def _to_raw_expr(value: Expr | str) -> expr_internal.Expr:
"""Convert a Python expression or column name to its raw variant.
Args:
@@ -305,8 +305,8 @@ def _to_raw_expr(value: _typing.Union[Expr, str]) ->
expr_internal.Expr:
def expr_list_to_raw_expr_list(
- expr_list: Optional[list[Expr] | Expr],
-) -> Optional[list[expr_internal.Expr]]:
+ expr_list: list[Expr] | Expr | None,
+) -> list[expr_internal.Expr] | None:
"""Convert a sequence of expressions or column names to raw expressions."""
if isinstance(expr_list, Expr | str):
expr_list = [expr_list]
@@ -315,7 +315,7 @@ def expr_list_to_raw_expr_list(
return [_to_raw_expr(e) for e in expr_list]
-def sort_or_default(e: _typing.Union[Expr, SortExpr]) ->
expr_internal.SortExpr:
+def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
"""Helper function to return a default Sort if an Expr is provided."""
if isinstance(e, SortExpr):
return e.raw_sort
@@ -323,8 +323,8 @@ def sort_or_default(e: _typing.Union[Expr, SortExpr]) ->
expr_internal.SortExpr:
def sort_list_to_raw_sort_list(
- sort_list: Optional[_typing.Union[Sequence[SortKey], SortKey]],
-) -> Optional[list[expr_internal.SortExpr]]:
+ sort_list: Sequence[SortKey] | SortKey | None,
+) -> list[expr_internal.SortExpr] | None:
"""Helper function to return an optional sort list to raw variant."""
if isinstance(sort_list, Expr | SortExpr | str):
sort_list = [sort_list]
@@ -601,7 +601,7 @@ class Expr:
"""Creates a new expression representing a column."""
return Expr(expr_internal.RawExpr.column(value))
- def alias(self, name: str, metadata: Optional[dict[str, str]] = None) ->
Expr:
+ def alias(self, name: str, metadata: dict[str, str] | None = None) -> Expr:
"""Assign a name to the expression.
Args:
@@ -630,13 +630,13 @@ class Expr:
"""Returns ``True`` if this expression is not null."""
return Expr(self.expr.is_not_null())
- def fill_nan(self, value: Optional[_typing.Union[Any, Expr]] = None) ->
Expr:
+ def fill_nan(self, value: Any | Expr | None = None) -> Expr:
"""Fill NaN values with a provided value."""
if not isinstance(value, Expr):
value = Expr.literal(value)
return Expr(functions_internal.nanvl(self.expr, value.expr))
- def fill_null(self, value: Optional[_typing.Union[Any, Expr]] = None) ->
Expr:
+ def fill_null(self, value: Any | Expr | None = None) -> Expr:
"""Fill NULL values with a provided value."""
if not isinstance(value, Expr):
value = Expr.literal(value)
@@ -649,7 +649,7 @@ class Expr:
bool: pa.bool_(),
}
- def cast(self, to: _typing.Union[pa.DataType[Any], type]) -> Expr:
+ def cast(self, to: pa.DataType[Any] | type) -> Expr:
"""Cast to a new data type."""
if not isinstance(to, pa.DataType):
try:
@@ -722,7 +722,7 @@ class Expr:
"""Compute the output column name based on the provided logical
plan."""
return self.expr.column_name(plan._raw_plan)
- def order_by(self, *exprs: _typing.Union[Expr, SortExpr]) ->
ExprFuncBuilder:
+ def order_by(self, *exprs: Expr | SortExpr) -> ExprFuncBuilder:
"""Set the ordering for a window or aggregate function.
This function will create an :py:class:`ExprFuncBuilder` that can be
used to
@@ -1271,17 +1271,10 @@ class Window:
def __init__(
self,
- partition_by: Optional[_typing.Union[list[Expr], Expr]] = None,
- window_frame: Optional[WindowFrame] = None,
- order_by: Optional[
- _typing.Union[
- list[_typing.Union[SortExpr, Expr, str]],
- Expr,
- SortExpr,
- str,
- ]
- ] = None,
- null_treatment: Optional[NullTreatment] = None,
+ partition_by: list[Expr] | Expr | None = None,
+ window_frame: WindowFrame | None = None,
+ order_by: list[SortExpr | Expr | str] | Expr | SortExpr | str | None =
None,
+ null_treatment: NullTreatment | None = None,
) -> None:
"""Construct a window definition.
@@ -1301,7 +1294,7 @@ class WindowFrame:
"""Defines a window frame for performing window operations."""
def __init__(
- self, units: str, start_bound: Optional[Any], end_bound: Optional[Any]
+ self, units: str, start_bound: Any | None, end_bound: Any | None
) -> None:
"""Construct a window frame using the given parameters.
@@ -1351,7 +1344,7 @@ class WindowFrameBound:
"""Constructs a window frame bound."""
self.frame_bound = frame_bound
- def get_offset(self) -> Optional[int]:
+ def get_offset(self) -> int | None:
"""Returns the offset of the window frame."""
return self.frame_bound.get_offset()
@@ -1435,4 +1428,4 @@ class SortExpr:
return self.raw_sort.__repr__()
-SortKey = _typing.Union[Expr, SortExpr, str]
+SortKey = Expr | SortExpr | str
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index 472a02fc..7ae59c00 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -18,7 +18,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Optional
+from typing import TYPE_CHECKING, Any
import pyarrow as pa
@@ -379,7 +379,7 @@ def order_by(expr: Expr, ascending: bool = True,
nulls_first: bool = True) -> So
return SortExpr(expr, ascending=ascending, nulls_first=nulls_first)
-def alias(expr: Expr, name: str, metadata: Optional[dict[str, str]] = None) ->
Expr:
+def alias(expr: Expr, name: str, metadata: dict[str, str] | None = None) ->
Expr:
"""Creates an alias expression with an optional metadata dictionary.
Args:
@@ -398,7 +398,7 @@ def col(name: str) -> Expr:
return Expr(f.col(name))
-def count_star(filter: Optional[Expr] = None) -> Expr:
+def count_star(filter: Expr | None = None) -> Expr:
"""Create a COUNT(1) aggregate expression.
This aggregate function will count all of the rows in the partition.
@@ -1647,7 +1647,7 @@ def empty(array: Expr) -> Expr:
# aggregate functions
def approx_distinct(
expression: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Returns the approximate number of distinct values.
@@ -1667,7 +1667,7 @@ def approx_distinct(
return Expr(f.approx_distinct(expression.expr, filter=filter_raw))
-def approx_median(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def approx_median(expression: Expr, filter: Expr | None = None) -> Expr:
"""Returns the approximate median value.
This aggregate function is similar to :py:func:`median`, but it will only
@@ -1687,8 +1687,8 @@ def approx_median(expression: Expr, filter:
Optional[Expr] = None) -> Expr:
def approx_percentile_cont(
sort_expression: Expr | SortExpr,
percentile: float,
- num_centroids: Optional[int] = None,
- filter: Optional[Expr] = None,
+ num_centroids: int | None = None,
+ filter: Expr | None = None,
) -> Expr:
"""Returns the value that is approximately at a given percentile of
``expr``.
@@ -1724,8 +1724,8 @@ def approx_percentile_cont_with_weight(
sort_expression: Expr | SortExpr,
weight: Expr,
percentile: float,
- num_centroids: Optional[int] = None,
- filter: Optional[Expr] = None,
+ num_centroids: int | None = None,
+ filter: Expr | None = None,
) -> Expr:
"""Returns the value of the weighted approximate percentile.
@@ -1759,8 +1759,8 @@ def approx_percentile_cont_with_weight(
def array_agg(
expression: Expr,
distinct: bool = False,
- filter: Optional[Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ filter: Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
) -> Expr:
"""Aggregate values into an array.
@@ -1793,7 +1793,7 @@ def array_agg(
def avg(
expression: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Returns the average value.
@@ -1810,7 +1810,7 @@ def avg(
return Expr(f.avg(expression.expr, filter=filter_raw))
-def corr(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) -> Expr:
+def corr(value_y: Expr, value_x: Expr, filter: Expr | None = None) -> Expr:
"""Returns the correlation coefficient between ``value1`` and ``value2``.
This aggregate function expects both values to be numeric and will return
a float.
@@ -1830,7 +1830,7 @@ def corr(value_y: Expr, value_x: Expr, filter:
Optional[Expr] = None) -> Expr:
def count(
expressions: Expr | list[Expr] | None = None,
distinct: bool = False,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Returns the number of rows that match the given arguments.
@@ -1856,7 +1856,7 @@ def count(
return Expr(f.count(*args, distinct=distinct, filter=filter_raw))
-def covar_pop(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) ->
Expr:
+def covar_pop(value_y: Expr, value_x: Expr, filter: Expr | None = None) ->
Expr:
"""Computes the population covariance.
This aggregate function expects both values to be numeric and will return
a float.
@@ -1873,7 +1873,7 @@ def covar_pop(value_y: Expr, value_x: Expr, filter:
Optional[Expr] = None) -> Ex
return Expr(f.covar_pop(value_y.expr, value_x.expr, filter=filter_raw))
-def covar_samp(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) ->
Expr:
+def covar_samp(value_y: Expr, value_x: Expr, filter: Expr | None = None) ->
Expr:
"""Computes the sample covariance.
This aggregate function expects both values to be numeric and will return
a float.
@@ -1890,7 +1890,7 @@ def covar_samp(value_y: Expr, value_x: Expr, filter:
Optional[Expr] = None) -> E
return Expr(f.covar_samp(value_y.expr, value_x.expr, filter=filter_raw))
-def covar(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) -> Expr:
+def covar(value_y: Expr, value_x: Expr, filter: Expr | None = None) -> Expr:
"""Computes the sample covariance.
This is an alias for :py:func:`covar_samp`.
@@ -1898,7 +1898,7 @@ def covar(value_y: Expr, value_x: Expr, filter:
Optional[Expr] = None) -> Expr:
return covar_samp(value_y, value_x, filter)
-def max(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def max(expression: Expr, filter: Expr | None = None) -> Expr:
"""Aggregate function that returns the maximum value of the argument.
If using the builder functions described in ref:`_aggregation` this
function ignores
@@ -1912,7 +1912,7 @@ def max(expression: Expr, filter: Optional[Expr] = None)
-> Expr:
return Expr(f.max(expression.expr, filter=filter_raw))
-def mean(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def mean(expression: Expr, filter: Expr | None = None) -> Expr:
"""Returns the average (mean) value of the argument.
This is an alias for :py:func:`avg`.
@@ -1921,7 +1921,7 @@ def mean(expression: Expr, filter: Optional[Expr] = None)
-> Expr:
def median(
- expression: Expr, distinct: bool = False, filter: Optional[Expr] = None
+ expression: Expr, distinct: bool = False, filter: Expr | None = None
) -> Expr:
"""Computes the median of a set of numbers.
@@ -1940,7 +1940,7 @@ def median(
return Expr(f.median(expression.expr, distinct=distinct,
filter=filter_raw))
-def min(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def min(expression: Expr, filter: Expr | None = None) -> Expr:
"""Returns the minimum value of the argument.
If using the builder functions described in ref:`_aggregation` this
function ignores
@@ -1956,7 +1956,7 @@ def min(expression: Expr, filter: Optional[Expr] = None)
-> Expr:
def sum(
expression: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Computes the sum of a set of numbers.
@@ -1973,7 +1973,7 @@ def sum(
return Expr(f.sum(expression.expr, filter=filter_raw))
-def stddev(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def stddev(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the standard deviation of the argument.
If using the builder functions described in ref:`_aggregation` this
function ignores
@@ -1987,7 +1987,7 @@ def stddev(expression: Expr, filter: Optional[Expr] =
None) -> Expr:
return Expr(f.stddev(expression.expr, filter=filter_raw))
-def stddev_pop(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def stddev_pop(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the population standard deviation of the argument.
If using the builder functions described in ref:`_aggregation` this
function ignores
@@ -2001,7 +2001,7 @@ def stddev_pop(expression: Expr, filter: Optional[Expr] =
None) -> Expr:
return Expr(f.stddev_pop(expression.expr, filter=filter_raw))
-def stddev_samp(arg: Expr, filter: Optional[Expr] = None) -> Expr:
+def stddev_samp(arg: Expr, filter: Expr | None = None) -> Expr:
"""Computes the sample standard deviation of the argument.
This is an alias for :py:func:`stddev`.
@@ -2009,7 +2009,7 @@ def stddev_samp(arg: Expr, filter: Optional[Expr] = None)
-> Expr:
return stddev(arg, filter=filter)
-def var(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def var(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the sample variance of the argument.
This is an alias for :py:func:`var_samp`.
@@ -2017,7 +2017,7 @@ def var(expression: Expr, filter: Optional[Expr] = None)
-> Expr:
return var_samp(expression, filter)
-def var_pop(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def var_pop(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the population variance of the argument.
If using the builder functions described in ref:`_aggregation` this
function ignores
@@ -2031,7 +2031,7 @@ def var_pop(expression: Expr, filter: Optional[Expr] =
None) -> Expr:
return Expr(f.var_pop(expression.expr, filter=filter_raw))
-def var_samp(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def var_samp(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the sample variance of the argument.
If using the builder functions described in ref:`_aggregation` this
function ignores
@@ -2045,7 +2045,7 @@ def var_samp(expression: Expr, filter: Optional[Expr] =
None) -> Expr:
return Expr(f.var_sample(expression.expr, filter=filter_raw))
-def var_sample(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def var_sample(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the sample variance of the argument.
This is an alias for :py:func:`var_samp`.
@@ -2056,7 +2056,7 @@ def var_sample(expression: Expr, filter: Optional[Expr] =
None) -> Expr:
def regr_avgx(
y: Expr,
x: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Computes the average of the independent variable ``x``.
@@ -2079,7 +2079,7 @@ def regr_avgx(
def regr_avgy(
y: Expr,
x: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Computes the average of the dependent variable ``y``.
@@ -2102,7 +2102,7 @@ def regr_avgy(
def regr_count(
y: Expr,
x: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Counts the number of rows in which both expressions are not null.
@@ -2125,7 +2125,7 @@ def regr_count(
def regr_intercept(
y: Expr,
x: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Computes the intercept from the linear regression.
@@ -2148,7 +2148,7 @@ def regr_intercept(
def regr_r2(
y: Expr,
x: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Computes the R-squared value from linear regression.
@@ -2171,7 +2171,7 @@ def regr_r2(
def regr_slope(
y: Expr,
x: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Computes the slope from linear regression.
@@ -2194,7 +2194,7 @@ def regr_slope(
def regr_sxx(
y: Expr,
x: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Computes the sum of squares of the independent variable ``x``.
@@ -2217,7 +2217,7 @@ def regr_sxx(
def regr_sxy(
y: Expr,
x: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Computes the sum of products of pairs of numbers.
@@ -2240,7 +2240,7 @@ def regr_sxy(
def regr_syy(
y: Expr,
x: Expr,
- filter: Optional[Expr] = None,
+ filter: Expr | None = None,
) -> Expr:
"""Computes the sum of squares of the dependent variable ``y``.
@@ -2262,8 +2262,8 @@ def regr_syy(
def first_value(
expression: Expr,
- filter: Optional[Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ filter: Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the first value in a group of values.
@@ -2299,8 +2299,8 @@ def first_value(
def last_value(
expression: Expr,
- filter: Optional[Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ filter: Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the last value in a group of values.
@@ -2337,8 +2337,8 @@ def last_value(
def nth_value(
expression: Expr,
n: int,
- filter: Optional[Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ filter: Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the n-th value in a group of values.
@@ -2374,7 +2374,7 @@ def nth_value(
)
-def bit_and(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def bit_and(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the bitwise AND of the argument.
This aggregate function will bitwise compare every value in the input
partition.
@@ -2390,7 +2390,7 @@ def bit_and(expression: Expr, filter: Optional[Expr] =
None) -> Expr:
return Expr(f.bit_and(expression.expr, filter=filter_raw))
-def bit_or(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def bit_or(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the bitwise OR of the argument.
This aggregate function will bitwise compare every value in the input
partition.
@@ -2407,7 +2407,7 @@ def bit_or(expression: Expr, filter: Optional[Expr] =
None) -> Expr:
def bit_xor(
- expression: Expr, distinct: bool = False, filter: Optional[Expr] = None
+ expression: Expr, distinct: bool = False, filter: Expr | None = None
) -> Expr:
"""Computes the bitwise XOR of the argument.
@@ -2425,7 +2425,7 @@ def bit_xor(
return Expr(f.bit_xor(expression.expr, distinct=distinct,
filter=filter_raw))
-def bool_and(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def bool_and(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the boolean AND of the argument.
This aggregate function will compare every value in the input partition.
These are
@@ -2442,7 +2442,7 @@ def bool_and(expression: Expr, filter: Optional[Expr] =
None) -> Expr:
return Expr(f.bool_and(expression.expr, filter=filter_raw))
-def bool_or(expression: Expr, filter: Optional[Expr] = None) -> Expr:
+def bool_or(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the boolean OR of the argument.
This aggregate function will compare every value in the input partition.
These are
@@ -2462,9 +2462,9 @@ def bool_or(expression: Expr, filter: Optional[Expr] =
None) -> Expr:
def lead(
arg: Expr,
shift_offset: int = 1,
- default_value: Optional[Any] = None,
- partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ default_value: Any | None = None,
+ partition_by: list[Expr] | Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
) -> Expr:
"""Create a lead window function.
@@ -2520,9 +2520,9 @@ def lead(
def lag(
arg: Expr,
shift_offset: int = 1,
- default_value: Optional[Any] = None,
- partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ default_value: Any | None = None,
+ partition_by: list[Expr] | Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
) -> Expr:
"""Create a lag window function.
@@ -2573,8 +2573,8 @@ def lag(
def row_number(
- partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ partition_by: list[Expr] | Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
) -> Expr:
"""Create a row number window function.
@@ -2612,8 +2612,8 @@ def row_number(
def rank(
- partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ partition_by: list[Expr] | Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
) -> Expr:
"""Create a rank window function.
@@ -2656,8 +2656,8 @@ def rank(
def dense_rank(
- partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ partition_by: list[Expr] | Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
) -> Expr:
"""Create a dense_rank window function.
@@ -2695,8 +2695,8 @@ def dense_rank(
def percent_rank(
- partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ partition_by: list[Expr] | Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
) -> Expr:
"""Create a percent_rank window function.
@@ -2735,8 +2735,8 @@ def percent_rank(
def cume_dist(
- partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ partition_by: list[Expr] | Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
) -> Expr:
"""Create a cumulative distribution window function.
@@ -2776,8 +2776,8 @@ def cume_dist(
def ntile(
groups: int,
- partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ partition_by: list[Expr] | Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
) -> Expr:
"""Create a n-tile window function.
@@ -2822,8 +2822,8 @@ def ntile(
def string_agg(
expression: Expr,
delimiter: str,
- filter: Optional[Expr] = None,
- order_by: Optional[list[SortKey] | SortKey] = None,
+ filter: Expr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
) -> Expr:
"""Concatenates the input strings.
diff --git a/python/datafusion/input/location.py
b/python/datafusion/input/location.py
index 08d98d11..b804ac18 100644
--- a/python/datafusion/input/location.py
+++ b/python/datafusion/input/location.py
@@ -17,7 +17,6 @@
"""The default input source for DataFusion."""
-import glob
from pathlib import Path
from typing import Any
@@ -84,6 +83,7 @@ class LocationInputPlugin(BaseInputSource):
raise RuntimeError(msg)
# Input could possibly be multiple files. Create a list if so
- input_files = glob.glob(input_item)
+ input_path = Path(input_item)
+ input_files = [str(p) for p in input_path.parent.glob(input_path.name)]
return SqlTable(table_name, columns, num_rows, input_files)
diff --git a/python/datafusion/user_defined.py
b/python/datafusion/user_defined.py
index 67568e31..21b2de63 100644
--- a/python/datafusion/user_defined.py
+++ b/python/datafusion/user_defined.py
@@ -22,7 +22,7 @@ from __future__ import annotations
import functools
from abc import ABCMeta, abstractmethod
from enum import Enum
-from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar,
overload
+from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload
import pyarrow as pa
@@ -31,6 +31,7 @@ from datafusion.expr import Expr
if TYPE_CHECKING:
_R = TypeVar("_R", bound=pa.DataType)
+ from collections.abc import Callable
class Volatility(Enum):
@@ -130,7 +131,7 @@ class ScalarUDF:
input_types: list[pa.DataType],
return_type: _R,
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> Callable[..., ScalarUDF]: ...
@overload
@@ -140,7 +141,7 @@ class ScalarUDF:
input_types: list[pa.DataType],
return_type: _R,
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> ScalarUDF: ...
@overload
@@ -194,7 +195,7 @@ class ScalarUDF:
input_types: list[pa.DataType],
return_type: _R,
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> ScalarUDF:
if not callable(func):
msg = "`func` argument must be callable"
@@ -216,15 +217,15 @@ class ScalarUDF:
input_types: list[pa.DataType],
return_type: _R,
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> Callable:
- def decorator(func: Callable):
+ def decorator(func: Callable) -> Callable:
udf_caller = ScalarUDF.udf(
func, input_types, return_type, volatility, name
)
@functools.wraps(func)
- def wrapper(*args: Any, **kwargs: Any):
+ def wrapper(*args: Any, **kwargs: Any) -> Callable:
return udf_caller(*args, **kwargs)
return wrapper
@@ -336,7 +337,7 @@ class AggregateUDF:
return_type: pa.DataType,
state_type: list[pa.DataType],
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> Callable[..., AggregateUDF]: ...
@overload
@@ -347,7 +348,7 @@ class AggregateUDF:
return_type: pa.DataType,
state_type: list[pa.DataType],
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> AggregateUDF: ...
@staticmethod
@@ -429,7 +430,7 @@ class AggregateUDF:
return_type: pa.DataType,
state_type: list[pa.DataType],
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> AggregateUDF:
if not callable(accum):
msg = "`func` must be callable."
@@ -455,7 +456,7 @@ class AggregateUDF:
return_type: pa.DataType,
state_type: list[pa.DataType],
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> Callable[..., Callable[..., Expr]]:
def decorator(accum: Callable[[], Accumulator]) -> Callable[...,
Expr]:
udaf_caller = AggregateUDF.udaf(
@@ -708,7 +709,7 @@ class WindowUDF:
input_types: pa.DataType | list[pa.DataType],
return_type: pa.DataType,
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> Callable[..., WindowUDF]: ...
@overload
@@ -718,7 +719,7 @@ class WindowUDF:
input_types: pa.DataType | list[pa.DataType],
return_type: pa.DataType,
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> WindowUDF: ...
@staticmethod
@@ -787,7 +788,7 @@ class WindowUDF:
input_types: pa.DataType | list[pa.DataType],
return_type: pa.DataType,
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> WindowUDF:
"""Create a WindowUDF instance from function arguments."""
if not callable(func):
@@ -825,7 +826,7 @@ class WindowUDF:
input_types: pa.DataType | list[pa.DataType],
return_type: pa.DataType,
volatility: Volatility | str,
- name: Optional[str] = None,
+ name: str | None = None,
) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
"""Create a decorator for a WindowUDF."""
@@ -922,7 +923,7 @@ class TableFunction:
@staticmethod
def _create_table_udf_decorator(
- name: Optional[str] = None,
+ name: str | None = None,
) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
"""Create a decorator for a WindowUDF."""
diff --git a/python/tests/test_aggregation.py b/python/tests/test_aggregation.py
index 17767ea1..f595127f 100644
--- a/python/tests/test_aggregation.py
+++ b/python/tests/test_aggregation.py
@@ -88,7 +88,7 @@ def df_aggregate_100():
f.covar_samp(column("b"), column("c")),
lambda a, b, c, d: np.array(np.cov(b, c, ddof=1)[0][1]),
),
- # f.grouping(col_a), # No physical plan implemented yet
+ # f.grouping(col_a), # noqa: ERA001 No physical plan implemented yet
(f.max(column("a")), lambda a, b, c, d: np.array(np.max(a))),
(f.mean(column("b")), lambda a, b, c, d: np.array(np.mean(b))),
(f.median(column("b")), lambda a, b, c, d: np.array(np.median(b))),
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index c3a5253c..aed477af 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -21,6 +21,7 @@ import os
import re
import threading
import time
+from pathlib import Path
from typing import Any
import pyarrow as pa
@@ -1040,33 +1041,14 @@ def test_invalid_window_frame(units, start_bound,
end_bound):
def test_window_frame_defaults_match_postgres(partitioned_df):
- # ref: https://github.com/apache/datafusion-python/issues/688
-
- window_frame = WindowFrame("rows", None, None)
-
col_a = column("a")
- # Using `f.window` with or without an unbounded window_frame produces the
same
- # results. These tests are included as a regression check but can be
removed when
- # f.window() is deprecated in favor of using the .over() approach.
- no_frame = f.window("avg", [col_a]).alias("no_frame")
- with_frame = f.window("avg", [col_a],
window_frame=window_frame).alias("with_frame")
- df_1 = partitioned_df.select(col_a, no_frame, with_frame)
-
- expected = {
- "a": [0, 1, 2, 3, 4, 5, 6],
- "no_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
- "with_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
- }
-
- assert df_1.sort(col_a).to_pydict() == expected
-
# When order is not set, the default frame should be unbounded preceding to
# unbounded following. When order is set, the default frame is unbounded
preceding
# to current row.
no_order = f.avg(col_a).over(Window()).alias("over_no_order")
with_order =
f.avg(col_a).over(Window(order_by=[col_a])).alias("over_with_order")
- df_2 = partitioned_df.select(col_a, no_order, with_order)
+ df = partitioned_df.select(col_a, no_order, with_order)
expected = {
"a": [0, 1, 2, 3, 4, 5, 6],
@@ -1074,7 +1056,7 @@ def
test_window_frame_defaults_match_postgres(partitioned_df):
"over_with_order": [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0],
}
- assert df_2.sort(col_a).to_pydict() == expected
+ assert df.sort(col_a).to_pydict() == expected
def _build_last_value_df(df):
@@ -2413,11 +2395,11 @@ def test_write_parquet_with_options_bloom_filter(df,
tmp_path):
size_no_bloom_filter = 0
for file in path_no_bloom_filter.rglob("*.parquet"):
- size_no_bloom_filter += os.path.getsize(file)
+ size_no_bloom_filter += Path(file).stat().st_size
size_bloom_filter = 0
for file in path_bloom_filter.rglob("*.parquet"):
- size_bloom_filter += os.path.getsize(file)
+ size_bloom_filter += Path(file).stat().st_size
assert size_no_bloom_filter < size_bloom_filter
diff --git a/python/tests/test_pyclass_frozen.py
b/python/tests/test_pyclass_frozen.py
index 189ea8de..3500c5e3 100644
--- a/python/tests/test_pyclass_frozen.py
+++ b/python/tests/test_pyclass_frozen.py
@@ -22,7 +22,10 @@ from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
-from typing import Iterator
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from collections.abc import Iterator
PYCLASS_RE = re.compile(
r"#\[\s*pyclass\s*(?:\((?P<args>.*?)\))?\s*\]",
diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py
index 635f0c33..8f57992d 100644
--- a/python/tests/test_sql.py
+++ b/python/tests/test_sql.py
@@ -181,13 +181,24 @@ def test_register_parquet_partitioned(ctx, tmp_path,
path_to_str, legacy_data_ty
partition_data_type = "string" if legacy_data_type else pa.string()
- ctx.register_parquet(
- "datapp",
- dir_root,
- table_partition_cols=[("grp", partition_data_type)],
- parquet_pruning=True,
- file_extension=".parquet",
- )
+ if legacy_data_type:
+ with pytest.warns(DeprecationWarning):
+ ctx.register_parquet(
+ "datapp",
+ dir_root,
+ table_partition_cols=[("grp", partition_data_type)],
+ parquet_pruning=True,
+ file_extension=".parquet",
+ )
+ else:
+ ctx.register_parquet(
+ "datapp",
+ dir_root,
+ table_partition_cols=[("grp", partition_data_type)],
+ parquet_pruning=True,
+ file_extension=".parquet",
+ )
+
assert ctx.catalog().schema().names() == {"datapp"}
result = ctx.sql("SELECT grp, COUNT(*) AS cnt FROM datapp GROUP BY
grp").collect()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]