Github user liancheng commented on a diff in the pull request:

    https://github.com/apache/spark/pull/4720#discussion_r25142433
  
    --- Diff: 
sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
 ---
    @@ -0,0 +1,403 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.hive.thriftserver
    +
    +import java.io.File
    +import java.sql.{Date, DriverManager, Statement}
    +
    +import scala.collection.mutable.ArrayBuffer
    +import scala.concurrent.duration._
    +import scala.concurrent.{Await, Promise}
    +import scala.sys.process.{Process, ProcessLogger}
    +import scala.util.{Random, Try}
    +
    +import org.apache.hadoop.hive.conf.HiveConf.ConfVars
    +import org.apache.hive.jdbc.HiveDriver
    +import org.apache.hive.service.auth.PlainSaslHelper
    +import org.apache.hive.service.cli.GetInfoType
    +import org.apache.hive.service.cli.thrift.TCLIService.Client
    +import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient
    +import org.apache.thrift.protocol.TBinaryProtocol
    +import org.apache.thrift.transport.TSocket
    +import org.scalatest.{BeforeAndAfterAll, FunSuite}
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.sql.catalyst.util
    +import org.apache.spark.sql.hive.HiveShim
    +
    +object TestData {
    +  def getTestDataFilePath(name: String) = {
    +    
Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name")
    +  }
    +
    +  val smallKv = getTestDataFilePath("small_kv.txt")
    +  val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt")
    +}
    +
    +class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
    +  override def mode = ServerMode.binary
    +
    +  private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): 
