andylam-db commented on code in PR #44084:
URL: https://github.com/apache/spark/pull/44084#discussion_r1411126704


##########
sql/core/src/test/scala/org/apache/spark/sql/crossdbms/CrossDbmsQueryTestSuite.scala:
##########
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.crossdbms
+
+import java.io.File
+import java.util.Locale
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SQLQueryTestSuite
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DescribeColumn, 
DescribeRelation, Distinct, Generate, Join, LogicalPlan, Sample, Sort}
+import org.apache.spark.sql.catalyst.util.stringToFile
+import org.apache.spark.sql.execution.command.{DescribeColumnCommand, 
DescribeCommandBase}
+import org.apache.spark.util.Utils
+
+// scalastyle:off line.size.limit
+/**
+ * See SQLQueryTestSuite.scala for more information. This class builds off of 
that to allow us
+ * to generate golden files with other DBMS to perform cross-checking for 
correctness. Note that the
+ * input directory path is currently limited because most, if not all, of our 
current SQL query
+ * tests will not be compatible with other DBMSes. There will be more work in 
the future, such as
+ * some kind of conversion, to increase coverage.
+ *
+ * If your SQL query test is not compatible with other DBMSes, please add it 
to the `ignoreList` at
+ * the bottom of this file.
+ *
+ * You need to have a database server up before running this test.
+ * For example, for postgres:
+ * 1. Install PostgreSQL.
+ *   a. On a mac: `brew install postgresql@13`
+ * 2. After installing PostgreSQL, start the database server, then create a 
role named pg with
+ *    superuser permissions: `createuser -s pg`` OR `psql> CREATE role pg 
superuser``
+ *
+ * To run the entire test suite:
+ * {{{
+ *   build/sbt "sql/testOnly 
org.apache.spark.sql.crossdbms.CrossDbmsQueryTestSuite"
+ * }}}
+ *
+ * To re-generate golden files for entire suite, run:
+ * {{{
+ *   SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly 
org.apache.spark.sql.crossdbms.CrossDbmsQueryTestSuite"
+ * }}}
+ *
+ * To re-generate golden file for a single test, run:
+ * {{{
+ *   SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly 
org.apache.spark.sql.crossdbms.CrossDbmsQueryTestSuite -- -z describe.sql"
+ * }}}
+ *
+ * To specify a DBMS to use (the default is postgres):
+ * {{{
+ *   REF_DBMS=mysql SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly 
org.apache.spark.sql.crossdbms.CrossDbmsQueryTestSuite"
+ * }}}
+ */
+// scalastyle:on line.size.limit
+class CrossDbmsQueryTestSuite extends SQLQueryTestSuite with Logging {
+
+  // Note: the below two functions have to be functions instead of variables 
because the superclass
+  // runs the test first before the subclass variables can be instantiated.
+  private def crossDbmsToGenerateGoldenFiles: String = {
+    val userInputDbms = System.getenv("REF_DBMS")
+    if (userInputDbms != null && userInputDbms.nonEmpty) {
+      assert(CrossDbmsQueryTestSuite.SUPPORTED_DBMS.contains(userInputDbms),
+        s"$userInputDbms is not currently supported.")
+      userInputDbms
+    } else {
+      CrossDbmsQueryTestSuite.DEFAULT_DBMS
+    }
+  }
+  private def customConnectionUrl: String = 
System.getenv("REF_DBMS_CONNECTION_URL")
+
+  override protected def runQueries(
+    queries: Seq[String],
+    testCase: TestCase,
+    configSet: Seq[(String, String)]): Unit = {
+    val localSparkSession = spark.newSession()
+
+    var runner: Option[SQLQueryTestRunner] = None
+    val outputs: Seq[QueryTestOutput] = queries.map { sql =>
+      val output = {
+        // Use the runner when generating golden files, and Spark when running 
the test against
+        // the already generated golden files.
+        if (regenerateGoldenFiles) {
+          if (runner.isEmpty) {
+            val connectionUrl = if (customConnectionUrl != null && 
customConnectionUrl.nonEmpty) {
+              Some(customConnectionUrl)
+            } else {
+              None
+            }
+            runner = Some(CrossDbmsQueryTestSuite.DBMS_TO_CONNECTION_MAPPING(
+              crossDbmsToGenerateGoldenFiles)(connectionUrl))
+          }
+          val sparkDf = spark.sql(sql)
+          val output = runner.map(_.runQuery(sql)).get
+          // Use Spark analyzed plan to check if the query result is already 
semantically sorted
+          val result = if 
(isSemanticallySorted(sparkDf.queryExecution.analyzed)) {
+            output
+          } else {
+            // Sort the answer manually if it isn't sorted.
+            output.sorted
+          }
+          result
+        } else {
+          
handleExceptions(getNormalizedQueryExecutionResult(localSparkSession, sql))._2
+        }
+      }
+      // We do some query canonicalization now.
+      val executionOutput = ExecutionOutput(
+        sql = sql,
+        // Don't care about the schema for this test. Only care about 
correctness.
+        schema = None,
+        output = normalizeTestResults(output.mkString("\n")))
+      if (testCase.isInstanceOf[CTETest]) {
+        expandCTEQueryAndCompareResult(localSparkSession, sql, executionOutput)
+      }
+      executionOutput
+    }
+    runner.foreach(_.cleanUp())
+
+    if (regenerateGoldenFiles) {
+      val goldenOutput = {
+        s"-- Automatically generated by ${getClass.getSimpleName}\n" +
+          outputs.mkString("\n\n\n") + "\n"
+      }
+      val resultFile = new File(testCase.resultFile)
+      val parent = resultFile.getParentFile
+      if (!parent.exists()) {
+        assert(parent.mkdirs(), "Could not create directory: " + parent)
+      }
+      stringToFile(resultFile, goldenOutput)
+    }
+
+    readGoldenFileAndCompareResults(testCase.resultFile, outputs, 
ExecutionOutput)
+  }
+
+  override def createScalaTestCase(testCase: TestCase): Unit = {
+    if (ignoreList.exists(t =>
+      
testCase.name.toLowerCase(Locale.ROOT).contains(t.toLowerCase(Locale.ROOT)))) {
+      ignore(testCase.name) {
+        /* Do nothing */
+      }
+    } else {
+      testCase match {
+        case _: RegularTestCase =>
+          // Create a test case to run this case.
+          test(testCase.name) {
+            runSqlTestCase(testCase, listTestCases)
+          }
+        case _ =>
+          ignore(s"Ignoring test cases that are not [[RegularTestCase]] for 
now") {
+            /* Do nothing */
+          }
+      }
+    }
+  }
+
+  override protected def resultFileForInputFile(file: File): String = {
+    val defaultResultsDir = new File(baseResourcePath, "results")
+    val goldenFilePath = new File(
+      defaultResultsDir, 
s"$crossDbmsToGenerateGoldenFiles-results").getAbsolutePath
+    file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out"
+  }
+
+  override lazy val listTestCases: Seq[TestCase] = {
+    listFilesRecursively(new File(inputFilePath)).flatMap { file =>
+      var resultFile = resultFileForInputFile(file)
+      // JDK-4511638 changes 'toString' result of Float/Double
+      // JDK-8282081 changes DataTimeFormatter 'F' symbol
+      if (Utils.isJavaVersionAtLeast21) {
+        if (new File(resultFile + ".java21").exists()) resultFile += ".java21"
+      }
+      val absPath = file.getAbsolutePath
+      val testCaseName = 
absPath.stripPrefix(inputFilePath).stripPrefix(File.separator)
+      RegularTestCase(testCaseName, absPath, resultFile) :: Nil
+    }.sortBy(_.name)
+  }
+
+  private def isSemanticallySorted(plan: LogicalPlan): Boolean = plan match {
+    case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => 
false
+    case _: DescribeCommandBase
+         | _: DescribeColumnCommand
+         | _: DescribeRelation
+         | _: DescribeColumn => true
+    case PhysicalOperation(_, _, Sort(_, true, _)) => true
+    case _ => plan.children.iterator.exists(isSemanticallySorted)
+  }
+
+  // Ignore all tests for now due to likely incompatibility.

Review Comment:
   Unfortunately a lot of the golden files in the existing `postgres` directory 
have also been heavily changed to make it compatible to Spark. Disclaimer, I 
didn't run through all of them.
   
   I think an ignoreList (although it's cumbersomely long) makes sense because 
it means that new SQL tests written would automatically be subject to this 
test. If the new SQL test isn't postgres compatible, then the developer would 
either modify the test (if it makes sense), or have to deliberately add it to 
the ignore list.



##########
sql/core/src/test/scala/org/apache/spark/sql/crossdbms/JdbcConnection.scala:
##########
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.crossdbms
+
+import java.sql.{DriverManager, ResultSet}
+import java.util.Properties
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry
+
+private[sql] trait JdbcConnection {
+  /**
+   * Runs the given query.
+   * @return A Seq[String] representing the output.
+   */
+  def runQuery(query: String): Seq[String]
+
+  /**
+   * Drop the table with the given table name.
+   */
+  def dropTable(tableName: String): Unit
+
+  /**
+   * Create a table with the given table name and schema.
+   */
+  def createTable(tableName: String, schemaString: String): Unit
+
+  /**
+   * Load data from the given Spark Dataframe into the table with given name.
+   */
+  def loadData(df: DataFrame, tableName: String): Unit
+
+  /**
+   * Close the connection.
+   */
+  def close(): Unit
+}
+
+private[sql] case class PostgresConnection(connection_url: Option[String] = 
None)
+  extends JdbcConnection {
+
+  DriverRegistry.register("org.postgresql.Driver")
+  private final val DEFAULT_USER = "pg"
+  private final val DEFAULT_CONNECTION_URL =
+    s"jdbc:postgresql://localhost:5432/postgres?user=$DEFAULT_USER"
+  private val url = connection_url.getOrElse(DEFAULT_CONNECTION_URL)
+  private val conn = DriverManager.getConnection(url)
+  private val stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, 
ResultSet.CONCUR_READ_ONLY)
+
+  def runQuery(query: String): Seq[String] = {
+    try {
+      val isResultSet = stmt.execute(query)
+      val rows = ArrayBuffer[Row]()
+      if (isResultSet) {
+        val rs = stmt.getResultSet
+        val metadata = rs.getMetaData
+        while (rs.next()) {
+          val row = Row.fromSeq((1 to metadata.getColumnCount).map(i => 
rs.getObject(i)))
+          rows.append(row)
+        }
+      }
+      rows.map(_.mkString("\t")).toSeq

Review Comment:
   Tabs are used in `hiveResultString` as well, which is used in 
`SQLQuerySuite`. I followed it to reduce hassle, let me know if it's a big deal.



##########
sql/core/src/test/scala/org/apache/spark/sql/crossdbms/JdbcConnection.scala:
##########
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.crossdbms
+
+import java.sql.{DriverManager, ResultSet}
+import java.util.Properties
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry
+
+private[sql] trait JdbcConnection {
+  /**
+   * Runs the given query.
+   * @return A Seq[String] representing the output.
+   */
+  def runQuery(query: String): Seq[String]
+
+  /**
+   * Drop the table with the given table name.
+   */
+  def dropTable(tableName: String): Unit
+
+  /**
+   * Create a table with the given table name and schema.
+   */
+  def createTable(tableName: String, schemaString: String): Unit
+
+  /**
+   * Load data from the given Spark Dataframe into the table with given name.
+   */
+  def loadData(df: DataFrame, tableName: String): Unit
+
+  /**
+   * Close the connection.
+   */
+  def close(): Unit
+}
+
+private[sql] case class PostgresConnection(connection_url: Option[String] = 
None)
+  extends JdbcConnection {
+
+  DriverRegistry.register("org.postgresql.Driver")
+  private final val DEFAULT_USER = "pg"
+  private final val DEFAULT_CONNECTION_URL =
+    s"jdbc:postgresql://localhost:5432/postgres?user=$DEFAULT_USER"
+  private val url = connection_url.getOrElse(DEFAULT_CONNECTION_URL)
+  private val conn = DriverManager.getConnection(url)
+  private val stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, 
ResultSet.CONCUR_READ_ONLY)
+
+  def runQuery(query: String): Seq[String] = {
+    try {
+      val isResultSet = stmt.execute(query)
+      val rows = ArrayBuffer[Row]()
+      if (isResultSet) {
+        val rs = stmt.getResultSet
+        val metadata = rs.getMetaData
+        while (rs.next()) {
+          val row = Row.fromSeq((1 to metadata.getColumnCount).map(i => 
rs.getObject(i)))
+          rows.append(row)
+        }
+      }
+      rows.map(_.mkString("\t")).toSeq
+    } catch {
+      case e: Throwable => Seq(e.toString)
+    }
+  }
+
+  def dropTable(tableName: String): Unit = {
+    val dropTableSql = s"DROP TABLE IF EXISTS $tableName"

Review Comment:
   Hmm, I wasn't too concerned about this/didn't think this through too much 
because this is only used in testing. Should I be? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to