This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new b8dd97bc Add additional ruff suggestions (#1062)
b8dd97bc is described below
commit b8dd97bc8eefcfecfa8dcc864c4898c654b236a9
Author: Spaarsh <[email protected]>
AuthorDate: Mon Mar 17 20:08:16 2025 +0530
Add additional ruff suggestions (#1062)
* Enabled ruff rule PT001 and ANN204
* Enabled ruff rule B008
* Enabled ruff rule EM101
* Enabled ruff rule PLR1714
* Enabled ruff rule ANN201
* Enabled ruff rule C400
* Enabled ruff rule B904
* Enabled ruff rule UP006
* Enabled ruff rule RUF012
* Enabled ruff rule FBT003
* Enabled ruff rule C416
* Enabled ruff rule SIM102
* Enabled ruff rule PGH003
* Enabled ruff rule PERF401
* Enabled ruff rule EM102
* Enabled ruff rule SIM108
* Enabled ruff rule ICN001
* Enabled ruff rule ICN001
* implemented reviews
* Update pyproject.toml to ignore `SIM102`
* Enabled ruff rule PLW2901
* Enabled ruff rule RET503
* Fixed failing ruff tests
---
benchmarks/db-benchmark/groupby-datafusion.py | 24 ++--
benchmarks/db-benchmark/join-datafusion.py | 5 +-
benchmarks/tpch/tpch.py | 7 +-
dev/release/generate-changelog.py | 6 +-
docs/source/conf.py | 4 +-
examples/create-context.py | 12 +-
examples/python-udaf.py | 36 +++---
examples/python-udf-comparisons.py | 9 +-
examples/python-udf.py | 12 +-
examples/query-pyarrow-data.py | 10 +-
examples/sql-using-python-udaf.py | 2 +-
examples/tpch/_tests.py | 4 +-
examples/tpch/convert_data_to_parquet.py | 134 ++++++++++-----------
examples/tpch/q08_market_share.py | 2 +-
examples/tpch/q19_discounted_revenue.py | 4 +-
examples/tpch/q21_suppliers_kept_orders_waiting.py | 2 +-
pyproject.toml | 20 ---
python/datafusion/__init__.py | 8 +-
python/datafusion/catalog.py | 4 +-
python/datafusion/context.py | 51 ++++----
python/datafusion/dataframe.py | 55 +++++----
python/datafusion/expr.py | 31 ++---
python/datafusion/functions.py | 9 +-
python/tests/test_functions.py | 2 +-
python/tests/test_wrapper_coverage.py | 7 +-
25 files changed, 213 insertions(+), 247 deletions(-)
diff --git a/benchmarks/db-benchmark/groupby-datafusion.py
b/benchmarks/db-benchmark/groupby-datafusion.py
index 04bf7a14..f9e8d638 100644
--- a/benchmarks/db-benchmark/groupby-datafusion.py
+++ b/benchmarks/db-benchmark/groupby-datafusion.py
@@ -20,7 +20,7 @@ import os
import timeit
import datafusion as df
-import pyarrow
+import pyarrow as pa
from datafusion import (
RuntimeEnvBuilder,
SessionConfig,
@@ -37,7 +37,7 @@ print("# groupby-datafusion.py", flush=True)
exec(open("./_helpers/helpers.py").read())
-def ans_shape(batches):
+def ans_shape(batches) -> tuple[int, int]:
rows, cols = 0, 0
for batch in batches:
rows += batch.num_rows
@@ -48,7 +48,7 @@ def ans_shape(batches):
return rows, cols
-def execute(df):
+def execute(df) -> list:
print(df.execution_plan().display_indent())
return df.collect()
@@ -68,14 +68,14 @@ data_name = os.environ["SRC_DATANAME"]
src_grp = os.path.join("data", data_name + ".csv")
print("loading dataset %s" % src_grp, flush=True)
-schema = pyarrow.schema(
+schema = pa.schema(
[
- ("id4", pyarrow.int32()),
- ("id5", pyarrow.int32()),
- ("id6", pyarrow.int32()),
- ("v1", pyarrow.int32()),
- ("v2", pyarrow.int32()),
- ("v3", pyarrow.float64()),
+ ("id4", pa.int32()),
+ ("id5", pa.int32()),
+ ("id6", pa.int32()),
+ ("v1", pa.int32()),
+ ("v2", pa.int32()),
+ ("v3", pa.float64()),
]
)
@@ -93,8 +93,8 @@ runtime = (
)
config = (
SessionConfig()
- .with_repartition_joins(False)
- .with_repartition_aggregations(False)
+ .with_repartition_joins(enabled=False)
+ .with_repartition_aggregations(enabled=False)
.set("datafusion.execution.coalesce_batches", "false")
)
ctx = SessionContext(config, runtime)
diff --git a/benchmarks/db-benchmark/join-datafusion.py
b/benchmarks/db-benchmark/join-datafusion.py
index b45ebf63..03986803 100755
--- a/benchmarks/db-benchmark/join-datafusion.py
+++ b/benchmarks/db-benchmark/join-datafusion.py
@@ -29,7 +29,7 @@ print("# join-datafusion.py", flush=True)
exec(open("./_helpers/helpers.py").read())
-def ans_shape(batches):
+def ans_shape(batches) -> tuple[int, int]:
rows, cols = 0, 0
for batch in batches:
rows += batch.num_rows
@@ -57,7 +57,8 @@ src_jn_y = [
os.path.join("data", y_data_name[2] + ".csv"),
]
if len(src_jn_y) != 3:
- raise Exception("Something went wrong in preparing files used for join")
+ error_msg = "Something went wrong in preparing files used for join"
+ raise Exception(error_msg)
print(
"loading datasets "
diff --git a/benchmarks/tpch/tpch.py b/benchmarks/tpch/tpch.py
index bfb9ac39..2d1bbae5 100644
--- a/benchmarks/tpch/tpch.py
+++ b/benchmarks/tpch/tpch.py
@@ -21,7 +21,7 @@ import time
from datafusion import SessionContext
-def bench(data_path, query_path):
+def bench(data_path, query_path) -> None:
with open("results.csv", "w") as results:
# register tables
start = time.time()
@@ -68,10 +68,7 @@ def bench(data_path, query_path):
with open(f"{query_path}/q{query}.sql") as f:
text = f.read()
tmp = text.split(";")
- queries = []
- for str in tmp:
- if len(str.strip()) > 0:
- queries.append(str.strip())
+ queries = [s.strip() for s in tmp if len(s.strip()) > 0]
try:
start = time.time()
diff --git a/dev/release/generate-changelog.py
b/dev/release/generate-changelog.py
index e30e2def..d8673677 100755
--- a/dev/release/generate-changelog.py
+++ b/dev/release/generate-changelog.py
@@ -24,7 +24,7 @@ import sys
from github import Github
-def print_pulls(repo_name, title, pulls):
+def print_pulls(repo_name, title, pulls) -> None:
if len(pulls) > 0:
print(f"**{title}:**")
print()
@@ -34,7 +34,7 @@ def print_pulls(repo_name, title, pulls):
print()
-def generate_changelog(repo, repo_name, tag1, tag2, version):
+def generate_changelog(repo, repo_name, tag1, tag2, version) -> None:
# get a list of commits between two tags
print(f"Fetching list of commits between {tag1} and {tag2}",
file=sys.stderr)
comparison = repo.compare(tag1, tag2)
@@ -154,7 +154,7 @@ under the License.
)
-def cli(args=None):
+def cli(args=None) -> None:
"""Process command line arguments."""
if not args:
args = sys.argv[1:]
diff --git a/docs/source/conf.py b/docs/source/conf.py
index c82a189e..0be03d81 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -73,7 +73,7 @@ suppress_warnings = ["autoapi.python_import_resolution"]
autoapi_python_class_content = "both"
-def autoapi_skip_member_fn(app, what, name, obj, skip, options): # noqa:
ARG001
+def autoapi_skip_member_fn(app, what, name, obj, skip, options) -> bool: #
noqa: ARG001
skip_contents = [
# Re-exports
("class", "datafusion.DataFrame"),
@@ -93,7 +93,7 @@ def autoapi_skip_member_fn(app, what, name, obj, skip,
options): # noqa: ARG001
return skip
-def setup(sphinx):
+def setup(sphinx) -> None:
sphinx.connect("autoapi-skip-member", autoapi_skip_member_fn)
diff --git a/examples/create-context.py b/examples/create-context.py
index 760c8513..0026d616 100644
--- a/examples/create-context.py
+++ b/examples/create-context.py
@@ -25,14 +25,14 @@ print(ctx)
runtime =
RuntimeEnvBuilder().with_disk_manager_os().with_fair_spill_pool(10000000)
config = (
SessionConfig()
- .with_create_default_catalog_and_schema(True)
+ .with_create_default_catalog_and_schema(enabled=True)
.with_default_catalog_and_schema("foo", "bar")
.with_target_partitions(8)
- .with_information_schema(True)
- .with_repartition_joins(False)
- .with_repartition_aggregations(False)
- .with_repartition_windows(False)
- .with_parquet_pruning(False)
+ .with_information_schema(enabled=True)
+ .with_repartition_joins(enabled=False)
+ .with_repartition_aggregations(enabled=False)
+ .with_repartition_windows(enabled=False)
+ .with_parquet_pruning(enabled=False)
.set("datafusion.execution.parquet.pushdown_filters", "true")
)
ctx = SessionContext(config, runtime)
diff --git a/examples/python-udaf.py b/examples/python-udaf.py
index 538f6957..6655edb0 100644
--- a/examples/python-udaf.py
+++ b/examples/python-udaf.py
@@ -16,7 +16,7 @@
# under the License.
import datafusion
-import pyarrow
+import pyarrow as pa
import pyarrow.compute
from datafusion import Accumulator, col, udaf
@@ -26,25 +26,21 @@ class MyAccumulator(Accumulator):
Interface of a user-defined accumulation.
"""
- def __init__(self):
- self._sum = pyarrow.scalar(0.0)
+ def __init__(self) -> None:
+ self._sum = pa.scalar(0.0)
- def update(self, values: pyarrow.Array) -> None:
+ def update(self, values: pa.Array) -> None:
# not nice since pyarrow scalars can't be summed yet. This breaks on
`None`
- self._sum = pyarrow.scalar(
- self._sum.as_py() + pyarrow.compute.sum(values).as_py()
- )
+ self._sum = pa.scalar(self._sum.as_py() +
pa.compute.sum(values).as_py())
- def merge(self, states: pyarrow.Array) -> None:
+ def merge(self, states: pa.Array) -> None:
# not nice since pyarrow scalars can't be summed yet. This breaks on
`None`
- self._sum = pyarrow.scalar(
- self._sum.as_py() + pyarrow.compute.sum(states).as_py()
- )
+ self._sum = pa.scalar(self._sum.as_py() +
pa.compute.sum(states).as_py())
- def state(self) -> pyarrow.Array:
- return pyarrow.array([self._sum.as_py()])
+ def state(self) -> pa.Array:
+ return pa.array([self._sum.as_py()])
- def evaluate(self) -> pyarrow.Scalar:
+ def evaluate(self) -> pa.Scalar:
return self._sum
@@ -52,17 +48,17 @@ class MyAccumulator(Accumulator):
ctx = datafusion.SessionContext()
# create a RecordBatch and a new DataFrame from it
-batch = pyarrow.RecordBatch.from_arrays(
- [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
+batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]])
my_udaf = udaf(
MyAccumulator,
- pyarrow.float64(),
- pyarrow.float64(),
- [pyarrow.float64()],
+ pa.float64(),
+ pa.float64(),
+ [pa.float64()],
"stable",
)
@@ -70,4 +66,4 @@ df = df.aggregate([], [my_udaf(col("a"))])
result = df.collect()[0]
-assert result.column(0) == pyarrow.array([6.0])
+assert result.column(0) == pa.array([6.0])
diff --git a/examples/python-udf-comparisons.py
b/examples/python-udf-comparisons.py
index c5d5ec8d..eb082501 100644
--- a/examples/python-udf-comparisons.py
+++ b/examples/python-udf-comparisons.py
@@ -112,8 +112,8 @@ def is_of_interest_impl(
returnflag_arr: pa.Array,
) -> pa.Array:
result = []
- for idx, partkey in enumerate(partkey_arr):
- partkey = partkey.as_py()
+ for idx, partkey_val in enumerate(partkey_arr):
+ partkey = partkey_val.as_py()
suppkey = suppkey_arr[idx].as_py()
returnflag = returnflag_arr[idx].as_py()
value = (partkey, suppkey, returnflag)
@@ -162,10 +162,7 @@ def udf_using_pyarrow_compute_impl(
resultant_arr = pc.and_(filtered_partkey_arr, filtered_suppkey_arr)
resultant_arr = pc.and_(resultant_arr, filtered_returnflag_arr)
- if results is None:
- results = resultant_arr
- else:
- results = pc.or_(results, resultant_arr)
+ results = resultant_arr if results is None else pc.or_(results,
resultant_arr)
return results
diff --git a/examples/python-udf.py b/examples/python-udf.py
index fb2bc253..1c08acd1 100644
--- a/examples/python-udf.py
+++ b/examples/python-udf.py
@@ -15,23 +15,23 @@
# specific language governing permissions and limitations
# under the License.
-import pyarrow
+import pyarrow as pa
from datafusion import SessionContext, udf
from datafusion import functions as f
-def is_null(array: pyarrow.Array) -> pyarrow.Array:
+def is_null(array: pa.Array) -> pa.Array:
return array.is_null()
-is_null_arr = udf(is_null, [pyarrow.int64()], pyarrow.bool_(), "stable")
+is_null_arr = udf(is_null, [pa.int64()], pa.bool_(), "stable")
# create a context
ctx = SessionContext()
# create a RecordBatch and a new DataFrame from it
-batch = pyarrow.RecordBatch.from_arrays(
- [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
+batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]])
@@ -40,4 +40,4 @@ df = df.select(is_null_arr(f.col("a")))
result = df.collect()[0]
-assert result.column(0) == pyarrow.array([False] * 3)
+assert result.column(0) == pa.array([False] * 3)
diff --git a/examples/query-pyarrow-data.py b/examples/query-pyarrow-data.py
index e3456fb5..9cfe8a62 100644
--- a/examples/query-pyarrow-data.py
+++ b/examples/query-pyarrow-data.py
@@ -16,15 +16,15 @@
# under the License.
import datafusion
-import pyarrow
+import pyarrow as pa
from datafusion import col
# create a context
ctx = datafusion.SessionContext()
# create a RecordBatch and a new DataFrame from it
-batch = pyarrow.RecordBatch.from_arrays(
- [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
+batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]])
@@ -38,5 +38,5 @@ df = df.select(
# execute and collect the first (and only) batch
result = df.collect()[0]
-assert result.column(0) == pyarrow.array([5, 7, 9])
-assert result.column(1) == pyarrow.array([-3, -3, -3])
+assert result.column(0) == pa.array([5, 7, 9])
+assert result.column(1) == pa.array([-3, -3, -3])
diff --git a/examples/sql-using-python-udaf.py
b/examples/sql-using-python-udaf.py
index 60ab8d13..32ce3890 100644
--- a/examples/sql-using-python-udaf.py
+++ b/examples/sql-using-python-udaf.py
@@ -25,7 +25,7 @@ class MyAccumulator(Accumulator):
Interface of a user-defined accumulation.
"""
- def __init__(self):
+ def __init__(self) -> None:
self._sum = pa.scalar(0.0)
def update(self, values: pa.Array) -> None:
diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py
index 2be4dfab..80ff8024 100644
--- a/examples/tpch/_tests.py
+++ b/examples/tpch/_tests.py
@@ -91,7 +91,7 @@ def check_q17(df):
("q22_global_sales_opportunity", "q22"),
],
)
-def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
+def test_tpch_query_vs_answer_file(query_code: str, answer_file: str) -> None:
module = import_module(query_code)
df: DataFrame = module.df
@@ -122,3 +122,5 @@ def test_tpch_query_vs_answer_file(query_code: str,
answer_file: str):
assert df.join(df_expected, on=cols, how="anti").count() == 0
assert df.count() == df_expected.count()
+
+ return None
diff --git a/examples/tpch/convert_data_to_parquet.py
b/examples/tpch/convert_data_to_parquet.py
index 73097fac..fd0fcca4 100644
--- a/examples/tpch/convert_data_to_parquet.py
+++ b/examples/tpch/convert_data_to_parquet.py
@@ -25,112 +25,112 @@ as will be generated by the script provided in this
repository.
import os
import datafusion
-import pyarrow
+import pyarrow as pa
ctx = datafusion.SessionContext()
all_schemas = {}
all_schemas["customer"] = [
- ("C_CUSTKEY", pyarrow.int64()),
- ("C_NAME", pyarrow.string()),
- ("C_ADDRESS", pyarrow.string()),
- ("C_NATIONKEY", pyarrow.int64()),
- ("C_PHONE", pyarrow.string()),
- ("C_ACCTBAL", pyarrow.decimal128(15, 2)),
- ("C_MKTSEGMENT", pyarrow.string()),
- ("C_COMMENT", pyarrow.string()),
+ ("C_CUSTKEY", pa.int64()),
+ ("C_NAME", pa.string()),
+ ("C_ADDRESS", pa.string()),
+ ("C_NATIONKEY", pa.int64()),
+ ("C_PHONE", pa.string()),
+ ("C_ACCTBAL", pa.decimal128(15, 2)),
+ ("C_MKTSEGMENT", pa.string()),
+ ("C_COMMENT", pa.string()),
]
all_schemas["lineitem"] = [
- ("L_ORDERKEY", pyarrow.int64()),
- ("L_PARTKEY", pyarrow.int64()),
- ("L_SUPPKEY", pyarrow.int64()),
- ("L_LINENUMBER", pyarrow.int32()),
- ("L_QUANTITY", pyarrow.decimal128(15, 2)),
- ("L_EXTENDEDPRICE", pyarrow.decimal128(15, 2)),
- ("L_DISCOUNT", pyarrow.decimal128(15, 2)),
- ("L_TAX", pyarrow.decimal128(15, 2)),
- ("L_RETURNFLAG", pyarrow.string()),
- ("L_LINESTATUS", pyarrow.string()),
- ("L_SHIPDATE", pyarrow.date32()),
- ("L_COMMITDATE", pyarrow.date32()),
- ("L_RECEIPTDATE", pyarrow.date32()),
- ("L_SHIPINSTRUCT", pyarrow.string()),
- ("L_SHIPMODE", pyarrow.string()),
- ("L_COMMENT", pyarrow.string()),
+ ("L_ORDERKEY", pa.int64()),
+ ("L_PARTKEY", pa.int64()),
+ ("L_SUPPKEY", pa.int64()),
+ ("L_LINENUMBER", pa.int32()),
+ ("L_QUANTITY", pa.decimal128(15, 2)),
+ ("L_EXTENDEDPRICE", pa.decimal128(15, 2)),
+ ("L_DISCOUNT", pa.decimal128(15, 2)),
+ ("L_TAX", pa.decimal128(15, 2)),
+ ("L_RETURNFLAG", pa.string()),
+ ("L_LINESTATUS", pa.string()),
+ ("L_SHIPDATE", pa.date32()),
+ ("L_COMMITDATE", pa.date32()),
+ ("L_RECEIPTDATE", pa.date32()),
+ ("L_SHIPINSTRUCT", pa.string()),
+ ("L_SHIPMODE", pa.string()),
+ ("L_COMMENT", pa.string()),
]
all_schemas["nation"] = [
- ("N_NATIONKEY", pyarrow.int64()),
- ("N_NAME", pyarrow.string()),
- ("N_REGIONKEY", pyarrow.int64()),
- ("N_COMMENT", pyarrow.string()),
+ ("N_NATIONKEY", pa.int64()),
+ ("N_NAME", pa.string()),
+ ("N_REGIONKEY", pa.int64()),
+ ("N_COMMENT", pa.string()),
]
all_schemas["orders"] = [
- ("O_ORDERKEY", pyarrow.int64()),
- ("O_CUSTKEY", pyarrow.int64()),
- ("O_ORDERSTATUS", pyarrow.string()),
- ("O_TOTALPRICE", pyarrow.decimal128(15, 2)),
- ("O_ORDERDATE", pyarrow.date32()),
- ("O_ORDERPRIORITY", pyarrow.string()),
- ("O_CLERK", pyarrow.string()),
- ("O_SHIPPRIORITY", pyarrow.int32()),
- ("O_COMMENT", pyarrow.string()),
+ ("O_ORDERKEY", pa.int64()),
+ ("O_CUSTKEY", pa.int64()),
+ ("O_ORDERSTATUS", pa.string()),
+ ("O_TOTALPRICE", pa.decimal128(15, 2)),
+ ("O_ORDERDATE", pa.date32()),
+ ("O_ORDERPRIORITY", pa.string()),
+ ("O_CLERK", pa.string()),
+ ("O_SHIPPRIORITY", pa.int32()),
+ ("O_COMMENT", pa.string()),
]
all_schemas["part"] = [
- ("P_PARTKEY", pyarrow.int64()),
- ("P_NAME", pyarrow.string()),
- ("P_MFGR", pyarrow.string()),
- ("P_BRAND", pyarrow.string()),
- ("P_TYPE", pyarrow.string()),
- ("P_SIZE", pyarrow.int32()),
- ("P_CONTAINER", pyarrow.string()),
- ("P_RETAILPRICE", pyarrow.decimal128(15, 2)),
- ("P_COMMENT", pyarrow.string()),
+ ("P_PARTKEY", pa.int64()),
+ ("P_NAME", pa.string()),
+ ("P_MFGR", pa.string()),
+ ("P_BRAND", pa.string()),
+ ("P_TYPE", pa.string()),
+ ("P_SIZE", pa.int32()),
+ ("P_CONTAINER", pa.string()),
+ ("P_RETAILPRICE", pa.decimal128(15, 2)),
+ ("P_COMMENT", pa.string()),
]
all_schemas["partsupp"] = [
- ("PS_PARTKEY", pyarrow.int64()),
- ("PS_SUPPKEY", pyarrow.int64()),
- ("PS_AVAILQTY", pyarrow.int32()),
- ("PS_SUPPLYCOST", pyarrow.decimal128(15, 2)),
- ("PS_COMMENT", pyarrow.string()),
+ ("PS_PARTKEY", pa.int64()),
+ ("PS_SUPPKEY", pa.int64()),
+ ("PS_AVAILQTY", pa.int32()),
+ ("PS_SUPPLYCOST", pa.decimal128(15, 2)),
+ ("PS_COMMENT", pa.string()),
]
all_schemas["region"] = [
- ("r_REGIONKEY", pyarrow.int64()),
- ("r_NAME", pyarrow.string()),
- ("r_COMMENT", pyarrow.string()),
+ ("r_REGIONKEY", pa.int64()),
+ ("r_NAME", pa.string()),
+ ("r_COMMENT", pa.string()),
]
all_schemas["supplier"] = [
- ("S_SUPPKEY", pyarrow.int64()),
- ("S_NAME", pyarrow.string()),
- ("S_ADDRESS", pyarrow.string()),
- ("S_NATIONKEY", pyarrow.int32()),
- ("S_PHONE", pyarrow.string()),
- ("S_ACCTBAL", pyarrow.decimal128(15, 2)),
- ("S_COMMENT", pyarrow.string()),
+ ("S_SUPPKEY", pa.int64()),
+ ("S_NAME", pa.string()),
+ ("S_ADDRESS", pa.string()),
+ ("S_NATIONKEY", pa.int32()),
+ ("S_PHONE", pa.string()),
+ ("S_ACCTBAL", pa.decimal128(15, 2)),
+ ("S_COMMENT", pa.string()),
]
curr_dir = os.path.dirname(os.path.abspath(__file__))
-for filename, curr_schema in all_schemas.items():
+for filename, curr_schema_val in all_schemas.items():
# For convenience, go ahead and convert the schema column names to
lowercase
- curr_schema = [(s[0].lower(), s[1]) for s in curr_schema]
+ curr_schema = [(s[0].lower(), s[1]) for s in curr_schema_val]
# Pre-collect the output columns so we can ignore the null field we add
# in to handle the trailing | in the file
output_cols = [r[0] for r in curr_schema]
- curr_schema = [pyarrow.field(r[0], r[1], nullable=False) for r in
curr_schema]
+ curr_schema = [pa.field(r[0], r[1], nullable=False) for r in curr_schema]
# Trailing | requires extra field for in processing
- curr_schema.append(("some_null", pyarrow.null()))
+ curr_schema.append(("some_null", pa.null()))
- schema = pyarrow.schema(curr_schema)
+ schema = pa.schema(curr_schema)
source_file = os.path.abspath(
os.path.join(curr_dir, f"../../benchmarks/tpch/data/{filename}.csv")
diff --git a/examples/tpch/q08_market_share.py
b/examples/tpch/q08_market_share.py
index d46df30f..4bf50efb 100644
--- a/examples/tpch/q08_market_share.py
+++ b/examples/tpch/q08_market_share.py
@@ -150,7 +150,7 @@ df = df_regional_customers.join(
df = df.with_column(
"national_volume",
F.case(col("s_suppkey").is_null())
- .when(lit(False), col("volume"))
+ .when(lit(value=False), col("volume"))
.otherwise(lit(0.0)),
)
diff --git a/examples/tpch/q19_discounted_revenue.py
b/examples/tpch/q19_discounted_revenue.py
index 2b87e112..bd492aac 100644
--- a/examples/tpch/q19_discounted_revenue.py
+++ b/examples/tpch/q19_discounted_revenue.py
@@ -89,8 +89,8 @@ def is_of_interest(
same number of rows in the output.
"""
result = []
- for idx, brand in enumerate(brand_arr):
- brand = brand.as_py()
+ for idx, brand_val in enumerate(brand_arr):
+ brand = brand_val.as_py()
if brand in items_of_interest:
values_of_interest = items_of_interest[brand]
diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py
b/examples/tpch/q21_suppliers_kept_orders_waiting.py
index 9bbaad77..619c4406 100644
--- a/examples/tpch/q21_suppliers_kept_orders_waiting.py
+++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py
@@ -65,7 +65,7 @@ df = df_lineitem.join(df, left_on="l_orderkey",
right_on="o_orderkey", how="inne
df = df.with_column(
"failed_supp",
F.case(col("l_receiptdate") > col("l_commitdate"))
- .when(lit(True), col("l_suppkey"))
+ .when(lit(value=True), col("l_suppkey"))
.end(),
)
diff --git a/pyproject.toml b/pyproject.toml
index a4ed18c4..d86b657e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -80,37 +80,17 @@ ignore = [
"TD003", # Allow TODO lines
"UP007", # Disallowing Union is pedantic
# TODO: Enable all of the following, but this PR is getting too large
already
- "PT001",
- "ANN204",
- "B008",
- "EM101",
"PLR0913",
- "PLR1714",
- "ANN201",
- "C400",
"TRY003",
- "B904",
- "UP006",
- "RUF012",
- "FBT003",
- "C416",
- "SIM102",
- "PGH003",
"PLR2004",
- "PERF401",
"PD901",
- "EM102",
"ERA001",
- "SIM108",
- "ICN001",
"ANN001",
"ANN202",
"PTH",
"N812",
"INP001",
"DTZ007",
- "PLW2901",
- "RET503",
"RUF015",
"A005",
"TC001",
diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py
index 286e5dc3..d871fdb7 100644
--- a/python/datafusion/__init__.py
+++ b/python/datafusion/__init__.py
@@ -92,17 +92,17 @@ __all__ = [
]
-def column(value: str):
+def column(value: str) -> Expr:
"""Create a column expression."""
return Expr.column(value)
-def col(value: str):
+def col(value: str) -> Expr:
"""Create a column expression."""
return Expr.column(value)
-def literal(value):
+def literal(value) -> Expr:
"""Create a literal expression."""
return Expr.literal(value)
@@ -120,6 +120,6 @@ def str_lit(value):
return string_literal(value)
-def lit(value):
+def lit(value) -> Expr:
"""Create a literal expression."""
return Expr.literal(value)
diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py
index 0560f470..6c3f188c 100644
--- a/python/datafusion/catalog.py
+++ b/python/datafusion/catalog.py
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING
import datafusion._internal as df_internal
if TYPE_CHECKING:
- import pyarrow
+ import pyarrow as pa
class Catalog:
@@ -67,7 +67,7 @@ class Table:
self.table = table
@property
- def schema(self) -> pyarrow.Schema:
+ def schema(self) -> pa.Schema:
"""Returns the schema associated with this table."""
return self.table.schema
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index 58ad9a94..1429a497 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -40,9 +40,9 @@ from ._internal import SQLOptions as SQLOptionsInternal
if TYPE_CHECKING:
import pathlib
- import pandas
- import polars
- import pyarrow
+ import pandas as pd
+ import polars as pl
+ import pyarrow as pa
from datafusion.plan import ExecutionPlan, LogicalPlan
@@ -537,7 +537,7 @@ class SessionContext:
path: str | pathlib.Path,
table_partition_cols: list[tuple[str, str]] | None = None,
file_extension: str = ".parquet",
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
file_sort_order: list[list[Expr | SortExpr]] | None = None,
) -> None:
"""Register multiple files as a single table.
@@ -606,14 +606,14 @@ class SessionContext:
def create_dataframe(
self,
- partitions: list[list[pyarrow.RecordBatch]],
+ partitions: list[list[pa.RecordBatch]],
name: str | None = None,
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
) -> DataFrame:
"""Create and return a dataframe using the provided partitions.
Args:
- partitions: :py:class:`pyarrow.RecordBatch` partitions to register.
+ partitions: :py:class:`pa.RecordBatch` partitions to register.
name: Resultant dataframe name.
schema: Schema for the partitions.
@@ -684,16 +684,14 @@ class SessionContext:
return DataFrame(self.ctx.from_arrow(data, name))
@deprecated("Use ``from_arrow`` instead.")
- def from_arrow_table(
- self, data: pyarrow.Table, name: str | None = None
- ) -> DataFrame:
+ def from_arrow_table(self, data: pa.Table, name: str | None = None) ->
DataFrame:
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow
table.
This is an alias for :py:func:`from_arrow`.
"""
return self.from_arrow(data, name)
- def from_pandas(self, data: pandas.DataFrame, name: str | None = None) ->
DataFrame:
+ def from_pandas(self, data: pd.DataFrame, name: str | None = None) ->
DataFrame:
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from a Pandas
DataFrame.
Args:
@@ -705,7 +703,7 @@ class SessionContext:
"""
return DataFrame(self.ctx.from_pandas(data, name))
- def from_polars(self, data: polars.DataFrame, name: str | None = None) ->
DataFrame:
+ def from_polars(self, data: pl.DataFrame, name: str | None = None) ->
DataFrame:
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from a Polars
DataFrame.
Args:
@@ -719,7 +717,7 @@ class SessionContext:
#
https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
# is the discussion on how we arrived at adding register_view
- def register_view(self, name: str, df: DataFrame):
+ def register_view(self, name: str, df: DataFrame) -> None:
"""Register a :py:class: `~datafusion.detaframe.DataFrame` as a view.
Args:
@@ -755,7 +753,7 @@ class SessionContext:
self.ctx.register_table_provider(name, provider)
def register_record_batches(
- self, name: str, partitions: list[list[pyarrow.RecordBatch]]
+ self, name: str, partitions: list[list[pa.RecordBatch]]
) -> None:
"""Register record batches as a table.
@@ -776,7 +774,7 @@ class SessionContext:
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
file_sort_order: list[list[SortExpr]] | None = None,
) -> None:
"""Register a Parquet file as a table.
@@ -817,7 +815,7 @@ class SessionContext:
self,
name: str,
path: str | pathlib.Path | list[str | pathlib.Path],
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
has_header: bool = True,
delimiter: str = ",",
schema_infer_max_records: int = 1000,
@@ -843,10 +841,7 @@ class SessionContext:
selected for data input.
file_compression_type: File compression type.
"""
- if isinstance(path, list):
- path = [str(p) for p in path]
- else:
- path = str(path)
+ path = [str(p) for p in path] if isinstance(path, list) else str(path)
self.ctx.register_csv(
name,
@@ -863,7 +858,7 @@ class SessionContext:
self,
name: str,
path: str | pathlib.Path,
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
@@ -901,7 +896,7 @@ class SessionContext:
self,
name: str,
path: str | pathlib.Path,
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
file_extension: str = ".avro",
table_partition_cols: list[tuple[str, str]] | None = None,
) -> None:
@@ -923,8 +918,8 @@ class SessionContext:
name, str(path), schema, file_extension, table_partition_cols
)
- def register_dataset(self, name: str, dataset: pyarrow.dataset.Dataset) ->
None:
- """Register a :py:class:`pyarrow.dataset.Dataset` as a table.
+ def register_dataset(self, name: str, dataset: pa.dataset.Dataset) -> None:
+ """Register a :py:class:`pa.dataset.Dataset` as a table.
Args:
name: Name of the table to register.
@@ -975,7 +970,7 @@ class SessionContext:
def read_json(
self,
path: str | pathlib.Path,
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
@@ -1012,7 +1007,7 @@ class SessionContext:
def read_csv(
self,
path: str | pathlib.Path | list[str] | list[pathlib.Path],
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
has_header: bool = True,
delimiter: str = ",",
schema_infer_max_records: int = 1000,
@@ -1065,7 +1060,7 @@ class SessionContext:
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
file_sort_order: list[list[Expr | SortExpr]] | None = None,
) -> DataFrame:
"""Read a Parquet source into a
:py:class:`~datafusion.dataframe.Dataframe`.
@@ -1110,7 +1105,7 @@ class SessionContext:
def read_avro(
self,
path: str | pathlib.Path,
- schema: pyarrow.Schema | None = None,
+ schema: pa.Schema | None = None,
file_partition_cols: list[tuple[str, str]] | None = None,
file_extension: str = ".avro",
) -> DataFrame:
diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py
index d1c71c2b..26fe8f45 100644
--- a/python/datafusion/dataframe.py
+++ b/python/datafusion/dataframe.py
@@ -26,10 +26,8 @@ from typing import (
TYPE_CHECKING,
Any,
Iterable,
- List,
Literal,
Optional,
- Type,
Union,
overload,
)
@@ -75,7 +73,7 @@ class Compression(Enum):
LZ4_RAW = "lz4_raw"
@classmethod
- def from_str(cls: Type[Compression], value: str) -> Compression:
+ def from_str(cls: type[Compression], value: str) -> Compression:
"""Convert a string to a Compression enum value.
Args:
@@ -89,11 +87,13 @@ class Compression(Enum):
"""
try:
return cls(value.lower())
- except ValueError:
+ except ValueError as err:
valid_values = str([item.value for item in Compression])
- raise ValueError(
- f"{value} is not a valid Compression. Valid values are:
{valid_values}"
- )
+ error_msg = f"""
+ {value} is not a valid Compression.
+ Valid values are: {valid_values}
+ """
+ raise ValueError(error_msg) from err
def get_default_level(self) -> Optional[int]:
"""Get the default compression level for the compression type.
@@ -132,7 +132,7 @@ class DataFrame:
"""Convert DataFrame as a ViewTable which can be used in
register_table."""
return self.df.into_view()
- def __getitem__(self, key: str | List[str]) -> DataFrame:
+ def __getitem__(self, key: str | list[str]) -> DataFrame:
"""Return a new :py:class`DataFrame` with the specified column or
columns.
Args:
@@ -287,8 +287,7 @@ class DataFrame:
if isinstance(expr, Expr):
expr_list.append(expr.expr)
elif isinstance(expr, Iterable):
- for inner_expr in expr:
- expr_list.append(inner_expr.expr)
+ expr_list.extend(inner_expr.expr for inner_expr in expr)
else:
raise NotImplementedError
if named_exprs:
@@ -513,10 +512,15 @@ class DataFrame:
# This check is to prevent breaking API changes where users prior to
# DF 43.0.0 would pass the join_keys as a positional argument instead
# of a keyword argument.
- if isinstance(on, tuple) and len(on) == 2:
- if isinstance(on[0], list) and isinstance(on[1], list):
- join_keys = on # type: ignore
- on = None
+ if (
+ isinstance(on, tuple)
+ and len(on) == 2
+ and isinstance(on[0], list)
+ and isinstance(on[1], list)
+ ):
+ # We know this is safe because we've checked the types
+ join_keys = on # type: ignore[assignment]
+ on = None
if join_keys is not None:
warnings.warn(
@@ -529,18 +533,17 @@ class DataFrame:
if on is not None:
if left_on is not None or right_on is not None:
- raise ValueError(
- "`left_on` or `right_on` should not provided with `on`"
- )
+ error_msg = "`left_on` or `right_on` should not provided with
`on`"
+ raise ValueError(error_msg)
left_on = on
right_on = on
elif left_on is not None or right_on is not None:
if left_on is None or right_on is None:
- raise ValueError("`left_on` and `right_on` should both be
provided.")
+ error_msg = "`left_on` and `right_on` should both be provided."
+ raise ValueError(error_msg)
else:
- raise ValueError(
- "either `on` or `left_on` and `right_on` should be provided."
- )
+ error_msg = "either `on` or `left_on` and `right_on` should be
provided."
+ raise ValueError(error_msg)
if isinstance(left_on, str):
left_on = [left_on]
if isinstance(right_on, str):
@@ -726,9 +729,11 @@ class DataFrame:
if isinstance(compression, str):
compression = Compression.from_str(compression)
- if compression in {Compression.GZIP, Compression.BROTLI,
Compression.ZSTD}:
- if compression_level is None:
- compression_level = compression.get_default_level()
+ if (
+ compression in {Compression.GZIP, Compression.BROTLI,
Compression.ZSTD}
+ and compression_level is None
+ ):
+ compression_level = compression.get_default_level()
self.df.write_parquet(str(path), compression.value, compression_level)
@@ -824,7 +829,7 @@ class DataFrame:
Returns:
A DataFrame with the columns expanded.
"""
- columns = [c for c in columns]
+ columns = list(columns)
return DataFrame(self.df.unnest_columns(columns,
preserve_nulls=preserve_nulls))
def __arrow_c_stream__(self, requested_schema: pa.Schema) -> Any:
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index 77b6c272..2697d814 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -22,7 +22,7 @@ See :ref:`Expressions` in the online documentation for more
details.
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Optional, Type
+from typing import TYPE_CHECKING, Any, ClassVar, Optional
import pyarrow as pa
@@ -176,7 +176,7 @@ def sort_or_default(e: Expr | SortExpr) ->
expr_internal.SortExpr:
"""Helper function to return a default Sort if an Expr is provided."""
if isinstance(e, SortExpr):
return e.raw_sort
- return SortExpr(e, True, True).raw_sort
+ return SortExpr(e, ascending=True, nulls_first=True).raw_sort
def sort_list_to_raw_sort_list(
@@ -439,24 +439,21 @@ class Expr:
value = Expr.literal(value)
return Expr(functions_internal.nvl(self.expr, value.expr))
- _to_pyarrow_types = {
+ _to_pyarrow_types: ClassVar[dict[type, pa.DataType]] = {
float: pa.float64(),
int: pa.int64(),
str: pa.string(),
bool: pa.bool_(),
}
- def cast(
- self, to: pa.DataType[Any] | Type[float] | Type[int] | Type[str] |
Type[bool]
- ) -> Expr:
+ def cast(self, to: pa.DataType[Any] | type[float | int | str | bool]) ->
Expr:
"""Cast to a new data type."""
if not isinstance(to, pa.DataType):
try:
to = self._to_pyarrow_types[to]
- except KeyError:
- raise TypeError(
- "Expected instance of pyarrow.DataType or builtins.type"
- )
+ except KeyError as err:
+ error_msg = "Expected instance of pyarrow.DataType or
builtins.type"
+ raise TypeError(error_msg) from err
return Expr(self.expr.cast(to))
@@ -565,9 +562,7 @@ class Expr:
set parameters for either window or aggregate functions. If used on
any other
type of expression, an error will be generated when ``build()`` is
called.
"""
- return ExprFuncBuilder(
- self.expr.partition_by(list(e.expr for e in partition_by))
- )
+ return ExprFuncBuilder(self.expr.partition_by([e.expr for e in
partition_by]))
def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
"""Set the frame fora window function.
@@ -610,7 +605,7 @@ class Expr:
class ExprFuncBuilder:
- def __init__(self, builder: expr_internal.ExprFuncBuilder):
+ def __init__(self, builder: expr_internal.ExprFuncBuilder) -> None:
self.builder = builder
def order_by(self, *exprs: Expr) -> ExprFuncBuilder:
@@ -638,7 +633,7 @@ class ExprFuncBuilder:
def partition_by(self, *partition_by: Expr) -> ExprFuncBuilder:
"""Set partitioning for window functions."""
return ExprFuncBuilder(
- self.builder.partition_by(list(e.expr for e in partition_by))
+ self.builder.partition_by([e.expr for e in partition_by])
)
def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
@@ -693,11 +688,11 @@ class WindowFrame:
"""
if not isinstance(start_bound, pa.Scalar) and start_bound is not None:
start_bound = pa.scalar(start_bound)
- if units == "rows" or units == "groups":
+ if units in ("rows", "groups"):
start_bound = start_bound.cast(pa.uint64())
if not isinstance(end_bound, pa.Scalar) and end_bound is not None:
end_bound = pa.scalar(end_bound)
- if units == "rows" or units == "groups":
+ if units in ("rows", "groups"):
end_bound = end_bound.cast(pa.uint64())
self.window_frame = expr_internal.WindowFrame(units, start_bound,
end_bound)
@@ -709,7 +704,7 @@ class WindowFrame:
"""Returns starting bound."""
return WindowFrameBound(self.window_frame.get_lower_bound())
- def get_upper_bound(self):
+ def get_upper_bound(self) -> WindowFrameBound:
"""Returns end bound."""
return WindowFrameBound(self.window_frame.get_upper_bound())
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index 26bac149..5cf914e1 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -790,10 +790,7 @@ def regexp_count(
"""
if flags is not None:
flags = flags.expr
- if start is not None:
- start = start.expr
- else:
- start = Expr.expr
+ start = start.expr if start is not None else Expr.expr
return Expr(f.regexp_count(string.expr, pattern.expr, start, flags))
@@ -817,13 +814,15 @@ def right(string: Expr, n: Expr) -> Expr:
return Expr(f.right(string.expr, n.expr))
-def round(value: Expr, decimal_places: Expr = Expr.literal(0)) -> Expr:
+def round(value: Expr, decimal_places: Expr | None = None) -> Expr:
"""Round the argument to the nearest integer.
If the optional ``decimal_places`` is specified, round to the nearest
number of
decimal places. You can specify a negative number of decimal places. For
example
``round(lit(125.2345), lit(-2))`` would yield a value of ``100.0``.
"""
+ if decimal_places is None:
+ decimal_places = Expr.literal(0)
return Expr(f.round(value.expr, decimal_places.expr))
diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py
index 161e1e3b..37f2075f 100644
--- a/python/tests/test_functions.py
+++ b/python/tests/test_functions.py
@@ -81,7 +81,7 @@ def test_literal(df):
literal("1"),
literal("OK"),
literal(3.14),
- literal(True),
+ literal(value=True),
literal(b"hello world"),
)
result = df.collect()
diff --git a/python/tests/test_wrapper_coverage.py
b/python/tests/test_wrapper_coverage.py
index a2de2d32..926a6596 100644
--- a/python/tests/test_wrapper_coverage.py
+++ b/python/tests/test_wrapper_coverage.py
@@ -28,7 +28,7 @@ except ImportError:
from enum import EnumMeta as EnumType
-def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901
+def missing_exports(internal_obj, wrapped_obj) -> None:
"""
Identify if any of the rust exposted structs or functions do not have
wrappers.
@@ -56,9 +56,8 @@ def missing_exports(internal_obj, wrapped_obj) -> None: #
noqa: C901
# __kwdefaults__ and __doc__. As long as these are None on the internal
# object, it's okay to skip them. However if they do exist on the
internal
# object they must also exist on the wrapped object.
- if internal_attr is not None:
- if wrapped_attr is None:
- pytest.fail(f"Missing attribute: {internal_attr_name}")
+ if internal_attr is not None and wrapped_attr is None:
+ pytest.fail(f"Missing attribute: {internal_attr_name}")
if internal_attr_name in ["__self__", "__class__"]:
continue
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]