This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-benchmarks.git


The following commit(s) were added to refs/heads/main by this push:
     new 55642ef  fix: Specify schema when converting TPC-H csv to parquet (#3)
55642ef is described below

commit 55642ef7025b701967ba1f12f6071828ef0b096b
Author: Andy Grove <[email protected]>
AuthorDate: Tue May 21 13:11:37 2024 -0600

    fix: Specify schema when converting TPC-H csv to parquet (#3)
    
    * specify schema when reading csv files
    
    * use snappy compression
    
    * fix
---
 tpch/tpchgen.py | 115 +++++++++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 109 insertions(+), 6 deletions(-)

diff --git a/tpch/tpchgen.py b/tpch/tpchgen.py
index a0cd406..b7920bf 100644
--- a/tpch/tpchgen.py
+++ b/tpch/tpchgen.py
@@ -19,11 +19,100 @@ import argparse
 import concurrent.futures
 from datafusion import SessionContext
 import os
+import pyarrow
 import subprocess
 import time
 
 table_names = ["customer", "lineitem", "nation", "orders", "part", "partsupp", 
"region", "supplier"]
 
+# schema definition copied from DataFusion Python tpch example
+all_schemas = {}
+
+all_schemas["customer"] = [
+    ("C_CUSTKEY", pyarrow.int32()),
+    ("C_NAME", pyarrow.string()),
+    ("C_ADDRESS", pyarrow.string()),
+    ("C_NATIONKEY", pyarrow.int32()),
+    ("C_PHONE", pyarrow.string()),
+    ("C_ACCTBAL", pyarrow.decimal128(15, 2)),
+    ("C_MKTSEGMENT", pyarrow.string()),
+    ("C_COMMENT", pyarrow.string()),
+]
+
+all_schemas["lineitem"] = [
+    ("L_ORDERKEY", pyarrow.int32()),
+    ("L_PARTKEY", pyarrow.int32()),
+    ("L_SUPPKEY", pyarrow.int32()),
+    ("L_LINENUMBER", pyarrow.int32()),
+    ("L_QUANTITY", pyarrow.decimal128(15, 2)),
+    ("L_EXTENDEDPRICE", pyarrow.decimal128(15, 2)),
+    ("L_DISCOUNT", pyarrow.decimal128(15, 2)),
+    ("L_TAX", pyarrow.decimal128(15, 2)),
+    ("L_RETURNFLAG", pyarrow.string()),
+    ("L_LINESTATUS", pyarrow.string()),
+    ("L_SHIPDATE", pyarrow.date32()),
+    ("L_COMMITDATE", pyarrow.date32()),
+    ("L_RECEIPTDATE", pyarrow.date32()),
+    ("L_SHIPINSTRUCT", pyarrow.string()),
+    ("L_SHIPMODE", pyarrow.string()),
+    ("L_COMMENT", pyarrow.string()),
+]
+
+all_schemas["nation"] = [
+    ("N_NATIONKEY", pyarrow.int32()),
+    ("N_NAME", pyarrow.string()),
+    ("N_REGIONKEY", pyarrow.int32()),
+    ("N_COMMENT", pyarrow.string()),
+]
+
+all_schemas["orders"] = [
+    ("O_ORDERKEY", pyarrow.int32()),
+    ("O_CUSTKEY", pyarrow.int32()),
+    ("O_ORDERSTATUS", pyarrow.string()),
+    ("O_TOTALPRICE", pyarrow.decimal128(15, 2)),
+    ("O_ORDERDATE", pyarrow.date32()),
+    ("O_ORDERPRIORITY", pyarrow.string()),
+    ("O_CLERK", pyarrow.string()),
+    ("O_SHIPPRIORITY", pyarrow.int32()),
+    ("O_COMMENT", pyarrow.string()),
+]
+
+all_schemas["part"] = [
+    ("P_PARTKEY", pyarrow.int32()),
+    ("P_NAME", pyarrow.string()),
+    ("P_MFGR", pyarrow.string()),
+    ("P_BRAND", pyarrow.string()),
+    ("P_TYPE", pyarrow.string()),
+    ("P_SIZE", pyarrow.int32()),
+    ("P_CONTAINER", pyarrow.string()),
+    ("P_RETAILPRICE", pyarrow.decimal128(15, 2)),
+    ("P_COMMENT", pyarrow.string()),
+]
+
+all_schemas["partsupp"] = [
+    ("PS_PARTKEY", pyarrow.int32()),
+    ("PS_SUPPKEY", pyarrow.int32()),
+    ("PS_AVAILQTY", pyarrow.int32()),
+    ("PS_SUPPLYCOST", pyarrow.decimal128(15, 2)),
+    ("PS_COMMENT", pyarrow.string()),
+]
+
+all_schemas["region"] = [
+    ("R_REGIONKEY", pyarrow.int32()),
+    ("R_NAME", pyarrow.string()),
+    ("R_COMMENT", pyarrow.string()),
+]
+
+all_schemas["supplier"] = [
+    ("S_SUPPKEY", pyarrow.int32()),
+    ("S_NAME", pyarrow.string()),
+    ("S_ADDRESS", pyarrow.string()),
+    ("S_NATIONKEY", pyarrow.int32()),
+    ("S_PHONE", pyarrow.string()),
+    ("S_ACCTBAL", pyarrow.decimal128(15, 2)),
+    ("S_COMMENT", pyarrow.string()),
+]
+
 def run(cmd: str):
     print(f"Executing: {cmd}")
     subprocess.run(cmd, shell=True, check=True)
@@ -33,10 +122,24 @@ def run_and_log_output(cmd: str, log_file: str):
     with open(log_file, "w") as file:
         subprocess.run(cmd, shell=True, check=True, stdout=file, 
stderr=subprocess.STDOUT)
 
-def convert_tbl_to_parquet(ctx: SessionContext, tbl_filename: str, 
file_extension: str, parquet_filename: str):
+def convert_tbl_to_parquet(ctx: SessionContext, table: str, tbl_filename: str, 
file_extension: str, parquet_filename: str):
     print(f"Converting {tbl_filename} to {parquet_filename} ...")
-    df = ctx.read_csv(tbl_filename, has_header=False, 
file_extension=file_extension, delimiter="|")
-    df.write_parquet(parquet_filename)
+
+    # schema manipulation code copied from DataFusion Python tpch example
+    table_schema = [(r[0].lower(), r[1]) for r in all_schemas[table]]
+
+    # Pre-collect the output columns so we can ignore the null field we add
+    # in to handle the trailing | in the file
+    output_cols = [r[0] for r in table_schema]
+
+    # Trailing | requires extra field for in processing
+    table_schema.append(("some_null", pyarrow.null()))
+
+    schema = pyarrow.schema(table_schema)
+
+    df = ctx.read_csv(tbl_filename, schema=schema, has_header=False, 
file_extension=file_extension, delimiter="|")
+    df = df.select_columns(*output_cols)
+    df.write_parquet(parquet_filename, compression="snappy")
 
 def generate_tpch(scale_factor: int, partitions: int):
     start_time = time.time()
@@ -47,7 +150,7 @@ def generate_tpch(scale_factor: int, partitions: int):
         # convert to parquet
         ctx = SessionContext()
         for table in table_names:
-            convert_tbl_to_parquet(ctx, f"data/{table}.tbl", "tbl", 
f"data/{table}.parquet")
+            convert_tbl_to_parquet(ctx, table, f"data/{table}.tbl", "tbl", 
f"data/{table}.parquet")
 
     else:
 
@@ -77,10 +180,10 @@ def generate_tpch(scale_factor: int, partitions: int):
             run(f"mkdir -p data/{table}.parquet")
             if table == "nation" or table == "region":
                 # nation and region are special cases and do not generate 
multiple files
-                convert_tbl_to_parquet(ctx, f"data/{table}.tbl", "tbl", 
f"data/{table}.parquet/part1.parquet")
+                convert_tbl_to_parquet(ctx, table, f"data/{table}.tbl", "tbl", 
f"data/{table}.parquet/part1.parquet")
             else:
                 for part in range(1, partitions + 1):
-                    convert_tbl_to_parquet(ctx, f"data/{table}.tbl.{part}", 
f"tbl.{part}", f"data/{table}.parquet/part{part}.parquet")
+                    convert_tbl_to_parquet(ctx, table, 
f"data/{table}.tbl.{part}", f"tbl.{part}", 
f"data/{table}.parquet/part{part}.parquet")
 
     end_time = time.time()
     print(f"Finished in {round(end_time - start_time, 2)} seconds")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to