HyukjinKwon commented on a change in pull request #26973: [SPARK-30323][SQL] 
Support filters pushdown in CSV datasource
URL: https://github.com/apache/spark/pull/26973#discussion_r366676649
 
 

 ##########
 File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala
 ##########
 @@ -267,4 +269,63 @@ class UnivocityParserSuite extends SparkFunSuite with 
SQLHelper {
     assert(convertedValue.isInstanceOf[UTF8String])
     assert(convertedValue == expected)
   }
+
+  test("skipping rows using pushdown filters") {
+    def check(
+        input: String = "1,a",
+        dataSchema: String = "i INTEGER, s STRING",
+        requiredSchema: String = "i INTEGER",
+        filters: Seq[Filter],
+        expected: Seq[InternalRow]): Unit = {
+      def getSchema(str: String): StructType = str match {
+        case "" => new StructType()
+        case _ => StructType.fromDDL(str)
+      }
+      Seq(false, true).foreach { columnPruning =>
+        val options = new CSVOptions(Map.empty[String, String], columnPruning, 
"GMT")
+        val parser = new UnivocityParser(
+          getSchema(dataSchema),
+          getSchema(requiredSchema),
+          options,
+          filters)
+        val actual = parser.parse(input)
+        assert(actual === expected)
+      }
+    }
+
+    check(filters = Seq(), expected = Seq(InternalRow(1)))
+    check(filters = Seq(EqualTo("i", 1)), expected = Seq(InternalRow(1)))
+    check(filters = Seq(EqualTo("i", 2)), expected = Seq())
+    check(requiredSchema = "s STRING", filters = Seq(StringStartsWith("s", 
"b")), expected = Seq())
+    check(
+      requiredSchema = "i INTEGER, s STRING",
+      filters = Seq(StringStartsWith("s", "a")),
+      expected = Seq(InternalRow(1, UTF8String.fromString("a"))))
+    check(
+      input = "1,a,3.14",
+      dataSchema = "i INTEGER, s STRING, d DOUBLE",
+      requiredSchema = "i INTEGER, d DOUBLE",
+      filters = Seq(EqualTo("d", 3.14)),
+      expected = Seq(InternalRow(1, 3.14)))
+
+    try {
+      check(filters = Seq(EqualTo("invalid attr", 1)), expected = Seq())
+      fail("Expected to throw an exception for the invalid input")
+    } catch {
+      case e: IllegalArgumentException =>
+        assert(e.getMessage.contains("invalid attr does not exist"))
+    }
+
+    try {
+      check(
 
 Review comment:
   I think it's more usual to use intercept:
   
   ```scala
         val errMsg = intercept[IllegalArgumentException] {
           ...
         }.getMessage
        assert(errMsg.contains("..."))
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to