This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new dfbb3ca tsaucer/run TPC-H examples in CI (#711)
dfbb3ca is described below
commit dfbb3ca1e7080c26a23ec774618bbdd4b6b1d34e
Author: Tim Saucer <[email protected]>
AuthorDate: Tue May 21 12:48:22 2024 -0400
tsaucer/run TPC-H examples in CI (#711)
* Mostly small updates to tpc-h examples to make their results consistent
with spec
* Mostly prepares tpch examples for CI by allowing us to check the path for
the files regardless of where it is run from. There are a few small updates to
make the tests match the expected answer file provided by dbgen.
* Expose the substring command
* Add the script to run all tpch examples in pytest
* Update tpch generator script to allow for non-interactive terminals, such
as when running in CI
* Add tpch examples to github workflow for testing
---
.github/workflows/test.yaml | 20 ++++
benchmarks/tpch/tpch-gen.sh | 13 ++-
examples/tpch/_tests.py | 111 +++++++++++++++++++++
examples/tpch/convert_data_to_parquet.py | 18 ++--
examples/tpch/q01_pricing_summary_report.py | 11 +-
examples/tpch/q02_minimum_cost_supplier.py | 14 +--
examples/tpch/q03_shipping_priority.py | 11 +-
examples/tpch/q04_order_priority_checking.py | 5 +-
examples/tpch/q05_local_supplier_volume.py | 13 +--
examples/tpch/q06_forecasting_revenue_change.py | 3 +-
examples/tpch/q07_volume_shipping.py | 11 +-
examples/tpch/q08_market_share.py | 15 +--
examples/tpch/q09_product_type_profit_measure.py | 15 +--
examples/tpch/q10_returned_item_reporting.py | 13 +--
.../tpch/q11_important_stock_identification.py | 9 +-
examples/tpch/q12_ship_mode_order_priority.py | 5 +-
examples/tpch/q13_customer_distribution.py | 7 +-
examples/tpch/q14_promotion_effect.py | 7 +-
examples/tpch/q15_top_supplier.py | 9 +-
examples/tpch/q16_part_supplier_relationship.py | 11 +-
examples/tpch/q17_small_quantity_order.py | 9 +-
examples/tpch/q18_large_volume_customer.py | 7 +-
examples/tpch/q19_discounted_revenue.py | 12 +--
examples/tpch/q20_potential_part_promotion.py | 15 +--
examples/tpch/q21_suppliers_kept_orders_waiting.py | 11 +-
examples/tpch/q22_global_sales_opportunity.py | 14 +--
examples/tpch/util.py | 33 ++++++
src/functions.rs | 2 +
28 files changed, 311 insertions(+), 113 deletions(-)
diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index a486f27..d8909b6 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -111,3 +111,23 @@ jobs:
source venv/bin/activate
pip install -e . -vv
pytest -v .
+
+ - name: Cache the generated dataset
+ id: cache-tpch-dataset
+ uses: actions/cache@v3
+ with:
+ path: benchmarks/tpch/data
+ key: tpch-data-2.18.0
+
+ - name: Run dbgen to create 1 Gb dataset
+ if: ${{ steps.cache-tpch-dataset.outputs.cache-hit != 'true' }}
+ run: |
+ cd benchmarks/tpch
+ RUN_IN_CI=TRUE ./tpch-gen.sh 1
+
+ - name: Run TPC-H examples
+ run: |
+ source venv/bin/activate
+ cd examples/tpch
+ python convert_data_to_parquet.py
+ pytest _tests.py
diff --git a/benchmarks/tpch/tpch-gen.sh b/benchmarks/tpch/tpch-gen.sh
index 15cab12..139c300 100755
--- a/benchmarks/tpch/tpch-gen.sh
+++ b/benchmarks/tpch/tpch-gen.sh
@@ -20,6 +20,15 @@ mkdir -p data/answers 2>/dev/null
set -e
+# If RUN_IN_CI is set, then do not produce verbose output or use an
interactive terminal
+if [[ -z "${RUN_IN_CI}" ]]; then
+ TERMINAL_FLAG="-it"
+ VERBOSE_OUTPUT="-vf"
+else
+ TERMINAL_FLAG=""
+ VERBOSE_OUTPUT="-f"
+fi
+
#pushd ..
#. ./dev/build-set-env.sh
#popd
@@ -29,7 +38,7 @@ FILE=./data/supplier.tbl
if test -f "$FILE"; then
echo "$FILE exists."
else
- docker run -v `pwd`/data:/data -it --rm ghcr.io/scalytics/tpch-docker:main
-vf -s $1
+ docker run -v `pwd`/data:/data $TERMINAL_FLAG --rm
ghcr.io/scalytics/tpch-docker:main $VERBOSE_OUTPUT -s $1
# workaround for https://github.com/apache/arrow-datafusion/issues/6147
mv data/customer.tbl data/customer.csv
@@ -49,5 +58,5 @@ FILE=./data/answers/q1.out
if test -f "$FILE"; then
echo "$FILE exists."
else
- docker run -v `pwd`/data:/data -it --entrypoint /bin/bash --rm
ghcr.io/scalytics/tpch-docker:main -c "cp /opt/tpch/2.18.0_rc2/dbgen/answers/*
/data/answers/"
+ docker run -v `pwd`/data:/data $TERMINAL_FLAG --entrypoint /bin/bash --rm
ghcr.io/scalytics/tpch-docker:main -c "cp /opt/tpch/2.18.0_rc2/dbgen/answers/*
/data/answers/"
fi
diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py
new file mode 100644
index 0000000..049b43d
--- /dev/null
+++ b/examples/tpch/_tests.py
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+from importlib import import_module
+import pyarrow as pa
+from datafusion import col, lit, functions as F
+from util import get_answer_file
+
+def df_selection(col_name, col_type):
+ if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type):
+ return F.round(col(col_name), lit(2)).alias(col_name)
+ elif col_type == pa.string():
+ return F.trim(col(col_name)).alias(col_name)
+ else:
+ return col(col_name)
+
+def load_schema(col_name, col_type):
+ if col_type == pa.int64() or col_type == pa.int32():
+ return col_name, pa.string()
+ elif isinstance(col_type, pa.Decimal128Type):
+ return col_name, pa.float64()
+ else:
+ return col_name, col_type
+
+def expected_selection(col_name, col_type):
+ if col_type == pa.int64() or col_type == pa.int32():
+ return F.trim(col(col_name)).cast(col_type).alias(col_name)
+ elif col_type == pa.string():
+ return F.trim(col(col_name)).alias(col_name)
+ else:
+ return col(col_name)
+
+def selections_and_schema(original_schema):
+ columns = [ (c, original_schema.field(c).type) for c in
original_schema.names ]
+
+ df_selections = [ df_selection(c, t) for (c, t) in columns]
+ expected_schema = [ load_schema(c, t) for (c, t) in columns]
+ expected_selections = [ expected_selection(c, t) for (c, t) in columns]
+
+ return (df_selections, expected_schema, expected_selections)
+
+def check_q17(df):
+ raw_value = float(df.collect()[0]["avg_yearly"][0].as_py())
+ value = round(raw_value, 2)
+ assert abs(value - 348406.05) < 0.001
+
[email protected](
+ ("query_code", "answer_file"),
+ [
+ ("q01_pricing_summary_report", "q1"),
+ ("q02_minimum_cost_supplier", "q2"),
+ ("q03_shipping_priority", "q3"),
+ ("q04_order_priority_checking", "q4"),
+ ("q05_local_supplier_volume", "q5"),
+ ("q06_forecasting_revenue_change", "q6"),
+ ("q07_volume_shipping", "q7"),
+ ("q08_market_share", "q8"),
+ ("q09_product_type_profit_measure", "q9"),
+ ("q10_returned_item_reporting", "q10"),
+ ("q11_important_stock_identification", "q11"),
+ ("q12_ship_mode_order_priority", "q12"),
+ ("q13_customer_distribution", "q13"),
+ ("q14_promotion_effect", "q14"),
+ ("q15_top_supplier", "q15"),
+ ("q16_part_supplier_relationship", "q16"),
+ ("q17_small_quantity_order", "q17"),
+ ("q18_large_volume_customer", "q18"),
+ ("q19_discounted_revenue", "q19"),
+ ("q20_potential_part_promotion", "q20"),
+ ("q21_suppliers_kept_orders_waiting", "q21"),
+ ("q22_global_sales_opportunity", "q22"),
+ ],
+)
+def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
+ module = import_module(query_code)
+ df = module.df
+
+ # Treat q17 as a special case. The answer file does not match the spec.
Running at
+ # scale factor 1, we have manually verified this result does match the
expected value.
+ if answer_file == "q17":
+ return check_q17(df)
+
+ (df_selections, expected_schema, expected_selections) =
selections_and_schema(df.schema())
+
+ df = df.select(*df_selections)
+
+ read_schema = pa.schema(expected_schema)
+
+ df_expected = module.ctx.read_csv(get_answer_file(answer_file),
schema=read_schema, delimiter="|", file_extension=".out")
+
+ df_expected = df_expected.select(*expected_selections)
+
+ cols = list(read_schema.names)
+
+ assert df.join(df_expected, (cols, cols), "anti").count() == 0
+ assert df.count() == df_expected.count()
diff --git a/examples/tpch/convert_data_to_parquet.py
b/examples/tpch/convert_data_to_parquet.py
index 178b7fb..5da60bc 100644
--- a/examples/tpch/convert_data_to_parquet.py
+++ b/examples/tpch/convert_data_to_parquet.py
@@ -36,7 +36,7 @@ all_schemas["customer"] = [
("C_ADDRESS", pyarrow.string()),
("C_NATIONKEY", pyarrow.int32()),
("C_PHONE", pyarrow.string()),
- ("C_ACCTBAL", pyarrow.float32()),
+ ("C_ACCTBAL", pyarrow.decimal128(15, 2)),
("C_MKTSEGMENT", pyarrow.string()),
("C_COMMENT", pyarrow.string()),
]
@@ -46,10 +46,10 @@ all_schemas["lineitem"] = [
("L_PARTKEY", pyarrow.int32()),
("L_SUPPKEY", pyarrow.int32()),
("L_LINENUMBER", pyarrow.int32()),
- ("L_QUANTITY", pyarrow.float32()),
- ("L_EXTENDEDPRICE", pyarrow.float32()),
- ("L_DISCOUNT", pyarrow.float32()),
- ("L_TAX", pyarrow.float32()),
+ ("L_QUANTITY", pyarrow.decimal128(15, 2)),
+ ("L_EXTENDEDPRICE", pyarrow.decimal128(15, 2)),
+ ("L_DISCOUNT", pyarrow.decimal128(15, 2)),
+ ("L_TAX", pyarrow.decimal128(15, 2)),
("L_RETURNFLAG", pyarrow.string()),
("L_LINESTATUS", pyarrow.string()),
("L_SHIPDATE", pyarrow.date32()),
@@ -71,7 +71,7 @@ all_schemas["orders"] = [
("O_ORDERKEY", pyarrow.int32()),
("O_CUSTKEY", pyarrow.int32()),
("O_ORDERSTATUS", pyarrow.string()),
- ("O_TOTALPRICE", pyarrow.float32()),
+ ("O_TOTALPRICE", pyarrow.decimal128(15, 2)),
("O_ORDERDATE", pyarrow.date32()),
("O_ORDERPRIORITY", pyarrow.string()),
("O_CLERK", pyarrow.string()),
@@ -87,7 +87,7 @@ all_schemas["part"] = [
("P_TYPE", pyarrow.string()),
("P_SIZE", pyarrow.int32()),
("P_CONTAINER", pyarrow.string()),
- ("P_RETAILPRICE", pyarrow.float32()),
+ ("P_RETAILPRICE", pyarrow.decimal128(15, 2)),
("P_COMMENT", pyarrow.string()),
]
@@ -95,7 +95,7 @@ all_schemas["partsupp"] = [
("PS_PARTKEY", pyarrow.int32()),
("PS_SUPPKEY", pyarrow.int32()),
("PS_AVAILQTY", pyarrow.int32()),
- ("PS_SUPPLYCOST", pyarrow.float32()),
+ ("PS_SUPPLYCOST", pyarrow.decimal128(15, 2)),
("PS_COMMENT", pyarrow.string()),
]
@@ -111,7 +111,7 @@ all_schemas["supplier"] = [
("S_ADDRESS", pyarrow.string()),
("S_NATIONKEY", pyarrow.int32()),
("S_PHONE", pyarrow.string()),
- ("S_ACCTBAL", pyarrow.float32()),
+ ("S_ACCTBAL", pyarrow.decimal128(15, 2)),
("S_COMMENT", pyarrow.string()),
]
diff --git a/examples/tpch/q01_pricing_summary_report.py
b/examples/tpch/q01_pricing_summary_report.py
index 1aafcca..7e86055 100644
--- a/examples/tpch/q01_pricing_summary_report.py
+++ b/examples/tpch/q01_pricing_summary_report.py
@@ -31,10 +31,11 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
ctx = SessionContext()
-df = ctx.read_parquet("data/lineitem.parquet")
+df = ctx.read_parquet(get_data_path("lineitem.parquet"))
# It may be that the date can be hard coded, based on examples shown.
# This approach will work with any date range in the provided data set.
@@ -45,7 +46,7 @@ greatest_ship_date = df.aggregate(
# From the given problem, this is how close to the last date in the database we
# want to report results for. It should be between 60-120 days before the end.
-DAYS_BEFORE_FINAL = 68
+DAYS_BEFORE_FINAL = 90
# Note: this is a hack on setting the values. It should be set differently once
# https://github.com/apache/datafusion-python/issues/665 is resolved.
@@ -63,13 +64,13 @@ df = df.aggregate(
[
F.sum(col("l_quantity")).alias("sum_qty"),
F.sum(col("l_extendedprice")).alias("sum_base_price"),
- F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias(
+ F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias(
"sum_disc_price"
),
F.sum(
col("l_extendedprice")
- * (lit(1.0) - col("l_discount"))
- * (lit(1.0) + col("l_tax"))
+ * (lit(1) - col("l_discount"))
+ * (lit(1) + col("l_tax"))
).alias("sum_charge"),
F.avg(col("l_quantity")).alias("avg_qty"),
F.avg(col("l_extendedprice")).alias("avg_price"),
diff --git a/examples/tpch/q02_minimum_cost_supplier.py
b/examples/tpch/q02_minimum_cost_supplier.py
index 262e2cf..f4020d7 100644
--- a/examples/tpch/q02_minimum_cost_supplier.py
+++ b/examples/tpch/q02_minimum_cost_supplier.py
@@ -31,8 +31,10 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
import datafusion
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
-# This is the part we're looking for
+# This is the part we're looking for. Values selected here differ from the
spec in order to run
+# unit tests on a small data set.
SIZE_OF_INTEREST = 15
TYPE_OF_INTEREST = "BRASS"
REGION_OF_INTEREST = "EUROPE"
@@ -41,10 +43,10 @@ REGION_OF_INTEREST = "EUROPE"
ctx = SessionContext()
-df_part = ctx.read_parquet("data/part.parquet").select_columns(
+df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_mfgr", "p_type", "p_size"
)
-df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
+df_supplier =
ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_acctbal",
"s_name",
"s_address",
@@ -53,13 +55,13 @@ df_supplier =
ctx.read_parquet("data/supplier.parquet").select_columns(
"s_nationkey",
"s_suppkey",
)
-df_partsupp = ctx.read_parquet("data/partsupp.parquet").select_columns(
+df_partsupp =
ctx.read_parquet(get_data_path("partsupp.parquet")).select_columns(
"ps_partkey", "ps_suppkey", "ps_supplycost"
)
-df_nation = ctx.read_parquet("data/nation.parquet").select_columns(
+df_nation = ctx.read_parquet(get_data_path("nation.parquet")).select_columns(
"n_nationkey", "n_regionkey", "n_name"
)
-df_region = ctx.read_parquet("data/region.parquet").select_columns(
+df_region = ctx.read_parquet(get_data_path("region.parquet")).select_columns(
"r_regionkey", "r_name"
)
diff --git a/examples/tpch/q03_shipping_priority.py
b/examples/tpch/q03_shipping_priority.py
index 78993e9..6a4886d 100644
--- a/examples/tpch/q03_shipping_priority.py
+++ b/examples/tpch/q03_shipping_priority.py
@@ -28,6 +28,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
"""
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
SEGMENT_OF_INTEREST = "BUILDING"
DATE_OF_INTEREST = "1995-03-15"
@@ -36,13 +37,13 @@ DATE_OF_INTEREST = "1995-03-15"
ctx = SessionContext()
-df_customer = ctx.read_parquet("data/customer.parquet").select_columns(
+df_customer =
ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_mktsegment", "c_custkey"
)
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_orderdate", "o_shippriority", "o_custkey", "o_orderkey"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_orderkey", "l_extendedprice", "l_discount", "l_shipdate"
)
@@ -73,9 +74,9 @@ df = df.aggregate(
df = df.sort(col("revenue").sort(ascending=False), col("o_orderdate").sort())
-# Only return 100 results
+# Only return 10 results
-df = df.limit(100)
+df = df.limit(10)
# Change the order that the columns are reported in just to match the spec
diff --git a/examples/tpch/q04_order_priority_checking.py
b/examples/tpch/q04_order_priority_checking.py
index b691d5b..40eab69 100644
--- a/examples/tpch/q04_order_priority_checking.py
+++ b/examples/tpch/q04_order_priority_checking.py
@@ -29,6 +29,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
from datetime import datetime
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
# Ideally we could put 3 months into the interval. See note below.
INTERVAL_DAYS = 92
@@ -38,10 +39,10 @@ DATE_OF_INTEREST = "1993-07-01"
ctx = SessionContext()
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_orderdate", "o_orderpriority", "o_orderkey"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_orderkey", "l_commitdate", "l_receiptdate"
)
diff --git a/examples/tpch/q05_local_supplier_volume.py
b/examples/tpch/q05_local_supplier_volume.py
index 7cb6e63..27b4b84 100644
--- a/examples/tpch/q05_local_supplier_volume.py
+++ b/examples/tpch/q05_local_supplier_volume.py
@@ -32,6 +32,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
from datetime import datetime
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
DATE_OF_INTEREST = "1994-01-01"
@@ -48,22 +49,22 @@ interval = pa.scalar((0, 0, INTERVAL_DAYS),
type=pa.month_day_nano_interval())
ctx = SessionContext()
-df_customer = ctx.read_parquet("data/customer.parquet").select_columns(
+df_customer =
ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_custkey", "c_nationkey"
)
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_custkey", "o_orderkey", "o_orderdate"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_orderkey", "l_suppkey", "l_extendedprice", "l_discount"
)
-df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
+df_supplier =
ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_nationkey"
)
-df_nation = ctx.read_parquet("data/nation.parquet").select_columns(
+df_nation = ctx.read_parquet(get_data_path("nation.parquet")).select_columns(
"n_nationkey", "n_regionkey", "n_name"
)
-df_region = ctx.read_parquet("data/region.parquet").select_columns(
+df_region = ctx.read_parquet(get_data_path("region.parquet")).select_columns(
"r_regionkey", "r_name"
)
diff --git a/examples/tpch/q06_forecasting_revenue_change.py
b/examples/tpch/q06_forecasting_revenue_change.py
index 5fbb917..3f58c5e 100644
--- a/examples/tpch/q06_forecasting_revenue_change.py
+++ b/examples/tpch/q06_forecasting_revenue_change.py
@@ -32,6 +32,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
from datetime import datetime
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
# Variables from the example query
@@ -52,7 +53,7 @@ interval = pa.scalar((0, 0, INTERVAL_DAYS),
type=pa.month_day_nano_interval())
ctx = SessionContext()
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_shipdate", "l_quantity", "l_extendedprice", "l_discount"
)
diff --git a/examples/tpch/q07_volume_shipping.py
b/examples/tpch/q07_volume_shipping.py
index 3c87f93..fd7323b 100644
--- a/examples/tpch/q07_volume_shipping.py
+++ b/examples/tpch/q07_volume_shipping.py
@@ -31,6 +31,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
from datetime import datetime
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
# Variables of interest to query over
@@ -48,19 +49,19 @@ end_date = lit(datetime.strptime(END_DATE,
"%Y-%m-%d").date())
ctx = SessionContext()
-df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
+df_supplier =
ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_nationkey"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_shipdate", "l_extendedprice", "l_discount", "l_suppkey", "l_orderkey"
)
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_orderkey", "o_custkey"
)
-df_customer = ctx.read_parquet("data/customer.parquet").select_columns(
+df_customer =
ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_custkey", "c_nationkey"
)
-df_nation = ctx.read_parquet("data/nation.parquet").select_columns(
+df_nation = ctx.read_parquet(get_data_path("nation.parquet")).select_columns(
"n_nationkey", "n_name"
)
diff --git a/examples/tpch/q08_market_share.py
b/examples/tpch/q08_market_share.py
index a415156..d13a71d 100644
--- a/examples/tpch/q08_market_share.py
+++ b/examples/tpch/q08_market_share.py
@@ -30,6 +30,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
from datetime import datetime
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
supplier_nation = lit("BRAZIL")
customer_region = lit("AMERICA")
@@ -46,23 +47,23 @@ end_date = lit(datetime.strptime(END_DATE,
"%Y-%m-%d").date())
ctx = SessionContext()
-df_part = ctx.read_parquet("data/part.parquet").select_columns("p_partkey",
"p_type")
-df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
+df_part =
ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey",
"p_type")
+df_supplier =
ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_nationkey"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_partkey", "l_extendedprice", "l_discount", "l_suppkey", "l_orderkey"
)
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_orderkey", "o_custkey", "o_orderdate"
)
-df_customer = ctx.read_parquet("data/customer.parquet").select_columns(
+df_customer =
ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_custkey", "c_nationkey"
)
-df_nation = ctx.read_parquet("data/nation.parquet").select_columns(
+df_nation = ctx.read_parquet(get_data_path("nation.parquet")).select_columns(
"n_nationkey", "n_name", "n_regionkey"
)
-df_region = ctx.read_parquet("data/region.parquet").select_columns(
+df_region = ctx.read_parquet(get_data_path("region.parquet")).select_columns(
"r_regionkey", "r_name"
)
diff --git a/examples/tpch/q09_product_type_profit_measure.py
b/examples/tpch/q09_product_type_profit_measure.py
index 4fdfc1c..29ffcee 100644
--- a/examples/tpch/q09_product_type_profit_measure.py
+++ b/examples/tpch/q09_product_type_profit_measure.py
@@ -31,6 +31,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
part_color = lit("green")
@@ -38,14 +39,14 @@ part_color = lit("green")
ctx = SessionContext()
-df_part = ctx.read_parquet("data/part.parquet").select_columns("p_partkey",
"p_name")
-df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
+df_part =
ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey",
"p_name")
+df_supplier =
ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_nationkey"
)
-df_partsupp = ctx.read_parquet("data/partsupp.parquet").select_columns(
+df_partsupp =
ctx.read_parquet(get_data_path("partsupp.parquet")).select_columns(
"ps_suppkey", "ps_partkey", "ps_supplycost"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_partkey",
"l_extendedprice",
"l_discount",
@@ -53,10 +54,10 @@ df_lineitem =
ctx.read_parquet("data/lineitem.parquet").select_columns(
"l_orderkey",
"l_quantity",
)
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_orderkey", "o_custkey", "o_orderdate"
)
-df_nation = ctx.read_parquet("data/nation.parquet").select_columns(
+df_nation = ctx.read_parquet(get_data_path("nation.parquet")).select_columns(
"n_nationkey", "n_name", "n_regionkey"
)
@@ -77,7 +78,7 @@ df = df.select(
col("n_name").alias("nation"),
F.datepart(lit("year"),
col("o_orderdate")).cast(pa.int32()).alias("o_year"),
(
- col("l_extendedprice") * (lit(1.0) - col("l_discount"))
+ (col("l_extendedprice") * (lit(1) - col("l_discount")))
- (col("ps_supplycost") * col("l_quantity"))
).alias("amount"),
)
diff --git a/examples/tpch/q10_returned_item_reporting.py
b/examples/tpch/q10_returned_item_reporting.py
index 1879027..ed88c29 100644
--- a/examples/tpch/q10_returned_item_reporting.py
+++ b/examples/tpch/q10_returned_item_reporting.py
@@ -32,6 +32,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
from datetime import datetime
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
DATE_START_OF_QUARTER = "1993-10-01"
@@ -39,13 +40,13 @@ date_start_of_quarter =
lit(datetime.strptime(DATE_START_OF_QUARTER, "%Y-%m-%d")
# Note: this is a hack on setting the values. It should be set differently once
# https://github.com/apache/datafusion-python/issues/665 is resolved.
-interval_one_quarter = lit(pa.scalar((0, 0, 120),
type=pa.month_day_nano_interval()))
+interval_one_quarter = lit(pa.scalar((0, 0, 92),
type=pa.month_day_nano_interval()))
# Load the dataframes we need
ctx = SessionContext()
-df_customer = ctx.read_parquet("data/customer.parquet").select_columns(
+df_customer =
ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_custkey",
"c_nationkey",
"c_name",
@@ -54,13 +55,13 @@ df_customer =
ctx.read_parquet("data/customer.parquet").select_columns(
"c_phone",
"c_comment",
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_extendedprice", "l_discount", "l_orderkey", "l_returnflag"
)
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_orderkey", "o_custkey", "o_orderdate"
)
-df_nation = ctx.read_parquet("data/nation.parquet").select_columns(
+df_nation = ctx.read_parquet(get_data_path("nation.parquet")).select_columns(
"n_nationkey", "n_name", "n_regionkey"
)
@@ -80,7 +81,7 @@ df = df.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]),
how="inner")
# Compute the revenue
df = df.aggregate(
[col("o_custkey")],
- [F.sum(col("l_extendedprice") * (lit(1.0) -
col("l_discount"))).alias("revenue")],
+ [F.sum(col("l_extendedprice") * (lit(1) -
col("l_discount"))).alias("revenue")],
)
# Now join in the customer data
diff --git a/examples/tpch/q11_important_stock_identification.py
b/examples/tpch/q11_important_stock_identification.py
index 78fe26d..2672487 100644
--- a/examples/tpch/q11_important_stock_identification.py
+++ b/examples/tpch/q11_important_stock_identification.py
@@ -28,6 +28,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
"""
from datafusion import SessionContext, WindowFrame, col, lit, functions as F
+from util import get_data_path
NATION = "GERMANY"
FRACTION = 0.0001
@@ -36,13 +37,13 @@ FRACTION = 0.0001
ctx = SessionContext()
-df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
+df_supplier =
ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_nationkey"
)
-df_partsupp = ctx.read_parquet("data/partsupp.parquet").select_columns(
+df_partsupp =
ctx.read_parquet(get_data_path("partsupp.parquet")).select_columns(
"ps_supplycost", "ps_availqty", "ps_suppkey", "ps_partkey"
)
-df_nation = ctx.read_parquet("data/nation.parquet").select_columns(
+df_nation = ctx.read_parquet(get_data_path("nation.parquet")).select_columns(
"n_nationkey", "n_name"
)
@@ -71,7 +72,7 @@ df = df.with_column(
)
# Limit to the parts for which there is a significant value based on the
fraction of the total
-df = df.filter(col("value") / col("total_value") > lit(FRACTION))
+df = df.filter(col("value") / col("total_value") >= lit(FRACTION))
# We only need to report on these two columns
df = df.select_columns("ps_partkey", "value")
diff --git a/examples/tpch/q12_ship_mode_order_priority.py
b/examples/tpch/q12_ship_mode_order_priority.py
index e76efa5..d3dd7d2 100644
--- a/examples/tpch/q12_ship_mode_order_priority.py
+++ b/examples/tpch/q12_ship_mode_order_priority.py
@@ -32,6 +32,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
from datetime import datetime
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
SHIP_MODE_1 = "MAIL"
SHIP_MODE_2 = "SHIP"
@@ -41,10 +42,10 @@ DATE_OF_INTEREST = "1994-01-01"
ctx = SessionContext()
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_orderkey", "o_orderpriority"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_orderkey", "l_shipmode", "l_commitdate", "l_shipdate", "l_receiptdate"
)
diff --git a/examples/tpch/q13_customer_distribution.py
b/examples/tpch/q13_customer_distribution.py
index 1eb9ca3..2b6e7e2 100644
--- a/examples/tpch/q13_customer_distribution.py
+++ b/examples/tpch/q13_customer_distribution.py
@@ -29,6 +29,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
"""
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
WORD_1 = "special"
WORD_2 = "requests"
@@ -37,10 +38,10 @@ WORD_2 = "requests"
ctx = SessionContext()
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_custkey", "o_comment"
)
-df_customer =
ctx.read_parquet("data/customer.parquet").select_columns("c_custkey")
+df_customer =
ctx.read_parquet(get_data_path("customer.parquet")).select_columns("c_custkey")
# Use a regex to remove special cases
df_orders = df_orders.filter(
@@ -51,7 +52,7 @@ df_orders = df_orders.filter(
df = df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="left")
# Find the number of orders for each customer
-df = df.aggregate([col("c_custkey")],
[F.count(col("c_custkey")).alias("c_count")])
+df = df.aggregate([col("c_custkey")],
[F.count(col("o_custkey")).alias("c_count")])
# Ultimately we want to know the number of customers that have that customer
count
df = df.aggregate([col("c_count")],
[F.count(col("c_count")).alias("custdist")])
diff --git a/examples/tpch/q14_promotion_effect.py
b/examples/tpch/q14_promotion_effect.py
index 9ec3836..333398c 100644
--- a/examples/tpch/q14_promotion_effect.py
+++ b/examples/tpch/q14_promotion_effect.py
@@ -29,6 +29,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
from datetime import datetime
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
DATE = "1995-09-01"
@@ -41,15 +42,15 @@ interval_one_month = lit(pa.scalar((0, 0, 30),
type=pa.month_day_nano_interval()
ctx = SessionContext()
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_partkey", "l_shipdate", "l_extendedprice", "l_discount"
)
-df_part = ctx.read_parquet("data/part.parquet").select_columns("p_partkey",
"p_type")
+df_part =
ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey",
"p_type")
# Check part type begins with PROMO
df_part = df_part.filter(
- F.substr(col("p_type"), lit(0), lit(6)) == lit("PROMO")
+ F.substring(col("p_type"), lit(0), lit(6)) == lit("PROMO")
).with_column("promo_factor", lit(1.0))
df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter(
diff --git a/examples/tpch/q15_top_supplier.py
b/examples/tpch/q15_top_supplier.py
index 7113e04..91af34a 100644
--- a/examples/tpch/q15_top_supplier.py
+++ b/examples/tpch/q15_top_supplier.py
@@ -29,22 +29,23 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
from datetime import datetime
import pyarrow as pa
from datafusion import SessionContext, WindowFrame, col, lit, functions as F
+from util import get_data_path
DATE = "1996-01-01"
date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date())
# Note: this is a hack on setting the values. It should be set differently once
# https://github.com/apache/datafusion-python/issues/665 is resolved.
-interval_3_months = lit(pa.scalar((0, 0, 90),
type=pa.month_day_nano_interval()))
+interval_3_months = lit(pa.scalar((0, 0, 91),
type=pa.month_day_nano_interval()))
# Load the dataframes we need
ctx = SessionContext()
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_suppkey", "l_shipdate", "l_extendedprice", "l_discount"
)
-df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
+df_supplier =
ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey",
"s_name",
"s_address",
@@ -59,7 +60,7 @@ df_lineitem = df_lineitem.filter(col("l_shipdate") >=
date_of_interest).filter(
df = df_lineitem.aggregate(
[col("l_suppkey")],
[
- F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias(
+ F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias(
"total_revenue"
)
],
diff --git a/examples/tpch/q16_part_supplier_relationship.py
b/examples/tpch/q16_part_supplier_relationship.py
index 5f941d5..0db2d1b 100644
--- a/examples/tpch/q16_part_supplier_relationship.py
+++ b/examples/tpch/q16_part_supplier_relationship.py
@@ -30,6 +30,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
BRAND = "Brand#45"
TYPE_TO_IGNORE = "MEDIUM POLISHED"
@@ -39,13 +40,13 @@ SIZES_OF_INTEREST = [49, 14, 23, 45, 19, 3, 36, 9]
ctx = SessionContext()
-df_part = ctx.read_parquet("data/part.parquet").select_columns(
+df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_brand", "p_type", "p_size"
)
-df_partsupp = ctx.read_parquet("data/partsupp.parquet").select_columns(
+df_partsupp =
ctx.read_parquet(get_data_path("partsupp.parquet")).select_columns(
"ps_suppkey", "ps_partkey"
)
-df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
+df_supplier =
ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_comment"
)
@@ -59,9 +60,9 @@ df_partsupp = df_partsupp.join(
)
# Select the parts we are interested in
-df_part = df_part.filter(col("p_brand") == lit(BRAND))
+df_part = df_part.filter(col("p_brand") != lit(BRAND))
df_part = df_part.filter(
- F.substr(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) !=
lit(TYPE_TO_IGNORE)
+ F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) !=
lit(TYPE_TO_IGNORE)
)
# Python conversion of integer to literal casts it to int64 but the data for
diff --git a/examples/tpch/q17_small_quantity_order.py
b/examples/tpch/q17_small_quantity_order.py
index aae238b..5880e7e 100644
--- a/examples/tpch/q17_small_quantity_order.py
+++ b/examples/tpch/q17_small_quantity_order.py
@@ -29,6 +29,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
"""
from datafusion import SessionContext, WindowFrame, col, lit, functions as F
+from util import get_data_path
BRAND = "Brand#23"
CONTAINER = "MED BOX"
@@ -37,10 +38,10 @@ CONTAINER = "MED BOX"
ctx = SessionContext()
-df_part = ctx.read_parquet("data/part.parquet").select_columns(
+df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_brand", "p_container"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_partkey", "l_quantity", "l_extendedprice"
)
@@ -55,7 +56,7 @@ df = df.join(df_lineitem, (["p_partkey"], ["l_partkey"]),
"inner")
# Find the average quantity
window_frame = WindowFrame("rows", None, None)
df = df.with_column(
- "avg_quantity", F.window("avg", [col("l_quantity")],
window_frame=window_frame)
+ "avg_quantity", F.window("avg", [col("l_quantity")],
window_frame=window_frame, partition_by=[col("l_partkey")])
)
df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity"))
@@ -64,6 +65,6 @@ df = df.filter(col("l_quantity") < lit(0.2) *
col("avg_quantity"))
df = df.aggregate([], [F.sum(col("l_extendedprice")).alias("total")])
# Divide by number of years in the problem statement to get average
-df = df.select((col("total") / lit(7.0)).alias("avg_yearly"))
+df = df.select((col("total") / lit(7)).alias("avg_yearly"))
df.show()
diff --git a/examples/tpch/q18_large_volume_customer.py
b/examples/tpch/q18_large_volume_customer.py
index 96ca08f..10c5f6e 100644
--- a/examples/tpch/q18_large_volume_customer.py
+++ b/examples/tpch/q18_large_volume_customer.py
@@ -27,6 +27,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
"""
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
QUANTITY = 300
@@ -34,13 +35,13 @@ QUANTITY = 300
ctx = SessionContext()
-df_customer = ctx.read_parquet("data/customer.parquet").select_columns(
+df_customer =
ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_custkey", "c_name"
)
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_orderkey", "o_custkey", "o_orderdate", "o_totalprice"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_orderkey", "l_quantity", "l_extendedprice"
)
diff --git a/examples/tpch/q19_discounted_revenue.py
b/examples/tpch/q19_discounted_revenue.py
index 20ad48a..b15cd98 100644
--- a/examples/tpch/q19_discounted_revenue.py
+++ b/examples/tpch/q19_discounted_revenue.py
@@ -28,6 +28,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
import pyarrow as pa
from datafusion import SessionContext, col, lit, udf, functions as F
+from util import get_data_path
items_of_interest = {
"Brand#12": {
@@ -51,10 +52,10 @@ items_of_interest = {
ctx = SessionContext()
-df_part = ctx.read_parquet("data/part.parquet").select_columns(
+df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_brand", "p_container", "p_size"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_partkey",
"l_quantity",
"l_shipmode",
@@ -67,9 +68,8 @@ df_lineitem =
ctx.read_parquet("data/lineitem.parquet").select_columns(
df = df_lineitem.filter(col("l_shipinstruct") == lit("DELIVER IN PERSON"))
-# Small note: The data generated uses "REG AIR" but the spec says "AIR REG"
df = df.filter(
- (col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("REG AIR"))
+ (col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("AIR REG"))
)
df = df.join(df_part, (["l_partkey"], ["p_partkey"]), "inner")
@@ -117,7 +117,7 @@ def is_of_interest(
# Turn the above function into a UDF that DataFusion can understand
is_of_interest_udf = udf(
is_of_interest,
- [pa.utf8(), pa.utf8(), pa.float32(), pa.int32()],
+ [pa.utf8(), pa.utf8(), pa.decimal128(15, 2), pa.int32()],
pa.bool_(),
"stable",
)
@@ -131,7 +131,7 @@ df = df.filter(
df = df.aggregate(
[],
- [F.sum(col("l_extendedprice") * (lit(1.0) -
col("l_discount"))).alias("revenue")],
+ [F.sum(col("l_extendedprice") * (lit(1) -
col("l_discount"))).alias("revenue")],
)
df.show()
diff --git a/examples/tpch/q20_potential_part_promotion.py
b/examples/tpch/q20_potential_part_promotion.py
index 09686db..4a60284 100644
--- a/examples/tpch/q20_potential_part_promotion.py
+++ b/examples/tpch/q20_potential_part_promotion.py
@@ -30,6 +30,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
from datetime import datetime
import pyarrow as pa
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
COLOR_OF_INTEREST = "forest"
DATE_OF_INTEREST = "1994-01-01"
@@ -39,17 +40,17 @@ NATION_OF_INTEREST = "CANADA"
ctx = SessionContext()
-df_part = ctx.read_parquet("data/part.parquet").select_columns("p_partkey",
"p_name")
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_part =
ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey",
"p_name")
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_shipdate", "l_partkey", "l_suppkey", "l_quantity"
)
-df_partsupp = ctx.read_parquet("data/partsupp.parquet").select_columns(
+df_partsupp =
ctx.read_parquet(get_data_path("partsupp.parquet")).select_columns(
"ps_partkey", "ps_suppkey", "ps_availqty"
)
-df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
+df_supplier =
ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_address", "s_name", "s_nationkey"
)
-df_nation = ctx.read_parquet("data/nation.parquet").select_columns(
+df_nation = ctx.read_parquet(get_data_path("nation.parquet")).select_columns(
"n_nationkey", "n_name"
)
@@ -62,7 +63,7 @@ interval = pa.scalar((0, 0, 365),
type=pa.month_day_nano_interval())
# Filter down dataframes
df_nation = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST))
df_part = df_part.filter(
- F.substr(col("p_name"), lit(0), lit(len(COLOR_OF_INTEREST) + 1))
+ F.substring(col("p_name"), lit(0), lit(len(COLOR_OF_INTEREST) + 1))
== lit(COLOR_OF_INTEREST)
)
@@ -90,7 +91,7 @@ df = df.join(df_supplier, (["ps_suppkey"], ["s_suppkey"]),
"inner")
df = df.join(df_nation, (["s_nationkey"], ["n_nationkey"]), "inner")
# Restrict to the requested data per the problem statement
-df = df.select_columns("s_name", "s_address")
+df = df.select_columns("s_name", "s_address").distinct()
df = df.sort(col("s_name").sort())
diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py
b/examples/tpch/q21_suppliers_kept_orders_waiting.py
index 2f58d6e..9f59804 100644
--- a/examples/tpch/q21_suppliers_kept_orders_waiting.py
+++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py
@@ -27,6 +27,7 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
"""
from datafusion import SessionContext, col, lit, functions as F
+from util import get_data_path
NATION_OF_INTEREST = "SAUDI ARABIA"
@@ -34,16 +35,16 @@ NATION_OF_INTEREST = "SAUDI ARABIA"
ctx = SessionContext()
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
+df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_orderkey", "o_orderstatus"
)
-df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
+df_lineitem =
ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_orderkey", "l_receiptdate", "l_commitdate", "l_suppkey"
)
-df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
+df_supplier =
ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_name", "s_nationkey"
)
-df_nation = ctx.read_parquet("data/nation.parquet").select_columns(
+df_nation = ctx.read_parquet(get_data_path("nation.parquet")).select_columns(
"n_nationkey", "n_name"
)
@@ -107,7 +108,7 @@ df = df.join(df_suppliers_of_interest, (["suppkey"],
["s_suppkey"]), "inner")
df = df.aggregate([col("s_name")],
[F.count(col("o_orderkey")).alias("numwait")])
# Return in descending order
-df = df.sort(col("numwait").sort(ascending=False))
+df = df.sort(col("numwait").sort(ascending=False), col("s_name").sort())
df = df.limit(100)
diff --git a/examples/tpch/q22_global_sales_opportunity.py
b/examples/tpch/q22_global_sales_opportunity.py
index d2d0c5a..dfde19c 100644
--- a/examples/tpch/q22_global_sales_opportunity.py
+++ b/examples/tpch/q22_global_sales_opportunity.py
@@ -27,27 +27,28 @@ as part of their TPC Benchmark H Specification revision
2.18.0.
"""
from datafusion import SessionContext, WindowFrame, col, lit, functions as F
+from util import get_data_path
-NATION_CODE = 13
+NATION_CODES = [13, 31, 23, 29, 30, 18, 17]
# Load the dataframes we need
ctx = SessionContext()
-df_customer = ctx.read_parquet("data/customer.parquet").select_columns(
+df_customer =
ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_phone", "c_acctbal", "c_custkey"
)
-df_orders = ctx.read_parquet("data/orders.parquet").select_columns("o_custkey")
+df_orders =
ctx.read_parquet(get_data_path("orders.parquet")).select_columns("o_custkey")
# The nation code is a two digit number, but we need to convert it to a string
literal
-nation_code = lit(str(NATION_CODE))
+nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES])
# Use the substring operation to extract the first two charaters of the phone
number
-df = df_customer.with_column("cntrycode", F.substr(col("c_phone"), lit(0),
lit(3)))
+df = df_customer.with_column("cntrycode", F.substring(col("c_phone"), lit(0),
lit(3)))
# Limit our search to customers with some balance and in the country code above
df = df.filter(col("c_acctbal") > lit(0.0))
-df = df.filter(nation_code == col("cntrycode"))
+df = df.filter(~F.array_position(nation_codes, col("cntrycode")).is_null())
# Compute the average balance. By default, the window frame is from unbounded
preceeding to the
# current row. We want our frame to cover the entire data frame.
@@ -56,6 +57,7 @@ df = df.with_column(
"avg_balance", F.window("avg", [col("c_acctbal")],
window_frame=window_frame)
)
+df.show()
# Limit results to customers with above average balance
df = df.filter(col("c_acctbal") > col("avg_balance"))
diff --git a/examples/tpch/util.py b/examples/tpch/util.py
new file mode 100644
index 0000000..191fa60
--- /dev/null
+++ b/examples/tpch/util.py
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Common utilities for running TPC-H examples.
+"""
+
+import os
+from pathlib import Path
+
+def get_data_path(filename: str) -> str:
+ path = os.path.dirname(os.path.abspath(__file__))
+
+ return os.path.join(path, "data", filename)
+
+def get_answer_file(answer_file: str) -> str:
+ path = os.path.dirname(os.path.abspath(__file__))
+
+ return os.path.join(path, "../../benchmarks/tpch/data/answers",
f"{answer_file}.out")
diff --git a/src/functions.rs b/src/functions.rs
index 975025b..a4bd986 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -477,6 +477,7 @@ expr_fn!(sqrt, num);
expr_fn!(starts_with, string prefix, "Returns true if string starts with
prefix.");
expr_fn!(strpos, string substring, "Returns starting index of specified
substring within string, or zero if it's not present. (Same as
position(substring in string), but note the reversed argument order.)");
expr_fn!(substr, string position);
+expr_fn!(substring, string position length);
expr_fn!(tan, num);
expr_fn!(tanh, num);
expr_fn!(
@@ -717,6 +718,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(strpos))?;
m.add_wrapped(wrap_pyfunction!(r#struct))?; // Use raw identifier since
struct is a keyword
m.add_wrapped(wrap_pyfunction!(substr))?;
+ m.add_wrapped(wrap_pyfunction!(substring))?;
m.add_wrapped(wrap_pyfunction!(sum))?;
m.add_wrapped(wrap_pyfunction!(tan))?;
m.add_wrapped(wrap_pyfunction!(tanh))?;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]