This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new b194a877 feat/improve ruff test coverage (#1055)
b194a877 is described below
commit b194a8772e58ccefc697e11671113127a8038716
Author: Tim Saucer <[email protected]>
AuthorDate: Wed Mar 12 14:25:32 2025 -0400
feat/improve ruff test coverage (#1055)
* Run python tests on all currently supported python versions
* Update ruff checks to select all
* Ruff auto fix
* Applying ruff suggestions
* noqa rules updates per ruff checks
* Working through more ruff suggestions
* Working through more ruff suggestions
* update timestamps on tests
* More ruff updates
* More ruff updates
* Instead of importing udf static functions as variables, import
* More ruff formatting suggestions
* more ruff formatting suggestions
* More ruff formatting
* More ruff formatting
* Cut off lint errors for this PR
* Working through more ruff checks and disabling a bunch for now
* Address CI difference from local ruff
* UDWF isn't a proper abstract base class right now since users can opt in
to all methods
* Update pre-commit to match the version of ruff used in CI
* To enable testing in python 3.9 we need numpy. Also going to the current
minimal supported version
* Update min requried version of python to 3.9 in pyproject.toml. The other
changes will come in #1043 that is soon to be merged.
* Suppress UP035
* ruff format
---
.github/workflows/test.yaml | 2 +
.pre-commit-config.yaml | 2 +-
benchmarks/tpch/tpch.py | 14 +-
dev/release/check-rat-report.py | 2 +-
dev/release/generate-changelog.py | 10 +-
docs/source/conf.py | 2 +-
examples/python-udwf.py | 2 +-
examples/tpch/_tests.py | 15 +-
pyproject.toml | 76 +++++++-
python/datafusion/__init__.py | 50 +++--
python/datafusion/common.py | 14 +-
python/datafusion/context.py | 4 +-
python/datafusion/dataframe.py | 15 +-
python/datafusion/expr.py | 94 ++++-----
python/datafusion/functions.py | 46 ++---
python/datafusion/input/__init__.py | 2 +-
python/datafusion/input/base.py | 6 +-
python/datafusion/input/location.py | 40 ++--
python/datafusion/io.py | 20 +-
python/datafusion/object_store.py | 2 +-
python/datafusion/plan.py | 8 +-
python/datafusion/record_batch.py | 8 +-
python/datafusion/substrait.py | 21 +-
python/datafusion/udf.py | 236 ++++++++++++----------
python/tests/generic.py | 19 +-
python/tests/test_aggregation.py | 16 +-
python/tests/test_catalog.py | 9 +-
python/tests/test_context.py | 53 ++---
python/tests/test_dataframe.py | 38 ++--
python/tests/test_expr.py | 11 +-
python/tests/test_functions.py | 358 ++++++++++++++++++----------------
python/tests/test_imports.py | 7 +-
python/tests/test_input.py | 12 +-
python/tests/test_io.py | 13 +-
python/tests/test_sql.py | 35 ++--
python/tests/test_store.py | 13 +-
python/tests/test_substrait.py | 2 +-
python/tests/test_udaf.py | 10 +-
python/tests/test_udwf.py | 2 +-
python/tests/test_wrapper_coverage.py | 7 +-
40 files changed, 697 insertions(+), 599 deletions(-)
diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index c1d9ac83..da358276 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -33,9 +33,11 @@ jobs:
fail-fast: false
matrix:
python-version:
+ - "3.9"
- "3.10"
- "3.11"
- "3.12"
+ - "3.13"
toolchain:
- "stable"
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b548ff18..abcfcf32 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -22,7 +22,7 @@ repos:
- id: actionlint-docker
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
- rev: v0.3.0
+ rev: v0.9.10
hooks:
# Run the linter.
- id: ruff
diff --git a/benchmarks/tpch/tpch.py b/benchmarks/tpch/tpch.py
index fb86b12b..bfb9ac39 100644
--- a/benchmarks/tpch/tpch.py
+++ b/benchmarks/tpch/tpch.py
@@ -59,13 +59,13 @@ def bench(data_path, query_path):
end = time.time()
time_millis = (end - start) * 1000
total_time_millis += time_millis
- print("setup,{}".format(round(time_millis, 1)))
- results.write("setup,{}\n".format(round(time_millis, 1)))
+ print(f"setup,{round(time_millis, 1)}")
+ results.write(f"setup,{round(time_millis, 1)}\n")
results.flush()
# run queries
for query in range(1, 23):
- with open("{}/q{}.sql".format(query_path, query)) as f:
+ with open(f"{query_path}/q{query}.sql") as f:
text = f.read()
tmp = text.split(";")
queries = []
@@ -83,14 +83,14 @@ def bench(data_path, query_path):
end = time.time()
time_millis = (end - start) * 1000
total_time_millis += time_millis
- print("q{},{}".format(query, round(time_millis, 1)))
- results.write("q{},{}\n".format(query, round(time_millis,
1)))
+ print(f"q{query},{round(time_millis, 1)}")
+ results.write(f"q{query},{round(time_millis, 1)}\n")
results.flush()
except Exception as e:
print("query", query, "failed", e)
- print("total,{}".format(round(total_time_millis, 1)))
- results.write("total,{}\n".format(round(total_time_millis, 1)))
+ print(f"total,{round(total_time_millis, 1)}")
+ results.write(f"total,{round(total_time_millis, 1)}\n")
if __name__ == "__main__":
diff --git a/dev/release/check-rat-report.py b/dev/release/check-rat-report.py
index d3dd7c5d..0c9f4c32 100644
--- a/dev/release/check-rat-report.py
+++ b/dev/release/check-rat-report.py
@@ -29,7 +29,7 @@ if len(sys.argv) != 3:
exclude_globs_filename = sys.argv[1]
xml_filename = sys.argv[2]
-globs = [line.strip() for line in open(exclude_globs_filename, "r")]
+globs = [line.strip() for line in open(exclude_globs_filename)]
tree = ET.parse(xml_filename)
root = tree.getroot()
diff --git a/dev/release/generate-changelog.py
b/dev/release/generate-changelog.py
index 2564eea8..e30e2def 100755
--- a/dev/release/generate-changelog.py
+++ b/dev/release/generate-changelog.py
@@ -26,15 +26,11 @@ from github import Github
def print_pulls(repo_name, title, pulls):
if len(pulls) > 0:
- print("**{}:**".format(title))
+ print(f"**{title}:**")
print()
for pull, commit in pulls:
- url = "https://github.com/{}/pull/{}".format(repo_name,
pull.number)
- print(
- "- {} [#{}]({}) ({})".format(
- pull.title, pull.number, url, commit.author.login
- )
- )
+ url = f"https://github.com/{repo_name}/pull/{pull.number}"
+ print(f"- {pull.title} [#{pull.number}]({url})
({commit.author.login})")
print()
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 2e5a4133..c82a189e 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -73,7 +73,7 @@ suppress_warnings = ["autoapi.python_import_resolution"]
autoapi_python_class_content = "both"
-def autoapi_skip_member_fn(app, what, name, obj, skip, options):
+def autoapi_skip_member_fn(app, what, name, obj, skip, options): # noqa:
ARG001
skip_contents = [
# Re-exports
("class", "datafusion.DataFrame"),
diff --git a/examples/python-udwf.py b/examples/python-udwf.py
index 7d39dc1b..98d118bf 100644
--- a/examples/python-udwf.py
+++ b/examples/python-udwf.py
@@ -59,7 +59,7 @@ class SmoothBoundedFromPreviousRow(WindowEvaluator):
def supports_bounded_execution(self) -> bool:
return True
- def get_range(self, idx: int, num_rows: int) -> tuple[int, int]:
+ def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: # noqa:
ARG002
# Override the default range of current row since uses_window_frame is
False
# So for the purpose of this test we just smooth from the previous row
to
# current.
diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py
index c4d87208..2be4dfab 100644
--- a/examples/tpch/_tests.py
+++ b/examples/tpch/_tests.py
@@ -27,28 +27,25 @@ from util import get_answer_file
def df_selection(col_name, col_type):
if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type):
return F.round(col(col_name), lit(2)).alias(col_name)
- elif col_type == pa.string() or col_type == pa.string_view():
+ if col_type == pa.string() or col_type == pa.string_view():
return F.trim(col(col_name)).alias(col_name)
- else:
- return col(col_name)
+ return col(col_name)
def load_schema(col_name, col_type):
if col_type == pa.int64() or col_type == pa.int32():
return col_name, pa.string()
- elif isinstance(col_type, pa.Decimal128Type):
+ if isinstance(col_type, pa.Decimal128Type):
return col_name, pa.float64()
- else:
- return col_name, col_type
+ return col_name, col_type
def expected_selection(col_name, col_type):
if col_type == pa.int64() or col_type == pa.int32():
return F.trim(col(col_name)).cast(col_type).alias(col_name)
- elif col_type == pa.string() or col_type == pa.string_view():
+ if col_type == pa.string() or col_type == pa.string_view():
return F.trim(col(col_name)).alias(col_name)
- else:
- return col(col_name)
+ return col(col_name)
def selections_and_schema(original_schema):
diff --git a/pyproject.toml b/pyproject.toml
index 1c273367..060e3b80 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -65,7 +65,57 @@ features = ["substrait"]
# Enable docstring linting using the google style guide
[tool.ruff.lint]
-select = ["E4", "E7", "E9", "F", "FA", "D", "W", "I"]
+select = ["ALL" ]
+ignore = [
+ "A001", # Allow using words like min as variable names
+ "A002", # Allow using words like filter as variable names
+ "ANN401", # Allow Any for wrapper classes
+ "COM812", # Recommended to ignore these rules when using with ruff-format
+ "FIX002", # Allow TODO lines - consider removing at some point
+ "FBT001", # Allow boolean positional args
+ "FBT002", # Allow boolean positional args
+ "ISC001", # Recommended to ignore these rules when using with ruff-format
+ "SLF001", # Allow accessing private members
+ "TD002",
+ "TD003", # Allow TODO lines
+ "UP007", # Disallowing Union is pedantic
+ # TODO: Enable all of the following, but this PR is getting too large
already
+ "PT001",
+ "ANN204",
+ "B008",
+ "EM101",
+ "PLR0913",
+ "PLR1714",
+ "ANN201",
+ "C400",
+ "TRY003",
+ "B904",
+ "UP006",
+ "RUF012",
+ "FBT003",
+ "C416",
+ "SIM102",
+ "PGH003",
+ "PLR2004",
+ "PERF401",
+ "PD901",
+ "EM102",
+ "ERA001",
+ "SIM108",
+ "ICN001",
+ "ANN001",
+ "ANN202",
+ "PTH",
+ "N812",
+ "INP001",
+ "DTZ007",
+ "PLW2901",
+ "RET503",
+ "RUF015",
+ "A005",
+ "TC001",
+ "UP035",
+]
[tool.ruff.lint.pydocstyle]
convention = "google"
@@ -75,16 +125,30 @@ max-doc-length = 88
# Disable docstring checking for these directories
[tool.ruff.lint.per-file-ignores]
-"python/tests/*" = ["D"]
-"examples/*" = ["D", "W505"]
-"dev/*" = ["D"]
-"benchmarks/*" = ["D", "F"]
+"python/tests/*" = [
+ "ANN",
+ "ARG",
+ "BLE001",
+ "D",
+ "S101",
+ "SLF",
+ "PD",
+ "PLR2004",
+ "PT011",
+ "RUF015",
+ "S608",
+ "PLR0913",
+ "PT004",
+]
+"examples/*" = ["D", "W505", "E501", "T201", "S101"]
+"dev/*" = ["D", "E", "T", "S", "PLR", "C", "SIM", "UP", "EXE", "N817"]
+"benchmarks/*" = ["D", "F", "T", "BLE", "FURB", "PLR", "E", "TD", "TRY", "S",
"SIM", "EXE", "UP"]
"docs/*" = ["D"]
[dependency-groups]
dev = [
"maturin>=1.8.1",
- "numpy>1.24.4 ; python_full_version >= '3.10'",
+ "numpy>1.25.0",
"pytest>=7.4.4",
"ruff>=0.9.1",
"toml>=0.10.2",
diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py
index f11ce54a..286e5dc3 100644
--- a/python/datafusion/__init__.py
+++ b/python/datafusion/__init__.py
@@ -48,44 +48,47 @@ from .expr import (
from .io import read_avro, read_csv, read_json, read_parquet
from .plan import ExecutionPlan, LogicalPlan
from .record_batch import RecordBatch, RecordBatchStream
-from .udf import Accumulator, AggregateUDF, ScalarUDF, WindowUDF
+from .udf import Accumulator, AggregateUDF, ScalarUDF, WindowUDF, udaf, udf,
udwf
__version__ = importlib_metadata.version(__name__)
__all__ = [
"Accumulator",
+ "AggregateUDF",
+ "Catalog",
"Config",
- "DataFrame",
- "SessionContext",
- "SessionConfig",
- "SQLOptions",
- "RuntimeEnvBuilder",
- "Expr",
- "ScalarUDF",
- "WindowFrame",
- "column",
- "col",
- "literal",
- "lit",
"DFSchema",
- "Catalog",
+ "DataFrame",
"Database",
- "Table",
- "AggregateUDF",
- "WindowUDF",
- "LogicalPlan",
"ExecutionPlan",
+ "Expr",
+ "LogicalPlan",
"RecordBatch",
"RecordBatchStream",
+ "RuntimeEnvBuilder",
+ "SQLOptions",
+ "ScalarUDF",
+ "SessionConfig",
+ "SessionContext",
+ "Table",
+ "WindowFrame",
+ "WindowUDF",
+ "col",
+ "column",
"common",
"expr",
"functions",
+ "lit",
+ "literal",
"object_store",
- "substrait",
- "read_parquet",
"read_avro",
"read_csv",
"read_json",
+ "read_parquet",
+ "substrait",
+ "udaf",
+ "udf",
+ "udwf",
]
@@ -120,10 +123,3 @@ def str_lit(value):
def lit(value):
"""Create a literal expression."""
return Expr.literal(value)
-
-
-udf = ScalarUDF.udf
-
-udaf = AggregateUDF.udaf
-
-udwf = WindowUDF.udwf
diff --git a/python/datafusion/common.py b/python/datafusion/common.py
index a2298c63..e762a993 100644
--- a/python/datafusion/common.py
+++ b/python/datafusion/common.py
@@ -20,7 +20,7 @@ from enum import Enum
from ._internal import common as common_internal
-# TODO these should all have proper wrapper classes
+# TODO: these should all have proper wrapper classes
DFSchema = common_internal.DFSchema
DataType = common_internal.DataType
@@ -38,15 +38,15 @@ __all__ = [
"DFSchema",
"DataType",
"DataTypeMap",
- "RexType",
- "PythonType",
- "SqlType",
"NullTreatment",
- "SqlTable",
+ "PythonType",
+ "RexType",
+ "SqlFunction",
"SqlSchema",
- "SqlView",
"SqlStatistics",
- "SqlFunction",
+ "SqlTable",
+ "SqlType",
+ "SqlView",
]
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index 282b2a47..0ab1a908 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -393,8 +393,6 @@ class RuntimeEnvBuilder:
class RuntimeConfig(RuntimeEnvBuilder):
"""See `RuntimeEnvBuilder`."""
- pass
-
class SQLOptions:
"""Options to be used when performing SQL queries."""
@@ -498,7 +496,7 @@ class SessionContext:
self.ctx = SessionContextInternal(config, runtime)
- def enable_url_table(self) -> "SessionContext":
+ def enable_url_table(self) -> SessionContext:
"""Control if local files can be queried as tables.
Returns:
diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py
index de5d8376..d1c71c2b 100644
--- a/python/datafusion/dataframe.py
+++ b/python/datafusion/dataframe.py
@@ -29,6 +29,7 @@ from typing import (
List,
Literal,
Optional,
+ Type,
Union,
overload,
)
@@ -49,10 +50,11 @@ if TYPE_CHECKING:
import polars as pl
import pyarrow as pa
+ from datafusion._internal import DataFrame as DataFrameInternal
+ from datafusion._internal import expr as expr_internal
+
from enum import Enum
-from datafusion._internal import DataFrame as DataFrameInternal
-from datafusion._internal import expr as expr_internal
from datafusion.expr import Expr, SortExpr, sort_or_default
@@ -73,7 +75,7 @@ class Compression(Enum):
LZ4_RAW = "lz4_raw"
@classmethod
- def from_str(cls, value: str) -> "Compression":
+ def from_str(cls: Type[Compression], value: str) -> Compression:
"""Convert a string to a Compression enum value.
Args:
@@ -88,8 +90,9 @@ class Compression(Enum):
try:
return cls(value.lower())
except ValueError:
+ valid_values = str([item.value for item in Compression])
raise ValueError(
- f"{value} is not a valid Compression. Valid values are:
{[item.value for item in Compression]}"
+ f"{value} is not a valid Compression. Valid values are:
{valid_values}"
)
def get_default_level(self) -> Optional[int]:
@@ -104,9 +107,9 @@ class Compression(Enum):
#
https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223
if self == Compression.GZIP:
return 6
- elif self == Compression.BROTLI:
+ if self == Compression.BROTLI:
return 1
- elif self == Compression.ZSTD:
+ if self == Compression.ZSTD:
return 4
return None
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index 3639abec..702f75ae 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -101,63 +101,63 @@ UnnestExpr = expr_internal.UnnestExpr
WindowExpr = expr_internal.WindowExpr
__all__ = [
- "Expr",
- "Column",
- "Literal",
- "BinaryExpr",
- "Literal",
+ "Aggregate",
"AggregateFunction",
- "Not",
- "IsNotNull",
- "IsNull",
- "IsTrue",
- "IsFalse",
- "IsUnknown",
- "IsNotTrue",
- "IsNotFalse",
- "IsNotUnknown",
- "Negative",
- "Like",
- "ILike",
- "SimilarTo",
- "ScalarVariable",
"Alias",
- "InList",
- "Exists",
- "Subquery",
- "InSubquery",
- "ScalarSubquery",
- "Placeholder",
- "GroupingSet",
+ "Analyze",
+ "Between",
+ "BinaryExpr",
"Case",
"CaseBuilder",
"Cast",
- "TryCast",
- "Between",
+ "Column",
+ "CreateMemoryTable",
+ "CreateView",
+ "Distinct",
+ "DropTable",
+ "EmptyRelation",
+ "Exists",
"Explain",
+ "Expr",
+ "Extension",
+ "Filter",
+ "GroupingSet",
+ "ILike",
+ "InList",
+ "InSubquery",
+ "IsFalse",
+ "IsNotFalse",
+ "IsNotNull",
+ "IsNotTrue",
+ "IsNotUnknown",
+ "IsNull",
+ "IsTrue",
+ "IsUnknown",
+ "Join",
+ "JoinConstraint",
+ "JoinType",
+ "Like",
"Limit",
- "Aggregate",
+ "Literal",
+ "Literal",
+ "Negative",
+ "Not",
+ "Partitioning",
+ "Placeholder",
+ "Projection",
+ "Repartition",
+ "ScalarSubquery",
+ "ScalarVariable",
+ "SimilarTo",
"Sort",
"SortExpr",
- "Analyze",
- "EmptyRelation",
- "Join",
- "JoinType",
- "JoinConstraint",
+ "Subquery",
+ "SubqueryAlias",
+ "TableScan",
+ "TryCast",
"Union",
"Unnest",
"UnnestExpr",
- "Extension",
- "Filter",
- "Projection",
- "TableScan",
- "CreateMemoryTable",
- "CreateView",
- "Distinct",
- "SubqueryAlias",
- "DropTable",
- "Partitioning",
- "Repartition",
"Window",
"WindowExpr",
"WindowFrame",
@@ -311,7 +311,7 @@ class Expr:
)
return Expr(self.expr.__getitem__(key))
- def __eq__(self, rhs: Any) -> Expr:
+ def __eq__(self, rhs: object) -> Expr:
"""Equal to.
Accepts either an expression or any valid PyArrow scalar literal value.
@@ -320,7 +320,7 @@ class Expr:
rhs = Expr.literal(rhs)
return Expr(self.expr.__eq__(rhs.expr))
- def __ne__(self, rhs: Any) -> Expr:
+ def __ne__(self, rhs: object) -> Expr:
"""Not equal to.
Accepts either an expression or any valid PyArrow scalar literal value.
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index b449c486..0cc7434c 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -18,13 +18,12 @@
from __future__ import annotations
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Optional
import pyarrow as pa
from datafusion._internal import functions as f
from datafusion.common import NullTreatment
-from datafusion.context import SessionContext
from datafusion.expr import (
CaseBuilder,
Expr,
@@ -34,6 +33,9 @@ from datafusion.expr import (
sort_list_to_raw_sort_list,
)
+if TYPE_CHECKING:
+ from datafusion.context import SessionContext
+
__all__ = [
"abs",
"acos",
@@ -81,8 +83,8 @@ __all__ = [
"array_sort",
"array_to_string",
"array_union",
- "arrow_typeof",
"arrow_cast",
+ "arrow_typeof",
"ascii",
"asin",
"asinh",
@@ -97,6 +99,7 @@ __all__ = [
"bool_and",
"bool_or",
"btrim",
+ "cardinality",
"case",
"cbrt",
"ceil",
@@ -116,6 +119,7 @@ __all__ = [
"covar",
"covar_pop",
"covar_samp",
+ "cume_dist",
"current_date",
"current_time",
"date_bin",
@@ -125,17 +129,17 @@ __all__ = [
"datetrunc",
"decode",
"degrees",
+ "dense_rank",
"digest",
"empty",
"encode",
"ends_with",
- "extract",
"exp",
+ "extract",
"factorial",
"find_in_set",
"first_value",
"flatten",
- "cardinality",
"floor",
"from_unixtime",
"gcd",
@@ -143,8 +147,10 @@ __all__ = [
"initcap",
"isnan",
"iszero",
+ "lag",
"last_value",
"lcm",
+ "lead",
"left",
"length",
"levenshtein",
@@ -166,10 +172,10 @@ __all__ = [
"list_prepend",
"list_push_back",
"list_push_front",
- "list_repeat",
"list_remove",
"list_remove_all",
"list_remove_n",
+ "list_repeat",
"list_replace",
"list_replace_all",
"list_replace_n",
@@ -180,14 +186,14 @@ __all__ = [
"list_union",
"ln",
"log",
- "log10",
"log2",
+ "log10",
"lower",
"lpad",
"ltrim",
"make_array",
- "make_list",
"make_date",
+ "make_list",
"max",
"md5",
"mean",
@@ -195,19 +201,22 @@ __all__ = [
"min",
"named_struct",
"nanvl",
- "nvl",
"now",
"nth_value",
+ "ntile",
"nullif",
+ "nvl",
"octet_length",
"order_by",
"overlay",
+ "percent_rank",
"pi",
"pow",
"power",
"radians",
"random",
"range",
+ "rank",
"regexp_like",
"regexp_match",
"regexp_replace",
@@ -225,6 +234,7 @@ __all__ = [
"reverse",
"right",
"round",
+ "row_number",
"rpad",
"rtrim",
"sha224",
@@ -252,8 +262,8 @@ __all__ = [
"to_hex",
"to_timestamp",
"to_timestamp_micros",
- "to_timestamp_nanos",
"to_timestamp_millis",
+ "to_timestamp_nanos",
"to_timestamp_seconds",
"to_unixtime",
"translate",
@@ -268,14 +278,6 @@ __all__ = [
"when",
# Window Functions
"window",
- "lead",
- "lag",
- "row_number",
- "rank",
- "dense_rank",
- "percent_rank",
- "cume_dist",
- "ntile",
]
@@ -292,14 +294,14 @@ def nullif(expr1: Expr, expr2: Expr) -> Expr:
return Expr(f.nullif(expr1.expr, expr2.expr))
-def encode(input: Expr, encoding: Expr) -> Expr:
+def encode(expr: Expr, encoding: Expr) -> Expr:
"""Encode the ``input``, using the ``encoding``. encoding can be base64 or
hex."""
- return Expr(f.encode(input.expr, encoding.expr))
+ return Expr(f.encode(expr.expr, encoding.expr))
-def decode(input: Expr, encoding: Expr) -> Expr:
+def decode(expr: Expr, encoding: Expr) -> Expr:
"""Decode the ``input``, using the ``encoding``. encoding can be base64 or
hex."""
- return Expr(f.decode(input.expr, encoding.expr))
+ return Expr(f.decode(expr.expr, encoding.expr))
def array_to_string(expr: Expr, delimiter: Expr) -> Expr:
diff --git a/python/datafusion/input/__init__.py
b/python/datafusion/input/__init__.py
index f85ce21f..f0c1f42b 100644
--- a/python/datafusion/input/__init__.py
+++ b/python/datafusion/input/__init__.py
@@ -23,5 +23,5 @@ The primary class used within DataFusion is
``LocationInputPlugin``.
from .location import LocationInputPlugin
__all__ = [
- LocationInputPlugin,
+ "LocationInputPlugin",
]
diff --git a/python/datafusion/input/base.py b/python/datafusion/input/base.py
index 4eba1978..f67dde2a 100644
--- a/python/datafusion/input/base.py
+++ b/python/datafusion/input/base.py
@@ -38,11 +38,9 @@ class BaseInputSource(ABC):
"""
@abstractmethod
- def is_correct_input(self, input_item: Any, table_name: str, **kwargs) ->
bool:
+ def is_correct_input(self, input_item: Any, table_name: str, **kwargs:
Any) -> bool:
"""Returns `True` if the input is valid."""
- pass
@abstractmethod
- def build_table(self, input_item: Any, table_name: str, **kwarg) ->
SqlTable:
+ def build_table(self, input_item: Any, table_name: str, **kwarg: Any) ->
SqlTable: # type: ignore[invalid-type-form]
"""Create a table from the input source."""
- pass
diff --git a/python/datafusion/input/location.py
b/python/datafusion/input/location.py
index 517cd157..08d98d11 100644
--- a/python/datafusion/input/location.py
+++ b/python/datafusion/input/location.py
@@ -18,7 +18,7 @@
"""The default input source for DataFusion."""
import glob
-import os
+from pathlib import Path
from typing import Any
from datafusion.common import DataTypeMap, SqlTable
@@ -31,7 +31,7 @@ class LocationInputPlugin(BaseInputSource):
This can be read in from a file (on disk, remote etc.).
"""
- def is_correct_input(self, input_item: Any, table_name: str, **kwargs):
+ def is_correct_input(self, input_item: Any, table_name: str, **kwargs:
Any) -> bool: # noqa: ARG002
"""Returns `True` if the input is valid."""
return isinstance(input_item, str)
@@ -39,27 +39,28 @@ class LocationInputPlugin(BaseInputSource):
self,
input_item: str,
table_name: str,
- **kwargs,
- ) -> SqlTable:
+ **kwargs: Any, # noqa: ARG002
+ ) -> SqlTable: # type: ignore[invalid-type-form]
"""Create a table from the input source."""
- _, extension = os.path.splitext(input_item)
- format = extension.lstrip(".").lower()
+ extension = Path(input_item).suffix
+ file_format = extension.lstrip(".").lower()
num_rows = 0 # Total number of rows in the file. Used for statistics
columns = []
- if format == "parquet":
+ if file_format == "parquet":
import pyarrow.parquet as pq
# Read the Parquet metadata
metadata = pq.read_metadata(input_item)
num_rows = metadata.num_rows
# Iterate through the schema and build the SqlTable
- for col in metadata.schema:
- columns.append(
- (
- col.name,
- DataTypeMap.from_parquet_type_str(col.physical_type),
- )
+ columns = [
+ (
+ col.name,
+ DataTypeMap.from_parquet_type_str(col.physical_type),
)
+ for col in metadata.schema
+ ]
+
elif format == "csv":
import csv
@@ -69,19 +70,18 @@ class LocationInputPlugin(BaseInputSource):
# to get that information. However, this should only be occurring
# at table creation time and therefore shouldn't
# slow down query performance.
- with open(input_item, "r") as file:
+ with Path(input_item).open() as file:
reader = csv.reader(file)
- header_row = next(reader)
- print(header_row)
+ _header_row = next(reader)
for _ in reader:
num_rows += 1
# TODO: Need to actually consume this row into reasonable columns
- raise RuntimeError("TODO: Currently unable to support CSV input
files.")
+ msg = "TODO: Currently unable to support CSV input files."
+ raise RuntimeError(msg)
else:
- raise RuntimeError(
- f"Input of format: `{format}` is currently not supported.\
+ msg = f"Input of format: `{format}` is currently not supported.\
Only Parquet and CSV."
- )
+ raise RuntimeError(msg)
# Input could possibly be multiple files. Create a list if so
input_files = glob.glob(input_item)
diff --git a/python/datafusion/io.py b/python/datafusion/io.py
index 3b626494..3e39703e 100644
--- a/python/datafusion/io.py
+++ b/python/datafusion/io.py
@@ -19,15 +19,19 @@
from __future__ import annotations
-import pathlib
-
-import pyarrow
+from typing import TYPE_CHECKING
from datafusion.dataframe import DataFrame
-from datafusion.expr import Expr
from ._internal import SessionContext as SessionContextInternal
+if TYPE_CHECKING:
+ import pathlib
+
+ import pyarrow as pa
+
+ from datafusion.expr import Expr
+
def read_parquet(
path: str | pathlib.Path,
@@ -35,7 +39,7 @@ def read_parquet(
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
file_sort_order: list[list[Expr]] | None = None,
) -> DataFrame:
"""Read a Parquet source into a
:py:class:`~datafusion.dataframe.Dataframe`.
@@ -79,7 +83,7 @@ def read_parquet(
def read_json(
path: str | pathlib.Path,
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
@@ -120,7 +124,7 @@ def read_json(
def read_csv(
path: str | pathlib.Path | list[str] | list[pathlib.Path],
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
has_header: bool = True,
delimiter: str = ",",
schema_infer_max_records: int = 1000,
@@ -173,7 +177,7 @@ def read_csv(
def read_avro(
path: str | pathlib.Path,
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
file_partition_cols: list[tuple[str, str]] | None = None,
file_extension: str = ".avro",
) -> DataFrame:
diff --git a/python/datafusion/object_store.py
b/python/datafusion/object_store.py
index 7cc17506..6298526f 100644
--- a/python/datafusion/object_store.py
+++ b/python/datafusion/object_store.py
@@ -24,4 +24,4 @@ LocalFileSystem = object_store.LocalFileSystem
MicrosoftAzure = object_store.MicrosoftAzure
Http = object_store.Http
-__all__ = ["AmazonS3", "GoogleCloud", "LocalFileSystem", "MicrosoftAzure",
"Http"]
+__all__ = ["AmazonS3", "GoogleCloud", "Http", "LocalFileSystem",
"MicrosoftAzure"]
diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py
index 133fc446..0b7bebcb 100644
--- a/python/datafusion/plan.py
+++ b/python/datafusion/plan.py
@@ -19,7 +19,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, List
+from typing import TYPE_CHECKING, Any
import datafusion._internal as df_internal
@@ -27,8 +27,8 @@ if TYPE_CHECKING:
from datafusion.context import SessionContext
__all__ = [
- "LogicalPlan",
"ExecutionPlan",
+ "LogicalPlan",
]
@@ -54,7 +54,7 @@ class LogicalPlan:
"""Convert the logical plan into its specific variant."""
return self._raw_plan.to_variant()
- def inputs(self) -> List[LogicalPlan]:
+ def inputs(self) -> list[LogicalPlan]:
"""Returns the list of inputs to the logical plan."""
return [LogicalPlan(p) for p in self._raw_plan.inputs()]
@@ -106,7 +106,7 @@ class ExecutionPlan:
"""This constructor should not be called by the end user."""
self._raw_plan = plan
- def children(self) -> List[ExecutionPlan]:
+ def children(self) -> list[ExecutionPlan]:
"""Get a list of children `ExecutionPlan` that act as inputs to this
plan.
The returned list will be empty for leaf nodes such as scans, will
contain a
diff --git a/python/datafusion/record_batch.py
b/python/datafusion/record_batch.py
index 772cd908..556eaa78 100644
--- a/python/datafusion/record_batch.py
+++ b/python/datafusion/record_batch.py
@@ -26,14 +26,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- import pyarrow
+ import pyarrow as pa
import typing_extensions
import datafusion._internal as df_internal
class RecordBatch:
- """This class is essentially a wrapper for
:py:class:`pyarrow.RecordBatch`."""
+ """This class is essentially a wrapper for :py:class:`pa.RecordBatch`."""
def __init__(self, record_batch: df_internal.RecordBatch) -> None:
"""This constructor is generally not called by the end user.
@@ -42,8 +42,8 @@ class RecordBatch:
"""
self.record_batch = record_batch
- def to_pyarrow(self) -> pyarrow.RecordBatch:
- """Convert to :py:class:`pyarrow.RecordBatch`."""
+ def to_pyarrow(self) -> pa.RecordBatch:
+ """Convert to :py:class:`pa.RecordBatch`."""
return self.record_batch.to_pyarrow()
diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py
index 06302fe3..f10adfb0 100644
--- a/python/datafusion/substrait.py
+++ b/python/datafusion/substrait.py
@@ -23,7 +23,6 @@ information about substrait.
from __future__ import annotations
-import pathlib
from typing import TYPE_CHECKING
try:
@@ -36,11 +35,13 @@ from datafusion.plan import LogicalPlan
from ._internal import substrait as substrait_internal
if TYPE_CHECKING:
+ import pathlib
+
from datafusion.context import SessionContext
__all__ = [
- "Plan",
"Consumer",
+ "Plan",
"Producer",
"Serde",
]
@@ -68,11 +69,9 @@ class Plan:
@deprecated("Use `Plan` instead.")
-class plan(Plan):
+class plan(Plan): # noqa: N801
"""See `Plan`."""
- pass
-
class Serde:
"""Provides the ``Substrait`` serialization and deserialization."""
@@ -140,11 +139,9 @@ class Serde:
@deprecated("Use `Serde` instead.")
-class serde(Serde):
+class serde(Serde): # noqa: N801
"""See `Serde` instead."""
- pass
-
class Producer:
"""Generates substrait plans from a logical plan."""
@@ -168,11 +165,9 @@ class Producer:
@deprecated("Use `Producer` instead.")
-class producer(Producer):
+class producer(Producer): # noqa: N801
"""Use `Producer` instead."""
- pass
-
class Consumer:
"""Generates a logical plan from a substrait plan."""
@@ -194,7 +189,5 @@ class Consumer:
@deprecated("Use `Consumer` instead.")
-class consumer(Consumer):
+class consumer(Consumer): # noqa: N801
"""Use `Consumer` instead."""
-
- pass
diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py
index af7bcf2e..603b7063 100644
--- a/python/datafusion/udf.py
+++ b/python/datafusion/udf.py
@@ -22,15 +22,15 @@ from __future__ import annotations
import functools
from abc import ABCMeta, abstractmethod
from enum import Enum
-from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload
-import pyarrow
+import pyarrow as pa
import datafusion._internal as df_internal
from datafusion.expr import Expr
if TYPE_CHECKING:
- _R = TypeVar("_R", bound=pyarrow.DataType)
+ _R = TypeVar("_R", bound=pa.DataType)
class Volatility(Enum):
@@ -72,7 +72,7 @@ class Volatility(Enum):
for each output row, resulting in a unique random value for each row.
"""
- def __str__(self):
+ def __str__(self) -> str:
"""Returns the string equivalent."""
return self.name.lower()
@@ -88,7 +88,7 @@ class ScalarUDF:
self,
name: str,
func: Callable[..., _R],
- input_types: pyarrow.DataType | list[pyarrow.DataType],
+ input_types: pa.DataType | list[pa.DataType],
return_type: _R,
volatility: Volatility | str,
) -> None:
@@ -96,7 +96,7 @@ class ScalarUDF:
See helper method :py:func:`udf` for argument details.
"""
- if isinstance(input_types, pyarrow.DataType):
+ if isinstance(input_types, pa.DataType):
input_types = [input_types]
self._udf = df_internal.ScalarUDF(
name, func, input_types, return_type, str(volatility)
@@ -111,7 +111,27 @@ class ScalarUDF:
args_raw = [arg.expr for arg in args]
return Expr(self._udf.__call__(*args_raw))
- class udf:
+ @overload
+ @staticmethod
+ def udf(
+ input_types: list[pa.DataType],
+ return_type: _R,
+ volatility: Volatility | str,
+ name: Optional[str] = None,
+ ) -> Callable[..., ScalarUDF]: ...
+
+ @overload
+ @staticmethod
+ def udf(
+ func: Callable[..., _R],
+ input_types: list[pa.DataType],
+ return_type: _R,
+ volatility: Volatility | str,
+ name: Optional[str] = None,
+ ) -> ScalarUDF: ...
+
+ @staticmethod
+ def udf(*args: Any, **kwargs: Any): # noqa: D417
"""Create a new User-Defined Function (UDF).
This class can be used both as a **function** and as a **decorator**.
@@ -125,7 +145,7 @@ class ScalarUDF:
Args:
func (Callable, optional): **Only needed when calling as a
function.**
Skip this argument when using `udf` as a decorator.
- input_types (list[pyarrow.DataType]): The data types of the
arguments
+ input_types (list[pa.DataType]): The data types of the arguments
to `func`. This list must be of the same length as the number
of
arguments.
return_type (_R): The data type of the return value from the
function.
@@ -141,40 +161,28 @@ class ScalarUDF:
```
def double_func(x):
return x * 2
- double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(),
+ double_udf = udf(double_func, [pa.int32()], pa.int32(),
"volatile", "double_it")
```
**Using `udf` as a decorator:**
```
- @udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it")
+ @udf([pa.int32()], pa.int32(), "volatile", "double_it")
def double_udf(x):
return x * 2
```
"""
- def __new__(cls, *args, **kwargs):
- """Create a new UDF.
-
- Trigger UDF function or decorator depending on if the first args
is callable
- """
- if args and callable(args[0]):
- # Case 1: Used as a function, require the first parameter to
be callable
- return cls._function(*args, **kwargs)
- else:
- # Case 2: Used as a decorator with parameters
- return cls._decorator(*args, **kwargs)
-
- @staticmethod
def _function(
func: Callable[..., _R],
- input_types: list[pyarrow.DataType],
+ input_types: list[pa.DataType],
return_type: _R,
volatility: Volatility | str,
name: Optional[str] = None,
) -> ScalarUDF:
if not callable(func):
- raise TypeError("`func` argument must be callable")
+ msg = "`func` argument must be callable"
+ raise TypeError(msg)
if name is None:
if hasattr(func, "__qualname__"):
name = func.__qualname__.lower()
@@ -188,49 +196,50 @@ class ScalarUDF:
volatility=volatility,
)
- @staticmethod
def _decorator(
- input_types: list[pyarrow.DataType],
+ input_types: list[pa.DataType],
return_type: _R,
volatility: Volatility | str,
name: Optional[str] = None,
- ):
- def decorator(func):
+ ) -> Callable:
+ def decorator(func: Callable):
udf_caller = ScalarUDF.udf(
func, input_types, return_type, volatility, name
)
@functools.wraps(func)
- def wrapper(*args, **kwargs):
+ def wrapper(*args: Any, **kwargs: Any):
return udf_caller(*args, **kwargs)
return wrapper
return decorator
+ if args and callable(args[0]):
+ # Case 1: Used as a function, require the first parameter to be
callable
+ return _function(*args, **kwargs)
+ # Case 2: Used as a decorator with parameters
+ return _decorator(*args, **kwargs)
+
class Accumulator(metaclass=ABCMeta):
"""Defines how an :py:class:`AggregateUDF` accumulates values."""
@abstractmethod
- def state(self) -> List[pyarrow.Scalar]:
+ def state(self) -> list[pa.Scalar]:
"""Return the current state."""
- pass
@abstractmethod
- def update(self, *values: pyarrow.Array) -> None:
+ def update(self, *values: pa.Array) -> None:
"""Evaluate an array of values and update state."""
- pass
@abstractmethod
- def merge(self, states: List[pyarrow.Array]) -> None:
+ def merge(self, states: list[pa.Array]) -> None:
"""Merge a set of states."""
- pass
@abstractmethod
- def evaluate(self) -> pyarrow.Scalar:
+ def evaluate(self) -> pa.Scalar:
"""Return the resultant value."""
- pass
class AggregateUDF:
@@ -244,9 +253,9 @@ class AggregateUDF:
self,
name: str,
accumulator: Callable[[], Accumulator],
- input_types: list[pyarrow.DataType],
- return_type: pyarrow.DataType,
- state_type: list[pyarrow.DataType],
+ input_types: list[pa.DataType],
+ return_type: pa.DataType,
+ state_type: list[pa.DataType],
volatility: Volatility | str,
) -> None:
"""Instantiate a user-defined aggregate function (UDAF).
@@ -272,7 +281,29 @@ class AggregateUDF:
args_raw = [arg.expr for arg in args]
return Expr(self._udaf.__call__(*args_raw))
- class udaf:
+ @overload
+ @staticmethod
+ def udaf(
+ input_types: pa.DataType | list[pa.DataType],
+ return_type: pa.DataType,
+ state_type: list[pa.DataType],
+ volatility: Volatility | str,
+ name: Optional[str] = None,
+ ) -> Callable[..., AggregateUDF]: ...
+
+ @overload
+ @staticmethod
+ def udaf(
+ accum: Callable[[], Accumulator],
+ input_types: pa.DataType | list[pa.DataType],
+ return_type: pa.DataType,
+ state_type: list[pa.DataType],
+ volatility: Volatility | str,
+ name: Optional[str] = None,
+ ) -> AggregateUDF: ...
+
+ @staticmethod
+ def udaf(*args: Any, **kwargs: Any): # noqa: D417
"""Create a new User-Defined Aggregate Function (UDAF).
This class allows you to define an **aggregate function** that can be
used in
@@ -300,13 +331,13 @@ class AggregateUDF:
def __init__(self, bias: float = 0.0):
self._sum = pa.scalar(bias)
- def state(self) -> List[pa.Scalar]:
+ def state(self) -> list[pa.Scalar]:
return [self._sum]
def update(self, values: pa.Array) -> None:
self._sum = pa.scalar(self._sum.as_py() +
pc.sum(values).as_py())
- def merge(self, states: List[pa.Array]) -> None:
+ def merge(self, states: list[pa.Array]) -> None:
self._sum = pa.scalar(self._sum.as_py() +
pc.sum(states[0]).as_py())
def evaluate(self) -> pa.Scalar:
@@ -344,37 +375,23 @@ class AggregateUDF:
aggregation or window function calls.
"""
- def __new__(cls, *args, **kwargs):
- """Create a new UDAF.
-
- Trigger UDAF function or decorator depending on if the first args
is
- callable
- """
- if args and callable(args[0]):
- # Case 1: Used as a function, require the first parameter to
be callable
- return cls._function(*args, **kwargs)
- else:
- # Case 2: Used as a decorator with parameters
- return cls._decorator(*args, **kwargs)
-
- @staticmethod
def _function(
accum: Callable[[], Accumulator],
- input_types: pyarrow.DataType | list[pyarrow.DataType],
- return_type: pyarrow.DataType,
- state_type: list[pyarrow.DataType],
+ input_types: pa.DataType | list[pa.DataType],
+ return_type: pa.DataType,
+ state_type: list[pa.DataType],
volatility: Volatility | str,
name: Optional[str] = None,
) -> AggregateUDF:
if not callable(accum):
- raise TypeError("`func` must be callable.")
- if not isinstance(accum.__call__(), Accumulator):
- raise TypeError(
- "Accumulator must implement the abstract base class
Accumulator"
- )
+ msg = "`func` must be callable."
+ raise TypeError(msg)
+ if not isinstance(accum(), Accumulator):
+ msg = "Accumulator must implement the abstract base class
Accumulator"
+ raise TypeError(msg)
if name is None:
- name = accum.__call__().__class__.__qualname__.lower()
- if isinstance(input_types, pyarrow.DataType):
+ name = accum().__class__.__qualname__.lower()
+ if isinstance(input_types, pa.DataType):
input_types = [input_types]
return AggregateUDF(
name=name,
@@ -385,29 +402,34 @@ class AggregateUDF:
volatility=volatility,
)
- @staticmethod
def _decorator(
- input_types: pyarrow.DataType | list[pyarrow.DataType],
- return_type: pyarrow.DataType,
- state_type: list[pyarrow.DataType],
+ input_types: pa.DataType | list[pa.DataType],
+ return_type: pa.DataType,
+ state_type: list[pa.DataType],
volatility: Volatility | str,
name: Optional[str] = None,
- ):
- def decorator(accum: Callable[[], Accumulator]):
+ ) -> Callable[..., Callable[..., Expr]]:
+ def decorator(accum: Callable[[], Accumulator]) -> Callable[...,
Expr]:
udaf_caller = AggregateUDF.udaf(
accum, input_types, return_type, state_type, volatility,
name
)
@functools.wraps(accum)
- def wrapper(*args, **kwargs):
+ def wrapper(*args: Any, **kwargs: Any) -> Expr:
return udaf_caller(*args, **kwargs)
return wrapper
return decorator
+ if args and callable(args[0]):
+ # Case 1: Used as a function, require the first parameter to be
callable
+ return _function(*args, **kwargs)
+ # Case 2: Used as a decorator with parameters
+ return _decorator(*args, **kwargs)
+
-class WindowEvaluator(metaclass=ABCMeta):
+class WindowEvaluator:
"""Evaluator class for user-defined window functions (UDWF).
It is up to the user to decide which evaluate function is appropriate.
@@ -423,7 +445,7 @@ class WindowEvaluator(metaclass=ABCMeta):
+------------------------+--------------------------------+------------------+---------------------------+
| True | True/False | True/False
| ``evaluate`` |
+------------------------+--------------------------------+------------------+---------------------------+
- """ # noqa: W505
+ """ # noqa: W505, E501
def memoize(self) -> None:
"""Perform a memoize operation to improve performance.
@@ -436,9 +458,8 @@ class WindowEvaluator(metaclass=ABCMeta):
`memoize` is called after each input batch is processed, and
such functions can save whatever they need
"""
- pass
- def get_range(self, idx: int, num_rows: int) -> tuple[int, int]:
+ def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: # noqa:
ARG002
"""Return the range for the window fuction.
If `uses_window_frame` flag is `false`. This method is used to
@@ -460,14 +481,17 @@ class WindowEvaluator(metaclass=ABCMeta):
"""Get whether evaluator needs future data for its result."""
return False
- def evaluate_all(self, values: list[pyarrow.Array], num_rows: int) ->
pyarrow.Array:
+ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
"""Evaluate a window function on an entire input partition.
This function is called once per input *partition* for window
functions that
*do not use* values from the window frame, such as
- :py:func:`~datafusion.functions.row_number`,
:py:func:`~datafusion.functions.rank`,
- :py:func:`~datafusion.functions.dense_rank`,
:py:func:`~datafusion.functions.percent_rank`,
- :py:func:`~datafusion.functions.cume_dist`,
:py:func:`~datafusion.functions.lead`,
+ :py:func:`~datafusion.functions.row_number`,
+ :py:func:`~datafusion.functions.rank`,
+ :py:func:`~datafusion.functions.dense_rank`,
+ :py:func:`~datafusion.functions.percent_rank`,
+ :py:func:`~datafusion.functions.cume_dist`,
+ :py:func:`~datafusion.functions.lead`,
and :py:func:`~datafusion.functions.lag`.
It produces the result of all rows in a single pass. It
@@ -499,12 +523,11 @@ class WindowEvaluator(metaclass=ABCMeta):
.. code-block:: text
avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING
AND 3 FOLLOWING)
- """ # noqa: W505
- pass
+ """ # noqa: W505, E501
def evaluate(
- self, values: list[pyarrow.Array], eval_range: tuple[int, int]
- ) -> pyarrow.Scalar:
+ self, values: list[pa.Array], eval_range: tuple[int, int]
+ ) -> pa.Scalar:
"""Evaluate window function on a range of rows in an input partition.
This is the simplest and most general function to implement
@@ -519,11 +542,10 @@ class WindowEvaluator(metaclass=ABCMeta):
and evaluation results of ORDER BY expressions. If function has a
single argument, `values[1..]` will contain ORDER BY expression
results.
"""
- pass
def evaluate_all_with_rank(
self, num_rows: int, ranks_in_partition: list[tuple[int, int]]
- ) -> pyarrow.Array:
+ ) -> pa.Array:
"""Called for window functions that only need the rank of a row.
Evaluate the partition evaluator against the partition using
@@ -552,7 +574,6 @@ class WindowEvaluator(metaclass=ABCMeta):
The user must implement this method if ``include_rank`` returns True.
"""
- pass
def supports_bounded_execution(self) -> bool:
"""Can the window function be incrementally computed using bounded
memory?"""
@@ -567,10 +588,6 @@ class WindowEvaluator(metaclass=ABCMeta):
return False
-if TYPE_CHECKING:
- _W = TypeVar("_W", bound=WindowEvaluator)
-
-
class WindowUDF:
"""Class for performing window user-defined functions (UDF).
@@ -582,8 +599,8 @@ class WindowUDF:
self,
name: str,
func: Callable[[], WindowEvaluator],
- input_types: list[pyarrow.DataType],
- return_type: pyarrow.DataType,
+ input_types: list[pa.DataType],
+ return_type: pa.DataType,
volatility: Volatility | str,
) -> None:
"""Instantiate a user-defined window function (UDWF).
@@ -607,8 +624,8 @@ class WindowUDF:
@staticmethod
def udwf(
func: Callable[[], WindowEvaluator],
- input_types: pyarrow.DataType | list[pyarrow.DataType],
- return_type: pyarrow.DataType,
+ input_types: pa.DataType | list[pa.DataType],
+ return_type: pa.DataType,
volatility: Volatility | str,
name: Optional[str] = None,
) -> WindowUDF:
@@ -648,16 +665,16 @@ class WindowUDF:
Returns:
A user-defined window function.
- """ # noqa W505
+ """ # noqa: W505, E501
if not callable(func):
- raise TypeError("`func` must be callable.")
- if not isinstance(func.__call__(), WindowEvaluator):
- raise TypeError(
- "`func` must implement the abstract base class WindowEvaluator"
- )
+ msg = "`func` must be callable."
+ raise TypeError(msg)
+ if not isinstance(func(), WindowEvaluator):
+ msg = "`func` must implement the abstract base class
WindowEvaluator"
+ raise TypeError(msg)
if name is None:
- name = func.__call__().__class__.__qualname__.lower()
- if isinstance(input_types, pyarrow.DataType):
+ name = func().__class__.__qualname__.lower()
+ if isinstance(input_types, pa.DataType):
input_types = [input_types]
return WindowUDF(
name=name,
@@ -666,3 +683,10 @@ class WindowUDF:
return_type=return_type,
volatility=volatility,
)
+
+
+# Convenience exports so we can import instead of treating as
+# variables at the package root
+udf = ScalarUDF.udf
+udaf = AggregateUDF.udaf
+udwf = WindowUDF.udwf
diff --git a/python/tests/generic.py b/python/tests/generic.py
index 0177e2df..1b98fdf9 100644
--- a/python/tests/generic.py
+++ b/python/tests/generic.py
@@ -16,6 +16,7 @@
# under the License.
import datetime
+from datetime import timezone
import numpy as np
import pyarrow as pa
@@ -26,29 +27,29 @@ import pyarrow.parquet as pq
def data():
- np.random.seed(1)
+ rng = np.random.default_rng(1)
data = np.concatenate(
[
- np.random.normal(0, 0.01, size=50),
- np.random.normal(50, 0.01, size=50),
+ rng.normal(0, 0.01, size=50),
+ rng.normal(50, 0.01, size=50),
]
)
return pa.array(data)
def data_with_nans():
- np.random.seed(0)
- data = np.random.normal(0, 0.01, size=50)
- mask = np.random.randint(0, 2, size=50)
+ rng = np.random.default_rng(0)
+ data = rng.normal(0, 0.01, size=50)
+ mask = rng.normal(0, 2, size=50)
data[mask == 0] = np.nan
return data
def data_datetime(f):
data = [
- datetime.datetime.now(),
- datetime.datetime.now() - datetime.timedelta(days=1),
- datetime.datetime.now() + datetime.timedelta(days=1),
+ datetime.datetime.now(tz=timezone.utc),
+ datetime.datetime.now(tz=timezone.utc) - datetime.timedelta(days=1),
+ datetime.datetime.now(tz=timezone.utc) + datetime.timedelta(days=1),
]
return pa.array(data, type=pa.timestamp(f), mask=np.array([False, True,
False]))
diff --git a/python/tests/test_aggregation.py b/python/tests/test_aggregation.py
index 5ef46131..61b1c7d8 100644
--- a/python/tests/test_aggregation.py
+++ b/python/tests/test_aggregation.py
@@ -66,7 +66,7 @@ def df_aggregate_100():
@pytest.mark.parametrize(
- "agg_expr, calc_expected",
+ ("agg_expr", "calc_expected"),
[
(f.avg(column("a")), lambda a, b, c, d: np.array(np.average(a))),
(
@@ -114,7 +114,7 @@ def test_aggregation_stats(df, agg_expr, calc_expected):
@pytest.mark.parametrize(
- "agg_expr, expected, array_sort",
+ ("agg_expr", "expected", "array_sort"),
[
(f.approx_distinct(column("b")), pa.array([2], type=pa.uint64()),
False),
(
@@ -182,12 +182,11 @@ def test_aggregation(df, agg_expr, expected, array_sort):
agg_df.show()
result = agg_df.collect()[0]
- print(result)
assert result.column(0) == expected
@pytest.mark.parametrize(
- "name,expr,expected",
+ ("name", "expr", "expected"),
[
(
"approx_percentile_cont",
@@ -299,7 +298,9 @@ data_test_bitwise_and_boolean_functions = [
]
[email protected]("name,expr,result",
data_test_bitwise_and_boolean_functions)
[email protected](
+ ("name", "expr", "result"), data_test_bitwise_and_boolean_functions
+)
def test_bit_and_bool_fns(df, name, expr, result):
df = df.aggregate([], [expr.alias(name)])
@@ -311,7 +312,7 @@ def test_bit_and_bool_fns(df, name, expr, result):
@pytest.mark.parametrize(
- "name,expr,result",
+ ("name", "expr", "result"),
[
("first_value", f.first_value(column("a")), [0, 4]),
(
@@ -361,7 +362,6 @@ def test_bit_and_bool_fns(df, name, expr, result):
),
[8, 9],
),
- ("first_value", f.first_value(column("a")), [0, 4]),
(
"nth_value_ordered",
f.nth_value(column("a"), 2,
order_by=[column("a").sort(ascending=False)]),
@@ -401,7 +401,7 @@ def test_first_last_value(df_partitioned, name, expr,
result) -> None:
@pytest.mark.parametrize(
- "name,expr,result",
+ ("name", "expr", "result"),
[
("string_agg", f.string_agg(column("a"), ","), "one,two,three,two"),
("string_agg", f.string_agg(column("b"), ""), "03124"),
diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py
index 214f6b16..23b32845 100644
--- a/python/tests/test_catalog.py
+++ b/python/tests/test_catalog.py
@@ -19,6 +19,9 @@ import pyarrow as pa
import pytest
+# Note we take in `database` as a variable even though we don't use
+# it because that will cause the fixture to set up the context with
+# the tables we need.
def test_basic(ctx, database):
with pytest.raises(KeyError):
ctx.catalog("non-existent")
@@ -26,10 +29,10 @@ def test_basic(ctx, database):
default = ctx.catalog()
assert default.names() == ["public"]
- for database in [default.database("public"), default.database()]:
- assert database.names() == {"csv1", "csv", "csv2"}
+ for db in [default.database("public"), default.database()]:
+ assert db.names() == {"csv1", "csv", "csv2"}
- table = database.table("csv")
+ table = db.table("csv")
assert table.kind == "physical"
assert table.schema == pa.schema(
[
diff --git a/python/tests/test_context.py b/python/tests/test_context.py
index 91046e6b..7a0a7aa0 100644
--- a/python/tests/test_context.py
+++ b/python/tests/test_context.py
@@ -16,7 +16,6 @@
# under the License.
import datetime as dt
import gzip
-import os
import pathlib
import pyarrow as pa
@@ -45,7 +44,7 @@ def test_create_context_runtime_config_only():
SessionContext(runtime=RuntimeEnvBuilder())
[email protected]("path_to_str", (True, False))
[email protected]("path_to_str", [True, False])
def test_runtime_configs(tmp_path, path_to_str):
path1 = tmp_path / "dir1"
path2 = tmp_path / "dir2"
@@ -62,7 +61,7 @@ def test_runtime_configs(tmp_path, path_to_str):
assert db is not None
[email protected]("path_to_str", (True, False))
[email protected]("path_to_str", [True, False])
def test_temporary_files(tmp_path, path_to_str):
path = str(tmp_path) if path_to_str else tmp_path
@@ -79,14 +78,14 @@ def test_create_context_with_all_valid_args():
runtime =
RuntimeEnvBuilder().with_disk_manager_os().with_fair_spill_pool(10000000)
config = (
SessionConfig()
- .with_create_default_catalog_and_schema(True)
+ .with_create_default_catalog_and_schema(enabled=True)
.with_default_catalog_and_schema("foo", "bar")
.with_target_partitions(1)
- .with_information_schema(True)
- .with_repartition_joins(False)
- .with_repartition_aggregations(False)
- .with_repartition_windows(False)
- .with_parquet_pruning(False)
+ .with_information_schema(enabled=True)
+ .with_repartition_joins(enabled=False)
+ .with_repartition_aggregations(enabled=False)
+ .with_repartition_windows(enabled=False)
+ .with_parquet_pruning(enabled=False)
)
ctx = SessionContext(config, runtime)
@@ -167,7 +166,7 @@ def test_from_arrow_table(ctx):
def record_batch_generator(num_batches: int):
schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
- for i in range(num_batches):
+ for _i in range(num_batches):
yield pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])], schema=schema
)
@@ -492,10 +491,10 @@ def test_table_not_found(ctx):
def test_read_json(ctx):
- path = os.path.dirname(os.path.abspath(__file__))
+ path = pathlib.Path(__file__).parent.resolve()
# Default
- test_data_path = os.path.join(path, "data_test_context", "data.json")
+ test_data_path = path / "data_test_context" / "data.json"
df = ctx.read_json(test_data_path)
result = df.collect()
@@ -515,7 +514,7 @@ def test_read_json(ctx):
assert result[0].schema == schema
# File extension
- test_data_path = os.path.join(path, "data_test_context", "data.json")
+ test_data_path = path / "data_test_context" / "data.json"
df = ctx.read_json(test_data_path, file_extension=".json")
result = df.collect()
@@ -524,15 +523,17 @@ def test_read_json(ctx):
def test_read_json_compressed(ctx, tmp_path):
- path = os.path.dirname(os.path.abspath(__file__))
- test_data_path = os.path.join(path, "data_test_context", "data.json")
+ path = pathlib.Path(__file__).parent.resolve()
+ test_data_path = path / "data_test_context" / "data.json"
# File compression type
gzip_path = tmp_path / "data.json.gz"
- with open(test_data_path, "rb") as csv_file:
- with gzip.open(gzip_path, "wb") as gzipped_file:
- gzipped_file.writelines(csv_file)
+ with (
+ pathlib.Path.open(test_data_path, "rb") as csv_file,
+ gzip.open(gzip_path, "wb") as gzipped_file,
+ ):
+ gzipped_file.writelines(csv_file)
df = ctx.read_json(gzip_path, file_extension=".gz",
file_compression_type="gz")
result = df.collect()
@@ -563,14 +564,16 @@ def test_read_csv_list(ctx):
def test_read_csv_compressed(ctx, tmp_path):
- test_data_path = "testing/data/csv/aggregate_test_100.csv"
+ test_data_path = pathlib.Path("testing/data/csv/aggregate_test_100.csv")
# File compression type
gzip_path = tmp_path / "aggregate_test_100.csv.gz"
- with open(test_data_path, "rb") as csv_file:
- with gzip.open(gzip_path, "wb") as gzipped_file:
- gzipped_file.writelines(csv_file)
+ with (
+ pathlib.Path.open(test_data_path, "rb") as csv_file,
+ gzip.open(gzip_path, "wb") as gzipped_file,
+ ):
+ gzipped_file.writelines(csv_file)
csv_df = ctx.read_csv(gzip_path, file_extension=".gz",
file_compression_type="gz")
csv_df.select(column("c1")).show()
@@ -603,7 +606,7 @@ def test_create_sql_options():
def test_sql_with_options_no_ddl(ctx):
sql = "CREATE TABLE IF NOT EXISTS valuetable AS
VALUES(1,'HELLO'),(12,'DATAFUSION')"
ctx.sql(sql)
- options = SQLOptions().with_allow_ddl(False)
+ options = SQLOptions().with_allow_ddl(allow=False)
with pytest.raises(Exception, match="DDL"):
ctx.sql_with_options(sql, options=options)
@@ -618,7 +621,7 @@ def test_sql_with_options_no_dml(ctx):
ctx.register_dataset(table_name, dataset)
sql = f'INSERT INTO "{table_name}" VALUES (1, 2), (2, 3);'
ctx.sql(sql)
- options = SQLOptions().with_allow_dml(False)
+ options = SQLOptions().with_allow_dml(allow=False)
with pytest.raises(Exception, match="DML"):
ctx.sql_with_options(sql, options=options)
@@ -626,6 +629,6 @@ def test_sql_with_options_no_dml(ctx):
def test_sql_with_options_no_statements(ctx):
sql = "SET time zone = 1;"
ctx.sql(sql)
- options = SQLOptions().with_allow_statements(False)
+ options = SQLOptions().with_allow_statements(allow=False)
with pytest.raises(Exception, match="SetVariable"):
ctx.sql_with_options(sql, options=options)
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index c636e896..d084f12d 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -339,7 +339,7 @@ def test_join():
# Verify we don't make a breaking change to pre-43.0.0
# where users would pass join_keys as a positional argument
- df2 = df.join(df1, (["a"], ["a"]), how="inner") # type: ignore
+ df2 = df.join(df1, (["a"], ["a"]), how="inner")
df2.show()
df2 = df2.sort(column("l.a"))
table = pa.Table.from_batches(df2.collect())
@@ -375,17 +375,17 @@ def test_join_invalid_params():
with pytest.raises(
ValueError, match=r"`left_on` or `right_on` should not provided with
`on`"
):
- df2 = df.join(df1, on="a", how="inner", right_on="test") # type:
ignore
+ df2 = df.join(df1, on="a", how="inner", right_on="test")
with pytest.raises(
ValueError, match=r"`left_on` and `right_on` should both be provided."
):
- df2 = df.join(df1, left_on="a", how="inner") # type: ignore
+ df2 = df.join(df1, left_on="a", how="inner")
with pytest.raises(
ValueError, match=r"either `on` or `left_on` and `right_on` should be
provided."
):
- df2 = df.join(df1, how="inner") # type: ignore
+ df2 = df.join(df1, how="inner")
def test_join_on():
@@ -567,7 +567,7 @@ data_test_window_functions = [
]
[email protected]("name,expr,result", data_test_window_functions)
[email protected](("name", "expr", "result"),
data_test_window_functions)
def test_window_functions(partitioned_df, name, expr, result):
df = partitioned_df.select(
column("a"), column("b"), column("c"), f.alias(expr, name)
@@ -731,7 +731,7 @@ def test_execution_plan(aggregate_df):
plan = aggregate_df.execution_plan()
expected = (
- "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1],
aggr=[sum(test.c2)]\n" # noqa: E501
+ "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1],
aggr=[sum(test.c2)]\n"
)
assert expected == plan.display()
@@ -756,7 +756,7 @@ def test_execution_plan(aggregate_df):
ctx = SessionContext()
rows_returned = 0
- for idx in range(0, plan.partition_count):
+ for idx in range(plan.partition_count):
stream = ctx.execute(plan, idx)
try:
batch = stream.next()
@@ -885,7 +885,7 @@ def test_union_distinct(ctx):
)
df_c = ctx.create_dataframe([[batch]]).sort(column("a"))
- df_a_u_b = df_a.union(df_b, True).sort(column("a"))
+ df_a_u_b = df_a.union(df_b, distinct=True).sort(column("a"))
assert df_c.collect() == df_a_u_b.collect()
assert df_c.collect() == df_a_u_b.collect()
@@ -954,8 +954,6 @@ def test_to_arrow_table(df):
def test_execute_stream(df):
stream = df.execute_stream()
- for s in stream:
- print(type(s))
assert all(batch is not None for batch in stream)
assert not list(stream) # after one iteration the generator must be
exhausted
@@ -969,7 +967,7 @@ def test_execute_stream_to_arrow_table(df, schema):
(batch.to_pyarrow() for batch in stream), schema=df.schema()
)
else:
- pyarrow_table = pa.Table.from_batches((batch.to_pyarrow() for batch in
stream))
+ pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in
stream)
assert isinstance(pyarrow_table, pa.Table)
assert pyarrow_table.shape == (3, 3)
@@ -1033,7 +1031,7 @@ def test_describe(df):
}
[email protected]("path_to_str", (True, False))
[email protected]("path_to_str", [True, False])
def test_write_csv(ctx, df, tmp_path, path_to_str):
path = str(tmp_path) if path_to_str else tmp_path
@@ -1046,7 +1044,7 @@ def test_write_csv(ctx, df, tmp_path, path_to_str):
assert result == expected
[email protected]("path_to_str", (True, False))
[email protected]("path_to_str", [True, False])
def test_write_json(ctx, df, tmp_path, path_to_str):
path = str(tmp_path) if path_to_str else tmp_path
@@ -1059,7 +1057,7 @@ def test_write_json(ctx, df, tmp_path, path_to_str):
assert result == expected
[email protected]("path_to_str", (True, False))
[email protected]("path_to_str", [True, False])
def test_write_parquet(df, tmp_path, path_to_str):
path = str(tmp_path) if path_to_str else tmp_path
@@ -1071,7 +1069,7 @@ def test_write_parquet(df, tmp_path, path_to_str):
@pytest.mark.parametrize(
- "compression, compression_level",
+ ("compression", "compression_level"),
[("gzip", 6), ("brotli", 7), ("zstd", 15)],
)
def test_write_compressed_parquet(df, tmp_path, compression,
compression_level):
@@ -1082,7 +1080,7 @@ def test_write_compressed_parquet(df, tmp_path,
compression, compression_level):
)
# test that the actual compression scheme is the one written
- for root, dirs, files in os.walk(path):
+ for _root, _dirs, files in os.walk(path):
for file in files:
if file.endswith(".parquet"):
metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict()
@@ -1097,7 +1095,7 @@ def test_write_compressed_parquet(df, tmp_path,
compression, compression_level):
@pytest.mark.parametrize(
- "compression, compression_level",
+ ("compression", "compression_level"),
[("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)],
)
def test_write_compressed_parquet_wrong_compression_level(
@@ -1152,7 +1150,7 @@ def test_dataframe_export(df) -> None:
table = pa.table(df, schema=desired_schema)
assert table.num_columns == 1
assert table.num_rows == 3
- for i in range(0, 3):
+ for i in range(3):
assert table[0][i].as_py() is None
# Expect an error when we cannot convert schema
@@ -1186,8 +1184,8 @@ def test_dataframe_transform(df):
result = df.to_pydict()
assert result["a"] == [1, 2, 3]
- assert result["string_col"] == ["string data" for _i in range(0, 3)]
- assert result["new_col"] == [3 for _i in range(0, 3)]
+ assert result["string_col"] == ["string data" for _i in range(3)]
+ assert result["new_col"] == [3 for _i in range(3)]
def test_dataframe_repr_html(df) -> None:
diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py
index 354c7e18..926e6984 100644
--- a/python/tests/test_expr.py
+++ b/python/tests/test_expr.py
@@ -85,18 +85,14 @@ def test_limit(test_ctx):
plan = plan.to_variant()
assert isinstance(plan, Limit)
- # TODO: Upstream now has expressions for skip and fetch
- # REF: https://github.com/apache/datafusion/pull/12836
- # assert plan.skip() == 0
+ assert "Skip: None" in str(plan)
df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Limit)
- # TODO: Upstream now has expressions for skip and fetch
- # REF: https://github.com/apache/datafusion/pull/12836
- # assert plan.skip() == 5
+ assert "Skip: Some(Literal(Int64(5)))" in str(plan)
def test_aggregate_query(test_ctx):
@@ -165,6 +161,7 @@ def test_expr_to_variant():
res = traverse_logical_plan(input_plan)
if res is not None:
return res
+ return None
ctx = SessionContext()
data = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"]}
@@ -176,7 +173,7 @@ def test_expr_to_variant():
assert variant.expr().to_variant().qualified_name() == "table1.name"
assert (
str(variant.list())
- == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")),
Expr(Utf8("vsa"))]'
+ == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")),
Expr(Utf8("vsa"))]' # noqa: E501
)
assert not variant.negated()
diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py
index fca05bb8..ed88a16e 100644
--- a/python/tests/test_functions.py
+++ b/python/tests/test_functions.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import math
-from datetime import datetime
+from datetime import datetime, timezone
import numpy as np
import pyarrow as pa
@@ -25,6 +25,8 @@ from datafusion import functions as f
np.seterr(invalid="ignore")
+DEFAULT_TZ = timezone.utc
+
@pytest.fixture
def df():
@@ -37,9 +39,9 @@ def df():
pa.array(["hello ", " world ", " !"], type=pa.string_view()),
pa.array(
[
- datetime(2022, 12, 31),
- datetime(2027, 6, 26),
- datetime(2020, 7, 2),
+ datetime(2022, 12, 31, tzinfo=DEFAULT_TZ),
+ datetime(2027, 6, 26, tzinfo=DEFAULT_TZ),
+ datetime(2020, 7, 2, tzinfo=DEFAULT_TZ),
]
),
pa.array([False, True, True]),
@@ -221,12 +223,12 @@ def py_indexof(arr, v):
def py_arr_remove(arr, v, n=None):
new_arr = arr[:]
found = 0
- while found != n:
- try:
+ try:
+ while found != n:
new_arr.remove(v)
found += 1
- except ValueError:
- break
+ except ValueError:
+ pass
return new_arr
@@ -234,13 +236,13 @@ def py_arr_remove(arr, v, n=None):
def py_arr_replace(arr, from_, to, n=None):
new_arr = arr[:]
found = 0
- while found != n:
- try:
+ try:
+ while found != n:
idx = new_arr.index(from_)
new_arr[idx] = to
found += 1
- except ValueError:
- break
+ except ValueError:
+ pass
return new_arr
@@ -268,266 +270,266 @@ def py_flatten(arr):
@pytest.mark.parametrize(
("stmt", "py_expr"),
[
- [
+ (
lambda col: f.array_append(col, literal(99.0)),
lambda data: [np.append(arr, 99.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_push_back(col, literal(99.0)),
lambda data: [np.append(arr, 99.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_append(col, literal(99.0)),
lambda data: [np.append(arr, 99.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_push_back(col, literal(99.0)),
lambda data: [np.append(arr, 99.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_concat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_cat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_cat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_concat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_dims(col),
lambda data: [[len(r)] for r in data],
- ],
- [
+ ),
+ (
lambda col: f.array_distinct(col),
lambda data: [list(set(r)) for r in data],
- ],
- [
+ ),
+ (
lambda col: f.list_distinct(col),
lambda data: [list(set(r)) for r in data],
- ],
- [
+ ),
+ (
lambda col: f.list_dims(col),
lambda data: [[len(r)] for r in data],
- ],
- [
+ ),
+ (
lambda col: f.array_element(col, literal(1)),
lambda data: [r[0] for r in data],
- ],
- [
+ ),
+ (
lambda col: f.array_empty(col),
lambda data: [len(r) == 0 for r in data],
- ],
- [
+ ),
+ (
lambda col: f.empty(col),
lambda data: [len(r) == 0 for r in data],
- ],
- [
+ ),
+ (
lambda col: f.array_extract(col, literal(1)),
lambda data: [r[0] for r in data],
- ],
- [
+ ),
+ (
lambda col: f.list_element(col, literal(1)),
lambda data: [r[0] for r in data],
- ],
- [
+ ),
+ (
lambda col: f.list_extract(col, literal(1)),
lambda data: [r[0] for r in data],
- ],
- [
+ ),
+ (
lambda col: f.array_length(col),
lambda data: [len(r) for r in data],
- ],
- [
+ ),
+ (
lambda col: f.list_length(col),
lambda data: [len(r) for r in data],
- ],
- [
+ ),
+ (
lambda col: f.array_has(col, literal(1.0)),
lambda data: [1.0 in r for r in data],
- ],
- [
+ ),
+ (
lambda col: f.array_has_all(
col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
),
lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in
data],
- ],
- [
+ ),
+ (
lambda col: f.array_has_any(
col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
),
lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in
data],
- ],
- [
+ ),
+ (
lambda col: f.array_position(col, literal(1.0)),
lambda data: [py_indexof(r, 1.0) for r in data],
- ],
- [
+ ),
+ (
lambda col: f.array_indexof(col, literal(1.0)),
lambda data: [py_indexof(r, 1.0) for r in data],
- ],
- [
+ ),
+ (
lambda col: f.list_position(col, literal(1.0)),
lambda data: [py_indexof(r, 1.0) for r in data],
- ],
- [
+ ),
+ (
lambda col: f.list_indexof(col, literal(1.0)),
lambda data: [py_indexof(r, 1.0) for r in data],
- ],
- [
+ ),
+ (
lambda col: f.array_positions(col, literal(1.0)),
lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r
in data],
- ],
- [
+ ),
+ (
lambda col: f.list_positions(col, literal(1.0)),
lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r
in data],
- ],
- [
+ ),
+ (
lambda col: f.array_ndims(col),
lambda data: [np.array(r).ndim for r in data],
- ],
- [
+ ),
+ (
lambda col: f.list_ndims(col),
lambda data: [np.array(r).ndim for r in data],
- ],
- [
+ ),
+ (
lambda col: f.array_prepend(literal(99.0), col),
lambda data: [np.insert(arr, 0, 99.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_push_front(literal(99.0), col),
lambda data: [np.insert(arr, 0, 99.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_prepend(literal(99.0), col),
lambda data: [np.insert(arr, 0, 99.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_push_front(literal(99.0), col),
lambda data: [np.insert(arr, 0, 99.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_pop_back(col),
lambda data: [arr[:-1] for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_pop_front(col),
lambda data: [arr[1:] for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_remove(col, literal(3.0)),
lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_remove(col, literal(3.0)),
lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_remove_n(col, literal(3.0), literal(2)),
lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_remove_n(col, literal(3.0), literal(2)),
lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_remove_all(col, literal(3.0)),
lambda data: [py_arr_remove(arr, 3.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_remove_all(col, literal(3.0)),
lambda data: [py_arr_remove(arr, 3.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_repeat(col, literal(2)),
lambda data: [[arr] * 2 for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_repeat(col, literal(2)),
lambda data: [[arr] * 2 for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_replace(col, literal(3.0), literal(4.0)),
lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_replace(col, literal(3.0), literal(4.0)),
lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_replace_n(col, literal(3.0), literal(4.0),
literal(1)),
lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_replace_n(col, literal(3.0), literal(4.0),
literal(2)),
lambda data: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_replace_all(col, literal(3.0), literal(4.0)),
lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)),
lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_sort(col, descending=True, null_first=True),
lambda data: [np.sort(arr)[::-1] for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_sort(col, descending=False, null_first=False),
lambda data: [np.sort(arr) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_slice(col, literal(2), literal(4)),
lambda data: [arr[1:4] for arr in data],
- ],
+ ),
pytest.param(
lambda col: f.list_slice(col, literal(-1), literal(2)),
lambda data: [arr[-1:2] for arr in data],
),
- [
+ (
lambda col: f.array_intersect(col, literal([3.0, 4.0])),
lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_intersect(col, literal([3.0, 4.0])),
lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_union(col, literal([12.0, 999.0])),
lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_union(col, literal([12.0, 999.0])),
lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_except(col, literal([3.0])),
lambda data: [np.setdiff1d(arr, [3.0]) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_except(col, literal([3.0])),
lambda data: [np.setdiff1d(arr, [3.0]) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.array_resize(col, literal(10), literal(0.0)),
lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.list_resize(col, literal(10), literal(0.0)),
lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data],
- ],
- [
+ ),
+ (
lambda col: f.range(literal(1), literal(5), literal(2)),
lambda data: [np.arange(1, 5, 2)],
- ],
+ ),
],
)
def test_array_functions(stmt, py_expr):
@@ -611,22 +613,22 @@ def test_make_array_functions(make_func):
@pytest.mark.parametrize(
("stmt", "py_expr"),
[
- [
+ (
f.array_to_string(column("arr"), literal(",")),
lambda data: [",".join([str(int(v)) for v in r]) for r in data],
- ],
- [
+ ),
+ (
f.array_join(column("arr"), literal(",")),
lambda data: [",".join([str(int(v)) for v in r]) for r in data],
- ],
- [
+ ),
+ (
f.list_to_string(column("arr"), literal(",")),
lambda data: [",".join([str(int(v)) for v in r]) for r in data],
- ],
- [
+ ),
+ (
f.list_join(column("arr"), literal(",")),
lambda data: [",".join([str(int(v)) for v in r]) for r in data],
- ],
+ ),
],
)
def test_array_function_obj_tests(stmt, py_expr):
@@ -640,7 +642,7 @@ def test_array_function_obj_tests(stmt, py_expr):
@pytest.mark.parametrize(
- "function, expected_result",
+ ("function", "expected_result"),
[
(
f.ascii(column("a")),
@@ -894,54 +896,72 @@ def test_temporal_functions(df):
assert result.column(0) == pa.array([12, 6, 7], type=pa.int32())
assert result.column(1) == pa.array([2022, 2027, 2020], type=pa.int32())
assert result.column(2) == pa.array(
- [datetime(2022, 12, 1), datetime(2027, 6, 1), datetime(2020, 7, 1)],
- type=pa.timestamp("us"),
+ [
+ datetime(2022, 12, 1, tzinfo=DEFAULT_TZ),
+ datetime(2027, 6, 1, tzinfo=DEFAULT_TZ),
+ datetime(2020, 7, 1, tzinfo=DEFAULT_TZ),
+ ],
+ type=pa.timestamp("ns", tz=DEFAULT_TZ),
)
assert result.column(3) == pa.array(
- [datetime(2022, 12, 31), datetime(2027, 6, 26), datetime(2020, 7, 2)],
- type=pa.timestamp("us"),
+ [
+ datetime(2022, 12, 31, tzinfo=DEFAULT_TZ),
+ datetime(2027, 6, 26, tzinfo=DEFAULT_TZ),
+ datetime(2020, 7, 2, tzinfo=DEFAULT_TZ),
+ ],
+ type=pa.timestamp("ns", tz=DEFAULT_TZ),
)
assert result.column(4) == pa.array(
[
- datetime(2022, 12, 30, 23, 47, 30),
- datetime(2027, 6, 25, 23, 47, 30),
- datetime(2020, 7, 1, 23, 47, 30),
+ datetime(2022, 12, 30, 23, 47, 30, tzinfo=DEFAULT_TZ),
+ datetime(2027, 6, 25, 23, 47, 30, tzinfo=DEFAULT_TZ),
+ datetime(2020, 7, 1, 23, 47, 30, tzinfo=DEFAULT_TZ),
],
- type=pa.timestamp("ns"),
+ type=pa.timestamp("ns", tz=DEFAULT_TZ),
)
assert result.column(5) == pa.array(
- [datetime(2023, 1, 10, 20, 52, 54)] * 3, type=pa.timestamp("s")
+ [datetime(2023, 1, 10, 20, 52, 54, tzinfo=DEFAULT_TZ)] * 3,
+ type=pa.timestamp("s"),
)
assert result.column(6) == pa.array(
- [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns")
+ [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
+ type=pa.timestamp("ns"),
)
assert result.column(7) == pa.array(
- [datetime(2023, 9, 7, 5, 6, 14)] * 3, type=pa.timestamp("s")
+ [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("s")
)
assert result.column(8) == pa.array(
- [datetime(2023, 9, 7, 5, 6, 14, 523000)] * 3, type=pa.timestamp("ms")
+ [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3,
+ type=pa.timestamp("ms"),
)
assert result.column(9) == pa.array(
- [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us")
+ [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
+ type=pa.timestamp("us"),
)
assert result.column(10) == pa.array([31, 26, 2], type=pa.int32())
assert result.column(11) == pa.array(
- [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns")
+ [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
+ type=pa.timestamp("ns"),
)
assert result.column(12) == pa.array(
- [datetime(2023, 9, 7, 5, 6, 14)] * 3, type=pa.timestamp("s")
+ [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3,
+ type=pa.timestamp("s"),
)
assert result.column(13) == pa.array(
- [datetime(2023, 9, 7, 5, 6, 14, 523000)] * 3, type=pa.timestamp("ms")
+ [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3,
+ type=pa.timestamp("ms"),
)
assert result.column(14) == pa.array(
- [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us")
+ [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
+ type=pa.timestamp("us"),
)
assert result.column(15) == pa.array(
- [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns")
+ [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
+ type=pa.timestamp("ns"),
)
assert result.column(16) == pa.array(
- [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns")
+ [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
+ type=pa.timestamp("ns"),
)
@@ -1057,7 +1077,7 @@ def test_regr_funcs_sql_2():
@pytest.mark.parametrize(
- "func, expected",
+ ("func", "expected"),
[
pytest.param(f.regr_slope(column("c2"), column("c1")), [4.6],
id="regr_slope"),
pytest.param(
@@ -1160,7 +1180,7 @@ def test_binary_string_functions(df):
@pytest.mark.parametrize(
- "python_datatype, name, expected",
+ ("python_datatype", "name", "expected"),
[
pytest.param(bool, "e", pa.bool_(), id="bool"),
pytest.param(int, "b", pa.int64(), id="int"),
@@ -1179,7 +1199,7 @@ def test_cast(df, python_datatype, name: str, expected):
@pytest.mark.parametrize(
- "negated, low, high, expected",
+ ("negated", "low", "high", "expected"),
[
pytest.param(False, 3, 5, {"filtered": [4, 5]}),
pytest.param(False, 4, 5, {"filtered": [4, 5]}),
diff --git a/python/tests/test_imports.py b/python/tests/test_imports.py
index 0c155cbd..9ef7ed89 100644
--- a/python/tests/test_imports.py
+++ b/python/tests/test_imports.py
@@ -169,14 +169,15 @@ def test_class_module_is_datafusion():
def test_import_from_functions_submodule():
- from datafusion.functions import abs, sin # noqa
+ from datafusion.functions import abs as df_abs
+ from datafusion.functions import sin
- assert functions.abs is abs
+ assert functions.abs is df_abs
assert functions.sin is sin
msg = "cannot import name 'foobar' from 'datafusion.functions'"
with pytest.raises(ImportError, match=msg):
- from datafusion.functions import foobar # noqa
+ from datafusion.functions import foobar # noqa: F401
def test_classes_are_inheritable():
diff --git a/python/tests/test_input.py b/python/tests/test_input.py
index 80647135..4663f614 100644
--- a/python/tests/test_input.py
+++ b/python/tests/test_input.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-import os
+import pathlib
from datafusion.input.location import LocationInputPlugin
@@ -23,10 +23,10 @@ from datafusion.input.location import LocationInputPlugin
def test_location_input():
location_input = LocationInputPlugin()
- cwd = os.getcwd()
- input_file = cwd +
"/testing/data/parquet/generated_simple_numerics/blogs.parquet"
+ cwd = pathlib.Path.cwd()
+ input_file = cwd /
"testing/data/parquet/generated_simple_numerics/blogs.parquet"
table_name = "blog"
- tbl = location_input.build_table(input_file, table_name)
- assert "blog" == tbl.name
- assert 3 == len(tbl.columns)
+ tbl = location_input.build_table(str(input_file), table_name)
+ assert tbl.name == "blog"
+ assert len(tbl.columns) == 3
assert "blogs.parquet" in tbl.filepaths[0]
diff --git a/python/tests/test_io.py b/python/tests/test_io.py
index 21ad188e..7ca50968 100644
--- a/python/tests/test_io.py
+++ b/python/tests/test_io.py
@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import os
-import pathlib
+from pathlib import Path
import pyarrow as pa
from datafusion import column
@@ -23,10 +22,10 @@ from datafusion.io import read_avro, read_csv, read_json,
read_parquet
def test_read_json_global_ctx(ctx):
- path = os.path.dirname(os.path.abspath(__file__))
+ path = Path(__file__).parent.resolve()
# Default
- test_data_path = os.path.join(path, "data_test_context", "data.json")
+ test_data_path = Path(path) / "data_test_context" / "data.json"
df = read_json(test_data_path)
result = df.collect()
@@ -46,7 +45,7 @@ def test_read_json_global_ctx(ctx):
assert result[0].schema == schema
# File extension
- test_data_path = os.path.join(path, "data_test_context", "data.json")
+ test_data_path = Path(path) / "data_test_context" / "data.json"
df = read_json(test_data_path, file_extension=".json")
result = df.collect()
@@ -59,7 +58,7 @@ def test_read_parquet_global():
parquet_df.show()
assert parquet_df is not None
- path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet"
+ path = Path.cwd() / "parquet/data/alltypes_plain.parquet"
parquet_df = read_parquet(path=path)
assert parquet_df is not None
@@ -90,6 +89,6 @@ def test_read_avro():
avro_df.show()
assert avro_df is not None
- path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro"
+ path = Path.cwd() / "testing/data/avro/alltypes_plain.avro"
avro_df = read_avro(path=path)
assert avro_df is not None
diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py
index 862f745b..b6348e3a 100644
--- a/python/tests/test_sql.py
+++ b/python/tests/test_sql.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import gzip
-import os
+from pathlib import Path
import numpy as np
import pyarrow as pa
@@ -47,9 +47,8 @@ def test_register_csv(ctx, tmp_path):
)
write_csv(table, path)
- with open(path, "rb") as csv_file:
- with gzip.open(gzip_path, "wb") as gzipped_file:
- gzipped_file.writelines(csv_file)
+ with Path.open(path, "rb") as csv_file, gzip.open(gzip_path, "wb") as
gzipped_file:
+ gzipped_file.writelines(csv_file)
ctx.register_csv("csv", path)
ctx.register_csv("csv1", str(path))
@@ -158,7 +157,7 @@ def test_register_parquet(ctx, tmp_path):
assert result.to_pydict() == {"cnt": [100]}
[email protected]("path_to_str", (True, False))
[email protected]("path_to_str", [True, False])
def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):
dir_root = tmp_path / "dataset_parquet_partitioned"
dir_root.mkdir(exist_ok=False)
@@ -194,7 +193,7 @@ def test_register_parquet_partitioned(ctx, tmp_path,
path_to_str):
assert dict(zip(rd["grp"], rd["cnt"])) == {"a": 3, "b": 1}
[email protected]("path_to_str", (True, False))
[email protected]("path_to_str", [True, False])
def test_register_dataset(ctx, tmp_path, path_to_str):
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
path = str(path) if path_to_str else path
@@ -209,13 +208,15 @@ def test_register_dataset(ctx, tmp_path, path_to_str):
def test_register_json(ctx, tmp_path):
- path = os.path.dirname(os.path.abspath(__file__))
- test_data_path = os.path.join(path, "data_test_context", "data.json")
+ path = Path(__file__).parent.resolve()
+ test_data_path = Path(path) / "data_test_context" / "data.json"
gzip_path = tmp_path / "data.json.gz"
- with open(test_data_path, "rb") as json_file:
- with gzip.open(gzip_path, "wb") as gzipped_file:
- gzipped_file.writelines(json_file)
+ with (
+ Path.open(test_data_path, "rb") as json_file,
+ gzip.open(gzip_path, "wb") as gzipped_file,
+ ):
+ gzipped_file.writelines(json_file)
ctx.register_json("json", test_data_path)
ctx.register_json("json1", str(test_data_path))
@@ -470,16 +471,18 @@ def test_simple_select(ctx, tmp_path, arr):
# In DF 43.0.0 we now default to having BinaryView and StringView
# so the array that is saved to the parquet is slightly different
# than the array read. Convert to values for comparison.
- if isinstance(result, pa.BinaryViewArray) or isinstance(result,
pa.StringViewArray):
+ if isinstance(result, (pa.BinaryViewArray, pa.StringViewArray)):
arr = arr.tolist()
result = result.tolist()
np.testing.assert_equal(result, arr)
[email protected]("file_sort_order", (None, [[col("int").sort(True,
True)]]))
[email protected]("pass_schema", (True, False))
[email protected]("path_to_str", (True, False))
[email protected](
+ "file_sort_order", [None, [[col("int").sort(ascending=True,
nulls_first=True)]]]
+)
[email protected]("pass_schema", [True, False])
[email protected]("path_to_str", [True, False])
def test_register_listing_table(
ctx, tmp_path, pass_schema, file_sort_order, path_to_str
):
@@ -528,7 +531,7 @@ def test_register_listing_table(
assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2}
result = ctx.sql(
- "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005
GROUP BY grp"
+ "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005
GROUP BY grp" # noqa: E501
).collect()
result = pa.Table.from_batches(result)
diff --git a/python/tests/test_store.py b/python/tests/test_store.py
index 53ffc3ac..ac9af98f 100644
--- a/python/tests/test_store.py
+++ b/python/tests/test_store.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-import os
+from pathlib import Path
import pytest
from datafusion import SessionContext
@@ -23,17 +23,16 @@ from datafusion import SessionContext
@pytest.fixture
def ctx():
- ctx = SessionContext()
- return ctx
+ return SessionContext()
def test_read_parquet(ctx):
ctx.register_parquet(
"test",
- f"file://{os.getcwd()}/parquet/data/alltypes_plain.parquet",
- [],
- True,
- ".parquet",
+ f"file://{Path.cwd()}/parquet/data/alltypes_plain.parquet",
+ table_partition_cols=[],
+ parquet_pruning=True,
+ file_extension=".parquet",
)
df = ctx.sql("SELECT * FROM test")
assert isinstance(df.collect(), list)
diff --git a/python/tests/test_substrait.py b/python/tests/test_substrait.py
index feada7cd..f367a447 100644
--- a/python/tests/test_substrait.py
+++ b/python/tests/test_substrait.py
@@ -50,7 +50,7 @@ def test_substrait_serialization(ctx):
substrait_plan = ss.Producer.to_substrait_plan(df.logical_plan(), ctx)
[email protected]("path_to_str", (True, False))
[email protected]("path_to_str", [True, False])
def test_substrait_file_serialization(ctx, tmp_path, path_to_str):
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py
index 97cf81f3..453ff6f4 100644
--- a/python/tests/test_udaf.py
+++ b/python/tests/test_udaf.py
@@ -17,8 +17,6 @@
from __future__ import annotations
-from typing import List
-
import pyarrow as pa
import pyarrow.compute as pc
import pytest
@@ -31,7 +29,7 @@ class Summarize(Accumulator):
def __init__(self, initial_value: float = 0.0):
self._sum = pa.scalar(initial_value)
- def state(self) -> List[pa.Scalar]:
+ def state(self) -> list[pa.Scalar]:
return [self._sum]
def update(self, values: pa.Array) -> None:
@@ -39,7 +37,7 @@ class Summarize(Accumulator):
# This breaks on `None`
self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
- def merge(self, states: List[pa.Array]) -> None:
+ def merge(self, states: list[pa.Array]) -> None:
# Not nice since pyarrow scalars can't be summed yet.
# This breaks on `None`
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
@@ -56,7 +54,7 @@ class MissingMethods(Accumulator):
def __init__(self):
self._sum = pa.scalar(0)
- def state(self) -> List[pa.Scalar]:
+ def state(self) -> list[pa.Scalar]:
return [self._sum]
@@ -86,7 +84,7 @@ def test_errors(df):
"evaluate, merge, update)"
)
with pytest.raises(Exception, match=msg):
- accum = udaf( # noqa F841
+ accum = udaf( # noqa: F841
MissingMethods,
pa.int64(),
pa.int64(),
diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py
index 2fea34aa..3d6dcf9d 100644
--- a/python/tests/test_udwf.py
+++ b/python/tests/test_udwf.py
@@ -298,7 +298,7 @@ data_test_udwf_functions = [
]
[email protected]("name,expr,expected", data_test_udwf_functions)
[email protected](("name", "expr", "expected"),
data_test_udwf_functions)
def test_udwf_functions(df, name, expr, expected):
df = df.select("a", "b", f.round(expr, lit(3)).alias(name))
diff --git a/python/tests/test_wrapper_coverage.py
b/python/tests/test_wrapper_coverage.py
index ac064ba9..d7f6f6e3 100644
--- a/python/tests/test_wrapper_coverage.py
+++ b/python/tests/test_wrapper_coverage.py
@@ -19,6 +19,7 @@ import datafusion
import datafusion.functions
import datafusion.object_store
import datafusion.substrait
+import pytest
# EnumType introduced in 3.11. 3.10 and prior it was called EnumMeta.
try:
@@ -41,10 +42,8 @@ def missing_exports(internal_obj, wrapped_obj) -> None:
internal_attr = getattr(internal_obj, attr)
wrapped_attr = getattr(wrapped_obj, attr)
- if internal_attr is not None:
- if wrapped_attr is None:
- print("Missing attribute: ", attr)
- assert False
+ if internal_attr is not None and wrapped_attr is None:
+ pytest.fail(f"Missing attribute: {attr}")
if attr in ["__self__", "__class__"]:
continue
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]