Hi,

I have encountered a weird and potentially dangerous behaviour of Spark 
concerning
partial overwrites of partitioned data. Not sure if this is a bug or just 
abstraction
leak. I have checked Spark section of Stack Overflow and haven't found any 
relevant
questions or answers.

Full minimal working example provided as attachment. Tested on Databricks 
runtime 7.3 LTS
ML (Spark 3.0.1). Short summary:

Write dataframe using partitioning by a column using saveAsTable. Filter out 
part of the
dataframe, change some values (simulates new increment of data) and write again,
overwriting a subset of partitions using insertInto. This operation will either 
fail on
schema mismatch or cause data corruption.

Reason: on the first write, the ordering of the columns is changed (partition 
column is
placed at the end). On the second write this is not taken into consideration 
and Spark
tries to insert values into the columns based on their order and not on their 
name. If
they have different types this will fail. If not, values will be written to 
incorrect
columns causing data corruption.

My question: is this a bug or intended behaviour? Can something be done about 
it to prevent
it? This issue can be avoided by doing a select with schema loaded from the 
target table.
However, when user is not aware this could cause hard to track down errors in 
data.

Best regards,
Oldřich Vlašic
# Databricks notebook source
import pyspark.sql.functions as F

spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")

# COMMAND ----------

print(spark.version)
# 3.0.1

# COMMAND ----------

table_name = "insert_into_mve_301"
spark.sql("DROP TABLE IF EXISTS insert_into_mve_301")

# COMMAND ----------

df_data = (
    spark.range(10_000)
    .withColumnRenamed("id", "x")
    .crossJoin(
        spark.range(5).withColumnRenamed("id", "y")
    )
    .withColumn("z", (F.rand() > 0.5).cast("integer"))
    .repartitionByRange(5, "y")
)

# COMMAND ----------

print(df_data.filter(F.col("x") < 5).show())

"""
+---+---+---+
|  x|  y|  z|
+---+---+---+
|  0|  0|  0|
|  1|  0|  0|
|  2|  0|  1|
|  3|  0|  1|
|  4|  0|  0|
|  0|  1|  1|
|  1|  1|  1|
|  2|  1|  1|
|  3|  1|  0|
|  4|  1|  1|
|  0|  2|  1|
|  1|  2|  0|
|  2|  2|  0|
|  3|  2|  0|
|  4|  2|  1|
|  0|  3|  0|
|  1|  3|  1|
|  2|  3|  0|
|  3|  3|  0|
|  4|  3|  0|
+---+---+---+
only showing top 20 rows
"""

# COMMAND ----------

(
    df_data
    .write
    .mode("overwrite")
    .partitionBy("y")
    .format("parquet")
#     .option("path", "dbfs:/ov_test/foo_01")
    .saveAsTable(table_name)
)

# COMMAND ----------

df_increment = (
    df_data.filter(F.col("y") < 2)  # rewrite only some partitions
    .withColumn("z", F.lit(42))  # change value so that we can see it
)

print(df_increment.filter(F.col("x") < 5).show())

"""
+---+---+---+
|  x|  y|  z|
+---+---+---+
|  0|  0| 42|
|  1|  0| 42|
|  2|  0| 42|
|  3|  0| 42|
|  4|  0| 42|
|  0|  1| 42|
|  1|  1| 42|
|  2|  1| 42|
|  3|  1| 42|
|  4|  1| 42|
+---+---+---+
"""

# COMMAND ----------

(
    df_increment
    .write
    .mode("overwrite")
    .insertInto(table_name)
)

# COMMAND ----------

print(
    spark
    .table(table_name)
    .filter(F.col("y") == 42)  # note that we inserted value 42 to column "z"
    .limit(10)
    .show()
)

"""
+---+---+---+
|  x|  z|  y|
+---+---+---+
|  0|  1| 42|
|  1|  1| 42|
|  2|  1| 42|
|  3|  1| 42|
|  4|  1| 42|
|  5|  1| 42|
|  6|  1| 42|
|  7|  1| 42|
|  8|  1| 42|
|  9|  1| 42|
+---+---+---+
"""
---------------------------------------------------------------------
To unsubscribe e-mail: user-unsubscr...@spark.apache.org

Reply via email to