Unit = {
    +    // Transport creation logics below mimics 
HiveConnection.createBinaryTransport
    +    val rawTransport = new TSocket("localhost", serverPort)
    +    val user = System.getProperty("user.name")
    +    val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", 
rawTransport)
    +    val protocol = new TBinaryProtocol(transport)
    +    val client = new ThriftCLIServiceClient(new Client(protocol))
    +
    +    transport.open()
    +    try f(client) finally transport.close()
    +  }
    +
    +  test("GetInfo Thrift API") {
    +    withCLIServiceClient { client =>
    +      val user = System.getProperty("user.name")
    +      val sessionHandle = client.openSession(user, "")
    +
    +      assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
    +        client.getInfo(sessionHandle, 
GetInfoType.CLI_DBMS_NAME).getStringValue
    +      }
    +
    +      assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
    +        client.getInfo(sessionHandle, 
GetInfoType.CLI_SERVER_NAME).getStringValue
    +      }
    +
    +      assertResult(true, "Spark version shouldn't be \"Unknown\"") {
    +        val version = client.getInfo(sessionHandle, 
GetInfoType.CLI_DBMS_VER).getStringValue
    +        logInfo(s"Spark version: $version")
    +        version != "Unknown"
    +      }
    +    }
    +  }
    +
    +  test("JDBC query execution") {
    +    withJdbcStatement { statement =>
    +      val queries = Seq(
    +        "SET spark.sql.shuffle.partitions=3",
    +        "DROP TABLE IF EXISTS test",
    +        "CREATE TABLE test(key INT, val STRING)",
    +        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO 
TABLE test",
    +        "CACHE TABLE test")
    +
    +      queries.foreach(statement.execute)
    +
    +      assertResult(5, "Row count mismatch") {
    +        val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
    +        resultSet.next()
    +        resultSet.getInt(1)
    +      }
    +    }
    +  }
    +
    +  test("Checks Hive version") {
    +    withJdbcStatement { statement =>
    +      val resultSet = statement.executeQuery("SET spark.sql.hive.version")
    +      resultSet.next()
    +      assert(resultSet.getString(1) === 
s"spark.sql.hive.version=${HiveShim.version}")
    +    }
    +  }
    +
    +  test("SPARK-3004 regression: result set containing NULL") {
    +    withJdbcStatement { statement =>
    +      val queries = Seq(
    +        "DROP TABLE IF EXISTS test_null",
    +        "CREATE TABLE test_null(key INT, val STRING)",
    +        s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE 
INTO TABLE test_null")
    +
    +      queries.foreach(statement.execute)
    +
    +      val resultSet = statement.executeQuery("SELECT * FROM test_null 
WHERE key IS NULL")
    +
    +      (0 until 5).foreach { _ =>
    +        resultSet.next()
    +        assert(resultSet.getInt(1) === 0)
    +        assert(resultSet.wasNull())
    +      }
    +
    +      assert(!resultSet.next())
    +    }
    +  }
    +
    +  test("SPARK-4292 regression: result set iterator issue") {
    +    withJdbcStatement { statement =>
    +      val queries = Seq(
    +        "DROP TABLE IF EXISTS test_4292",
    +        "CREATE TABLE test_4292(key INT, val STRING)",
    +        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO 
TABLE test_4292")
    +
    +      queries.foreach(statement.execute)
    +
    +      val resultSet = statement.executeQuery("SELECT key FROM test_4292")
    +
    +      Seq(238, 86, 311, 27, 165).foreach { key =>
    +        resultSet.next()
    +        assert(resultSet.getInt(1) === key)
    +      }
    +
    +      statement.executeQuery("DROP TABLE IF EXISTS test_4292")
    +    }
    +  }
    +
    +  test("SPARK-4309 regression: Date type support") {
    +    withJdbcStatement { statement =>
    +      val queries = Seq(
    +        "DROP TABLE IF EXISTS test_date",
    +        "CREATE TABLE test_date(key INT, value STRING)",
    +        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO 
TABLE test_date")
    +
    +      queries.foreach(statement.execute)
    +
    +      assertResult(Date.valueOf("2011-01-01")) {
    +        val resultSet = statement.executeQuery(
    +          "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1")
    +        resultSet.next()
    +        resultSet.getDate(1)
    +      }
    +    }
    +  }
    +
    +  test("SPARK-4407 regression: Complex type support") {
    +    withJdbcStatement { statement =>
    +      val queries = Seq(
    +        "DROP TABLE IF EXISTS test_map",
    +        "CREATE TABLE test_map(key INT, value STRING)",
    +        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO 
TABLE test_map")
    +
    +      queries.foreach(statement.execute)
    +
    +      assertResult("""{238:"val_238"}""") {
    +        val resultSet = statement.executeQuery("SELECT MAP(key, value) 
FROM test_map LIMIT 1")
    +        resultSet.next()
    +        resultSet.getString(1)
    +      }
    +
    +      assertResult("""["238","val_238"]""") {
    +        val resultSet = statement.executeQuery(
    +          "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1")
    +        resultSet.next()
    +        resultSet.getString(1)
    +      }
    +    }
    +  }
    +}
    +
    +class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
    +  override def mode = ServerMode.http
    +
    +  test("JDBC query execution") {
    +    withJdbcStatement { statement =>
    +      val queries = Seq(
    +        "SET spark.sql.shuffle.partitions=3",
    +        "DROP TABLE IF EXISTS test",
    +        "CREATE TABLE test(key INT, val STRING)",
    +        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO 
TABLE test",
    +        "CACHE TABLE test")
    +
    +      queries.foreach(statement.execute)
    +
    +      assertResult(5, "Row count mismatch") {
    +        val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
    +        resultSet.next()
    +        resultSet.getInt(1)
    +      }
    +    }
    +  }
    +
    +  test("Checks Hive version") {
    +    withJdbcStatement { statement =>
    +      val resultSet = statement.executeQuery("SET spark.sql.hive.version")
    +      resultSet.next()
    +      assert(resultSet.getString(1) === 
s"spark.sql.hive.version=${HiveShim.version}")
    +    }
    +  }
    +}
    +
    +object ServerMode extends Enumeration {
    +  val binary, http = Value
    +}
    +
    +abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
    +  Class.forName(classOf[HiveDriver].getCanonicalName)
    +
    +  private def jdbcUri = if (mode == ServerMode.http) {
    +    s"""jdbc:hive2://localhost:$serverPort/
    +       |default?
    +       |hive.server2.transport.mode=http;
    +       |hive.server2.thrift.http.path=cliservice
    +     """.stripMargin.split("\n").mkString.trim
    +  } else {
    +    s"jdbc:hive2://localhost:$serverPort/"
    +  }
    +
    +  protected def withJdbcStatement(f: Statement => Unit): Unit = {
    +    val connection = DriverManager.getConnection(jdbcUri, user, "")
    +    val statement = connection.createStatement()
    +
    +    try f(statement) finally {
    +      statement.close()
    +      connection.close()
    +    }
    +  }
    +}
    +
    +abstract class HiveThriftServer2Test extends FunSuite with 
BeforeAndAfterAll with Logging {
    +  def mode: ServerMode.Value
    +
    +  private val CLASS_NAME = 
HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")
    +  private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to "
    +
    +  private val startScript = 
"../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
    +  private val stopScript = 
"../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
    +
    +  private var listeningPort: Int = _
    +  protected def serverPort: Int = listeningPort
    +
    +  protected def user = System.getProperty("user.name")
    +
    +  private var warehousePath: File = _
    +  private var metastorePath: File = _
    +  private def metastoreJdbcUri = 
s"jdbc:derby:;databaseName=$metastorePath;create=true"
    +
    +  private var logPath: File = _
    +  private var logTailingProcess: Process = _
    +  private var diagnosisBuffer: ArrayBuffer[String] = 
ArrayBuffer.empty[String]
    +
    +  private def serverStartCommand(port: Int) = {
    +    val portConf = if (mode == ServerMode.binary) {
    +      ConfVars.HIVE_SERVER2_THRIFT_PORT
    +    } else {
    +      ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT
    +    }
    +
    +    s"""$startScript
    +       |  --master local
    +       |  --hiveconf hive.root.logger=INFO,console
    +       |  --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
    +       |  --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
    +       |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
    +       |  --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode
    +       |  --hiveconf $portConf=$port
    +       |  --driver-class-path ${sys.props("java.class.path")}
    +       |  --conf spark.ui.enabled=false
    +     """.stripMargin.split("\\s+").toSeq
    +  }
    +
    +  private def startThriftServer(port: Int, attempt: Int) = {
    +    warehousePath = util.getTempFilePath("warehouse")
    +    metastorePath = util.getTempFilePath("metastore")
    +    logPath = null
    +    logTailingProcess = null
    +
    +    val command = serverStartCommand(port)
    +
    +    diagnosisBuffer ++=
    +      s"""
    +         |### Attempt $attempt ###
    +         |HiveThriftServer2 command line: $command
    +         |Listening port: $port
    +         |System user: $user
    +       """.stripMargin.split("\n")
    +
    +    logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, 
attempt=$attempt")
    +
    +    logPath = Process(command, None, "SPARK_TESTING" -> 
"0").lines.collectFirst {
    +      case line if line.contains(LOG_FILE_MARK) => new 
File(line.drop(LOG_FILE_MARK.length))
    +    }.getOrElse {
    +      throw new RuntimeException("Failed to find HiveThriftServer2 log 
file.")
    +    }
    --- End diff --
    
    The `start-thriftserver.sh` script delegates to `spark-daemon.sh` and 
terminates pretty soon (about 1s). I found using `ProcessBuilder.lines` is much 
simpler and more straightforward than using a `ProcessLogger`.
    
    `start-thriftserver.sh` itself may fail if there has already been a started 
server instance within the same Spark home directory. However, right now we 
only start a single instance a time.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to