This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-benchmarks.git
The following commit(s) were added to refs/heads/main by this push:
new 55642ef fix: Specify schema when converting TPC-H csv to parquet (#3)
55642ef is described below
commit 55642ef7025b701967ba1f12f6071828ef0b096b
Author: Andy Grove <[email protected]>
AuthorDate: Tue May 21 13:11:37 2024 -0600
fix: Specify schema when converting TPC-H csv to parquet (#3)
* specify schema when reading csv files
* use snappy compression
* fix
---
tpch/tpchgen.py | 115 +++++++++++++++++++++++++++++++++++++++++++++++++++++---
1 file changed, 109 insertions(+), 6 deletions(-)
diff --git a/tpch/tpchgen.py b/tpch/tpchgen.py
index a0cd406..b7920bf 100644
--- a/tpch/tpchgen.py
+++ b/tpch/tpchgen.py
@@ -19,11 +19,100 @@ import argparse
import concurrent.futures
from datafusion import SessionContext
import os
+import pyarrow
import subprocess
import time
table_names = ["customer", "lineitem", "nation", "orders", "part", "partsupp",
"region", "supplier"]
+# schema definition copied from DataFusion Python tpch example
+all_schemas = {}
+
+all_schemas["customer"] = [
+ ("C_CUSTKEY", pyarrow.int32()),
+ ("C_NAME", pyarrow.string()),
+ ("C_ADDRESS", pyarrow.string()),
+ ("C_NATIONKEY", pyarrow.int32()),
+ ("C_PHONE", pyarrow.string()),
+ ("C_ACCTBAL", pyarrow.decimal128(15, 2)),
+ ("C_MKTSEGMENT", pyarrow.string()),
+ ("C_COMMENT", pyarrow.string()),
+]
+
+all_schemas["lineitem"] = [
+ ("L_ORDERKEY", pyarrow.int32()),
+ ("L_PARTKEY", pyarrow.int32()),
+ ("L_SUPPKEY", pyarrow.int32()),
+ ("L_LINENUMBER", pyarrow.int32()),
+ ("L_QUANTITY", pyarrow.decimal128(15, 2)),
+ ("L_EXTENDEDPRICE", pyarrow.decimal128(15, 2)),
+ ("L_DISCOUNT", pyarrow.decimal128(15, 2)),
+ ("L_TAX", pyarrow.decimal128(15, 2)),
+ ("L_RETURNFLAG", pyarrow.string()),
+ ("L_LINESTATUS", pyarrow.string()),
+ ("L_SHIPDATE", pyarrow.date32()),
+ ("L_COMMITDATE", pyarrow.date32()),
+ ("L_RECEIPTDATE", pyarrow.date32()),
+ ("L_SHIPINSTRUCT", pyarrow.string()),
+ ("L_SHIPMODE", pyarrow.string()),
+ ("L_COMMENT", pyarrow.string()),
+]
+
+all_schemas["nation"] = [
+ ("N_NATIONKEY", pyarrow.int32()),
+ ("N_NAME", pyarrow.string()),
+ ("N_REGIONKEY", pyarrow.int32()),
+ ("N_COMMENT", pyarrow.string()),
+]
+
+all_schemas["orders"] = [
+ ("O_ORDERKEY", pyarrow.int32()),
+ ("O_CUSTKEY", pyarrow.int32()),
+ ("O_ORDERSTATUS", pyarrow.string()),
+ ("O_TOTALPRICE", pyarrow.decimal128(15, 2)),
+ ("O_ORDERDATE", pyarrow.date32()),
+ ("O_ORDERPRIORITY", pyarrow.string()),
+ ("O_CLERK", pyarrow.string()),
+ ("O_SHIPPRIORITY", pyarrow.int32()),
+ ("O_COMMENT", pyarrow.string()),
+]
+
+all_schemas["part"] = [
+ ("P_PARTKEY", pyarrow.int32()),
+ ("P_NAME", pyarrow.string()),
+ ("P_MFGR", pyarrow.string()),
+ ("P_BRAND", pyarrow.string()),
+ ("P_TYPE", pyarrow.string()),
+ ("P_SIZE", pyarrow.int32()),
+ ("P_CONTAINER", pyarrow.string()),
+ ("P_RETAILPRICE", pyarrow.decimal128(15, 2)),
+ ("P_COMMENT", pyarrow.string()),
+]
+
+all_schemas["partsupp"] = [
+ ("PS_PARTKEY", pyarrow.int32()),
+ ("PS_SUPPKEY", pyarrow.int32()),
+ ("PS_AVAILQTY", pyarrow.int32()),
+ ("PS_SUPPLYCOST", pyarrow.decimal128(15, 2)),
+ ("PS_COMMENT", pyarrow.string()),
+]
+
+all_schemas["region"] = [
+ ("R_REGIONKEY", pyarrow.int32()),
+ ("R_NAME", pyarrow.string()),
+ ("R_COMMENT", pyarrow.string()),
+]
+
+all_schemas["supplier"] = [
+ ("S_SUPPKEY", pyarrow.int32()),
+ ("S_NAME", pyarrow.string()),
+ ("S_ADDRESS", pyarrow.string()),
+ ("S_NATIONKEY", pyarrow.int32()),
+ ("S_PHONE", pyarrow.string()),
+ ("S_ACCTBAL", pyarrow.decimal128(15, 2)),
+ ("S_COMMENT", pyarrow.string()),
+]
+
def run(cmd: str):
print(f"Executing: {cmd}")
subprocess.run(cmd, shell=True, check=True)
@@ -33,10 +122,24 @@ def run_and_log_output(cmd: str, log_file: str):
with open(log_file, "w") as file:
subprocess.run(cmd, shell=True, check=True, stdout=file,
stderr=subprocess.STDOUT)
-def convert_tbl_to_parquet(ctx: SessionContext, tbl_filename: str,
file_extension: str, parquet_filename: str):
+def convert_tbl_to_parquet(ctx: SessionContext, table: str, tbl_filename: str,
file_extension: str, parquet_filename: str):
print(f"Converting {tbl_filename} to {parquet_filename} ...")
- df = ctx.read_csv(tbl_filename, has_header=False,
file_extension=file_extension, delimiter="|")
- df.write_parquet(parquet_filename)
+
+ # schema manipulation code copied from DataFusion Python tpch example
+ table_schema = [(r[0].lower(), r[1]) for r in all_schemas[table]]
+
+ # Pre-collect the output columns so we can ignore the null field we add
+ # in to handle the trailing | in the file
+ output_cols = [r[0] for r in table_schema]
+
+ # Trailing | requires extra field for in processing
+ table_schema.append(("some_null", pyarrow.null()))
+
+ schema = pyarrow.schema(table_schema)
+
+ df = ctx.read_csv(tbl_filename, schema=schema, has_header=False,
file_extension=file_extension, delimiter="|")
+ df = df.select_columns(*output_cols)
+ df.write_parquet(parquet_filename, compression="snappy")
def generate_tpch(scale_factor: int, partitions: int):
start_time = time.time()
@@ -47,7 +150,7 @@ def generate_tpch(scale_factor: int, partitions: int):
# convert to parquet
ctx = SessionContext()
for table in table_names:
- convert_tbl_to_parquet(ctx, f"data/{table}.tbl", "tbl",
f"data/{table}.parquet")
+ convert_tbl_to_parquet(ctx, table, f"data/{table}.tbl", "tbl",
f"data/{table}.parquet")
else:
@@ -77,10 +180,10 @@ def generate_tpch(scale_factor: int, partitions: int):
run(f"mkdir -p data/{table}.parquet")
if table == "nation" or table == "region":
# nation and region are special cases and do not generate
multiple files
- convert_tbl_to_parquet(ctx, f"data/{table}.tbl", "tbl",
f"data/{table}.parquet/part1.parquet")
+ convert_tbl_to_parquet(ctx, table, f"data/{table}.tbl", "tbl",
f"data/{table}.parquet/part1.parquet")
else:
for part in range(1, partitions + 1):
- convert_tbl_to_parquet(ctx, f"data/{table}.tbl.{part}",
f"tbl.{part}", f"data/{table}.parquet/part{part}.parquet")
+ convert_tbl_to_parquet(ctx, table,
f"data/{table}.tbl.{part}", f"tbl.{part}",
f"data/{table}.parquet/part{part}.parquet")
end_time = time.time()
print(f"Finished in {round(end_time - start_time, 2)} seconds")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]