wForget commented on code in PR #3214:
URL: https://github.com/apache/datafusion-comet/pull/3214#discussion_r2704353956


##########
spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala:
##########
@@ -228,4 +230,299 @@ class CometParquetWriterSuite extends CometTestBase {
       }
     }
   }
+
+  // ===== Complex Type Tests =====
+
+  private def writeComplexTypeData(
+      inputDf: DataFrame,
+      outputPath: String,
+      expectedRows: Int): Unit = {
+    withTempPath { inputDir =>
+      val inputPath = new File(inputDir, "input.parquet").getAbsolutePath
+
+      // First write the input data without Comet to get proper Arrow arrays 
when reading
+      withSQLConf(
+        CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false",
+        SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") {
+        inputDf.write.parquet(inputPath)
+      }
+
+      // Now read and write with Comet native writer
+      // Use auto scan mode so native_iceberg_compat is used (which supports 
complex types)
+      // instead of native_comet. This overrides the COMET_PARQUET_SCAN_IMPL 
env var set by CI.
+      withSQLConf(
+        CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
+        CometConf.COMET_NATIVE_SCAN_IMPL.key -> "auto",
+        CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> "true",
+        SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
+        
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> 
"true",
+        CometConf.COMET_EXEC_ENABLED.key -> "true") {
+
+        val parquetDf = spark.read.parquet(inputPath)
+        parquetDf.write.parquet(outputPath)
+
+        // Verify round-trip: read with Spark and Comet, compare results
+        var sparkDf: DataFrame = null
+        var cometDf: DataFrame = null
+        withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") {
+          sparkDf = spark.read.parquet(outputPath)
+        }
+        withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {
+          cometDf = spark.read.parquet(outputPath)
+        }
+
+        assert(sparkDf.count() == expectedRows, s"Expected $expectedRows rows")
+        checkAnswer(sparkDf, cometDf)
+        checkAnswer(parquetDf, sparkDf)
+      }
+    }
+  }
+
+  test("parquet write with array type") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      val df = Seq((1, Seq(1, 2, 3)), (2, Seq(4, 5)), (3, Seq[Int]()), (4, 
Seq(6, 7, 8, 9)))
+        .toDF("id", "values")
+
+      writeComplexTypeData(df, outputPath, 4)
+    }
+  }
+
+  test("parquet write with struct type") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      val df =
+        Seq((1, ("Alice", 30)), (2, ("Bob", 25)), (3, ("Charlie", 
35))).toDF("id", "person")
+
+      writeComplexTypeData(df, outputPath, 3)
+    }
+  }
+
+  test("parquet write with map type") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      val df = Seq(
+        (1, Map("a" -> 1, "b" -> 2)),
+        (2, Map("c" -> 3)),
+        (3, Map[String, Int]()),
+        (4, Map("d" -> 4, "e" -> 5, "f" -> 6))).toDF("id", "properties")
+
+      writeComplexTypeData(df, outputPath, 4)
+    }
+  }
+
+  test("parquet write with array of structs") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      val df = Seq(
+        (1, Seq(("Alice", 30), ("Bob", 25))),
+        (2, Seq(("Charlie", 35))),
+        (3, Seq[(String, Int)]())).toDF("id", "people")
+
+      writeComplexTypeData(df, outputPath, 3)
+    }
+  }
+
+  test("parquet write with struct containing array") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      val df = spark.sql("""
+        SELECT
+          1 as id,
+          named_struct('name', 'Team A', 'scores', array(95, 87, 92)) as team
+        UNION ALL SELECT
+          2 as id,
+          named_struct('name', 'Team B', 'scores', array(88, 91)) as team
+        UNION ALL SELECT
+          3 as id,
+          named_struct('name', 'Team C', 'scores', array(100)) as team
+      """)
+
+      writeComplexTypeData(df, outputPath, 3)
+    }
+  }
+
+  test("parquet write with map with struct values") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      val df = spark.sql("""
+        SELECT
+          1 as id,
+          map('emp1', named_struct('name', 'Alice', 'age', 30),
+              'emp2', named_struct('name', 'Bob', 'age', 25)) as employees
+        UNION ALL SELECT
+          2 as id,
+          map('emp3', named_struct('name', 'Charlie', 'age', 35)) as employees
+      """)
+
+      writeComplexTypeData(df, outputPath, 2)
+    }
+  }
+
+  test("parquet write with deeply nested types") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      // Create deeply nested structure: array of maps containing arrays
+      val df = spark.sql("""
+        SELECT
+          1 as id,
+          array(
+            map('key1', array(1, 2, 3), 'key2', array(4, 5)),
+            map('key3', array(6, 7, 8, 9))
+          ) as nested_data
+        UNION ALL SELECT
+          2 as id,
+          array(
+            map('key4', array(10, 11))
+          ) as nested_data
+      """)
+
+      writeComplexTypeData(df, outputPath, 2)
+    }
+  }
+
+  test("parquet write with nullable complex types") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      // Test nulls at various levels
+      val df = spark.sql("""
+        SELECT
+          1 as id,
+          array(1, null, 3) as arr_with_nulls,
+          named_struct('a', 1, 'b', cast(null as int)) as struct_with_nulls,
+          map('x', 1, 'y', cast(null as int)) as map_with_nulls
+        UNION ALL SELECT
+          2 as id,
+          cast(null as array<int>) as arr_with_nulls,
+          cast(null as struct<a:int, b:int>) as struct_with_nulls,
+          cast(null as map<string, int>) as map_with_nulls
+        UNION ALL SELECT
+          3 as id,
+          array(4, 5, 6) as arr_with_nulls,
+          named_struct('a', 7, 'b', 8) as struct_with_nulls,
+          map('z', 9) as map_with_nulls
+      """)
+
+      writeComplexTypeData(df, outputPath, 3)
+    }
+  }
+
+  test("parquet write with decimal types within complex types") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      val df = spark.sql("""
+        SELECT
+          1 as id,
+          array(cast(1.23 as decimal(10,2)), cast(4.56 as decimal(10,2))) as 
decimal_arr,
+          named_struct('amount', cast(99.99 as decimal(10,2))) as 
decimal_struct,
+          map('price', cast(19.99 as decimal(10,2))) as decimal_map
+        UNION ALL SELECT
+          2 as id,
+          array(cast(7.89 as decimal(10,2))) as decimal_arr,
+          named_struct('amount', cast(0.01 as decimal(10,2))) as 
decimal_struct,
+          map('price', cast(0.50 as decimal(10,2))) as decimal_map
+      """)
+
+      writeComplexTypeData(df, outputPath, 2)
+    }
+  }
+
+  test("parquet write with temporal types within complex types") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      val df = spark.sql("""
+        SELECT
+          1 as id,
+          array(date'2024-01-15', date'2024-02-20') as date_arr,
+          named_struct('ts', timestamp'2024-01-15 10:30:00') as ts_struct,
+          map('event', timestamp'2024-03-01 14:00:00') as ts_map
+        UNION ALL SELECT
+          2 as id,
+          array(date'2024-06-30') as date_arr,
+          named_struct('ts', timestamp'2024-07-04 12:00:00') as ts_struct,
+          map('event', timestamp'2024-12-25 00:00:00') as ts_map
+      """)
+
+      writeComplexTypeData(df, outputPath, 2)
+    }
+  }
+
+  test("parquet write with empty arrays and maps") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      val df = Seq(
+        (1, Seq[Int](), Map[String, Int]()),
+        (2, Seq(1, 2), Map("a" -> 1)),
+        (3, Seq[Int](), Map[String, Int]())).toDF("id", "arr", "mp")
+
+      writeComplexTypeData(df, outputPath, 3)
+    }
+  }
+
+  test("parquet write complex types fuzz test") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
+
+      withTempPath { inputDir =>
+        val inputPath = new File(inputDir, "input.parquet").getAbsolutePath
+
+        // Generate test data with complex types enabled
+        val schema = FuzzDataGenerator.generateSchema(
+          SchemaGenOptions(generateArray = true, generateStruct = true, 
generateMap = true))
+        val df = FuzzDataGenerator.generateDataFrame(
+          new Random(42),
+          spark,
+          schema,
+          500,
+          DataGenOptions(generateNegativeZero = false))
+
+        // Write input data without Comet
+        withSQLConf(
+          CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false",
+          SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") {
+          df.write.parquet(inputPath)
+        }
+
+        // Write with Comet native writer
+        // Use auto scan mode so native_iceberg_compat is used (which supports 
complex types)
+        // instead of native_comet. This overrides the COMET_PARQUET_SCAN_IMPL 
env var set by CI.
+        withSQLConf(
+          CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
+          CometConf.COMET_NATIVE_SCAN_IMPL.key -> "auto",
+          CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> "true",
+          SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
+          
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> 
"true",
+          CometConf.COMET_EXEC_ENABLED.key -> "true") {
+
+          val inputDf = spark.read.parquet(inputPath)
+          inputDf.write.parquet(outputPath)
+
+          // Verify round-trip
+          var sparkDf: DataFrame = null
+          var cometDf: DataFrame = null
+          withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") {
+            sparkDf = spark.read.parquet(outputPath)
+          }
+          withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {
+            cometDf = spark.read.parquet(outputPath)
+          }
+
+          assert(sparkDf.count() == 500, "Expected 500 rows")
+          checkAnswer(sparkDf, cometDf)

Review Comment:
   ditto



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to