asfgit closed pull request #22903: [SPARK-24196][SQL] Implement Spark's own
GetSchemasOperation
URL: https://github.com/apache/spark/pull/22903
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java
b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java
index d6f6280f1c398..3516bc2ba242c 100644
---
a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java
+++
b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java
@@ -41,7 +41,7 @@
.addStringColumn("TABLE_SCHEM", "Schema name.")
.addStringColumn("TABLE_CATALOG", "Catalog name.");
- private RowSet rowSet;
+ protected RowSet rowSet;
protected GetSchemasOperation(HiveSession parentSession,
String catalogName, String schemaName) {
diff --git
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala
new file mode 100644
index 0000000000000..d585049c28e33
--- /dev/null
+++
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import
org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType
+import org.apache.hive.service.cli._
+import org.apache.hive.service.cli.operation.GetSchemasOperation
+import
org.apache.hive.service.cli.operation.MetadataOperation.DEFAULT_HIVE_CATALOG
+import org.apache.hive.service.cli.session.HiveSession
+
+import org.apache.spark.sql.SQLContext
+
+/**
+ * Spark's own GetSchemasOperation
+ *
+ * @param sqlContext SQLContext to use
+ * @param parentSession a HiveSession from SessionManager
+ * @param catalogName catalog name. null if not applicable.
+ * @param schemaName database name, null or a concrete database name
+ */
+private[hive] class SparkGetSchemasOperation(
+ sqlContext: SQLContext,
+ parentSession: HiveSession,
+ catalogName: String,
+ schemaName: String)
+ extends GetSchemasOperation(parentSession, catalogName, schemaName) {
+
+ override def runInternal(): Unit = {
+ setState(OperationState.RUNNING)
+ // Always use the latest class loader provided by executionHive's state.
+ val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader
+ Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
+
+ if (isAuthV2Enabled) {
+ val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName"
+ authorizeMetaGets(HiveOperationType.GET_TABLES, null, cmdStr)
+ }
+
+ try {
+ val schemaPattern = convertSchemaPattern(schemaName)
+ sqlContext.sessionState.catalog.listDatabases(schemaPattern).foreach {
dbName =>
+ rowSet.addRow(Array[AnyRef](dbName, DEFAULT_HIVE_CATALOG))
+ }
+ setState(OperationState.FINISHED)
+ } catch {
+ case e: HiveSQLException =>
+ setState(OperationState.ERROR)
+ throw e
+ }
+ }
+}
diff --git
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index bf7c01f60fb5c..85b6c7134755b 100644
---
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -21,13 +21,13 @@ import java.util.{Map => JMap}
import java.util.concurrent.ConcurrentHashMap
import org.apache.hive.service.cli._
-import org.apache.hive.service.cli.operation.{ExecuteStatementOperation,
Operation, OperationManager}
+import org.apache.hive.service.cli.operation.{ExecuteStatementOperation,
GetSchemasOperation, Operation, OperationManager}
import org.apache.hive.service.cli.session.HiveSession
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveUtils
-import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils,
SparkExecuteStatementOperation}
+import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils,
SparkExecuteStatementOperation, SparkGetSchemasOperation}
import org.apache.spark.sql.internal.SQLConf
/**
@@ -63,6 +63,19 @@ private[thriftserver] class SparkSQLOperationManager()
operation
}
+ override def newGetSchemasOperation(
+ parentSession: HiveSession,
+ catalogName: String,
+ schemaName: String): GetSchemasOperation = synchronized {
+ val sqlContext = sessionToContexts.get(parentSession.getSessionHandle)
+ require(sqlContext != null, s"Session handle:
${parentSession.getSessionHandle} has not been" +
+ " initialized or had already closed.")
+ val operation = new SparkGetSchemasOperation(sqlContext, parentSession,
catalogName, schemaName)
+ handleToOperation.put(operation.getHandle, operation)
+ logDebug(s"Created GetSchemasOperation with session=$parentSession.")
+ operation
+ }
+
def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit
= {
val iterator = confMap.entrySet().iterator()
while (iterator.hasNext) {
diff --git
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 70eb28cdd0c64..f9509aed4aaab 100644
---
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -818,6 +818,22 @@ abstract class HiveThriftJdbcTest extends
HiveThriftServer2Test {
}
}
+ def withDatabase(dbNames: String*)(fs: (Statement => Unit)*) {
+ val user = System.getProperty("user.name")
+ val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user,
"") }
+ val statements = connections.map(_.createStatement())
+
+ try {
+ statements.zip(fs).foreach { case (s, f) => f(s) }
+ } finally {
+ dbNames.foreach { name =>
+ statements(0).execute(s"DROP DATABASE IF EXISTS $name")
+ }
+ statements.foreach(_.close())
+ connections.foreach(_.close())
+ }
+ }
+
def withJdbcStatement(tableNames: String*)(f: Statement => Unit) {
withMultipleConnectionJdbcStatement(tableNames: _*)(f)
}
diff --git
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala
new file mode 100644
index 0000000000000..9a997ae01df9d
--- /dev/null
+++
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.util.Properties
+
+import org.apache.hive.jdbc.{HiveConnection, HiveQueryResultSet, Utils =>
JdbcUtils}
+import org.apache.hive.service.auth.PlainSaslHelper
+import org.apache.hive.service.cli.thrift._
+import org.apache.thrift.protocol.TBinaryProtocol
+import org.apache.thrift.transport.TSocket
+
+class SparkMetadataOperationSuite extends HiveThriftJdbcTest {
+
+ override def mode: ServerMode.Value = ServerMode.binary
+
+ test("Spark's own GetSchemasOperation(SparkGetSchemasOperation)") {
+ def testGetSchemasOperation(
+ catalog: String,
+ schemaPattern: String)(f: HiveQueryResultSet => Unit): Unit = {
+ val rawTransport = new TSocket("localhost", serverPort)
+ val connection = new
HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties)
+ val user = System.getProperty("user.name")
+ val transport = PlainSaslHelper.getPlainTransport(user, "anonymous",
rawTransport)
+ val client = new TCLIService.Client(new TBinaryProtocol(transport))
+ transport.open()
+ var rs: HiveQueryResultSet = null
+ try {
+ val openResp = client.OpenSession(new TOpenSessionReq)
+ val sessHandle = openResp.getSessionHandle
+ val schemaReq = new TGetSchemasReq(sessHandle)
+
+ if (catalog != null) {
+ schemaReq.setCatalogName(catalog)
+ }
+
+ if (schemaPattern == null) {
+ schemaReq.setSchemaName("%")
+ } else {
+ schemaReq.setSchemaName(schemaPattern)
+ }
+
+ val schemaResp = client.GetSchemas(schemaReq)
+ JdbcUtils.verifySuccess(schemaResp.getStatus)
+
+ rs = new HiveQueryResultSet.Builder(connection)
+ .setClient(client)
+ .setSessionHandle(sessHandle)
+ .setStmtHandle(schemaResp.getOperationHandle)
+ .build()
+ f(rs)
+ } finally {
+ rs.close()
+ connection.close()
+ transport.close()
+ rawTransport.close()
+ }
+ }
+
+ def checkResult(dbNames: Seq[String], rs: HiveQueryResultSet): Unit = {
+ if (dbNames.nonEmpty) {
+ for (i <- dbNames.indices) {
+ assert(rs.next())
+ assert(rs.getString("TABLE_SCHEM") === dbNames(i))
+ }
+ } else {
+ assert(!rs.next())
+ }
+ }
+
+ withDatabase("db1", "db2") { statement =>
+ Seq("CREATE DATABASE db1", "CREATE DATABASE
db2").foreach(statement.execute)
+
+ testGetSchemasOperation(null, "%") { rs =>
+ checkResult(Seq("db1", "db2"), rs)
+ }
+ testGetSchemasOperation(null, "db1") { rs =>
+ checkResult(Seq("db1"), rs)
+ }
+ testGetSchemasOperation(null, "db_not_exist") { rs =>
+ checkResult(Seq.empty, rs)
+ }
+ testGetSchemasOperation(null, "db*") { rs =>
+ checkResult(Seq("db1", "db2"), rs)
+ }
+ }
+ }
+}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]