(Branching from the previous discussion, as Micah pointed out another
interesting aspect)
Consider the list
[[0, 1], None, [2, None, 3], [4, 5, 6], [], [7, 8, 9], None, [10]]
for the schema
optional group column1 (LIST) {
repeated group list {
optional int32 element;
}
}
When looking at the row group statistics, pyarrow 4 seems to report a null
count of 1 while spark 3 reports a null count of 4 (see attached script for
the writing and reading of the statistics).
I am a bit lost on which should be the intended result. Isn't spark using
the official Java implementation? a null count of 4 seems a bit odd in the
example above.
Best,
Jorge
import pyarrow as pa
import pyarrow.parquet
import os
import shutil
PYARROW_PATH = "fixtures/pyarrow4"
PYSPARK_PATH = "fixtures/pyspark3"
def case_nested():
items = [[0, 1], None, [2, None, 3], [4, 5, 6], [], [7, 8, 9], None, [10]]
fields = [
pa.field("list_int64", pa.list_(pa.int64())),
]
schema = pa.schema(fields)
return (
{
"list_int64": items,
},
schema,
f"nested_nullable_{10}.parquet",
)
def write_pyarrow(case, page_version=1, use_dictionary=False, use_compression=False):
data, schema, path = case()
compression_path = "/snappy" if use_compression else ""
if use_dictionary:
base_path = f"{PYARROW_PATH}/v{page_version}/dict{compression_path}"
else:
base_path = f"{PYARROW_PATH}/v{page_version}/non_dict{compression_path}"
t = pa.table(data, schema=schema)
os.makedirs(base_path, exist_ok=True)
pa.parquet.write_table(
t,
f"{base_path}/{path}",
version=f"{page_version}.0",
data_page_version=f"{page_version}.0",
write_statistics=True,
compression="snappy" if use_compression else None,
use_dictionary=use_dictionary,
)
return f"{base_path}/{path}"
def write_pyspark(case, page_version=1, use_dictionary=False, use_compression=False):
data, _, path = case()
compression_path = "/snappy" if use_compression else ""
if use_dictionary:
base_path = f"{PYSPARK_PATH}/v{page_version}/dict{compression_path}"
else:
base_path = f"{PYSPARK_PATH}/v{page_version}/non_dict{compression_path}"
os.makedirs(base_path, exist_ok=True)
length = len(list(data.values())[0])
rows = [[x[i] for x in data.values()] for i in range(length)]
from pyspark.sql import SparkSession
spark = SparkSession.builder.config(
"spark.sql.parquet.compression.codec", "snappy" if use_compression else None
).getOrCreate()
df = spark.createDataFrame(rows, list(data.keys()))
df.repartition(1).write.parquet(f"{base_path}/{path}", mode="overwrite")
f = next(f for f in os.listdir(f"{base_path}/{path}") if f.endswith(".parquet"))
os.rename(f"{base_path}/{path}/{f}", "_temp")
shutil.rmtree(f"{base_path}/{path}")
os.rename("_temp", f"{base_path}/{path}")
return f"{base_path}/{path}"
def _read_column(file):
return (
pyarrow.parquet.read_metadata(
f"fixtures/{file}/v1/non_dict/nested_nullable_10.parquet"
)
.row_group(0)
.column(0)
)
write_pyspark(case_nested, 1, False, False)
write_pyarrow(case_nested, 1, False, False)
meta = _read_column("pyarrow4")
print(meta.statistics.null_count)
meta = _read_column("pyspark3")
print(meta.statistics.null_count)