http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/7d8fa69f/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyCLIService.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyCLIService.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyCLIService.scala new file mode 100644 index 0000000..5289354 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyCLIService.scala @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver + +import java.io.IOException +import java.util +import java.util.concurrent.{CancellationException, ExecutionException, TimeoutException, TimeUnit} +import javax.security.auth.login.LoginException + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.common.log.ProgressMonitor +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.parse.ParseUtils +import org.apache.hadoop.hive.shims.Utils +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hive.service.{CompositeService, ServiceException} +import org.apache.hive.service.auth.HiveAuthFactory +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.Operation +import org.apache.hive.service.rpc.thrift.{TOperationHandle, TProtocolVersion} + +import org.apache.livy.{LIVY_VERSION, Logging} + +class LivyCLIService(server: LivyThriftServer) + extends CompositeService(classOf[LivyCLIService].getName) with ICLIService with Logging { + import LivyCLIService._ + + private var sessionManager: LivyThriftSessionManager = _ + private var defaultFetchRows: Int = _ + private var serviceUGI: UserGroupInformation = _ + private var httpUGI: UserGroupInformation = _ + + override def init(hiveConf: HiveConf): Unit = { + sessionManager = new LivyThriftSessionManager(server, hiveConf) + addService(sessionManager) + defaultFetchRows = + hiveConf.getIntVar(ConfVars.HIVE_SERVER2_THRIFT_RESULTSET_DEFAULT_FETCH_SIZE) + // If the hadoop cluster is secure, do a kerberos login for the service from the keytab + if (UserGroupInformation.isSecurityEnabled) { + try { + serviceUGI = Utils.getUGI + } catch { + case e: IOException => + throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + case e: LoginException => + throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + } + // Also try creating a UGI object for the SPNego principal + val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_PRINCIPAL) + val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_KEYTAB) + if (principal.isEmpty || keyTabFile.isEmpty) { + info(s"SPNego httpUGI not created, SPNegoPrincipal: $principal, ketabFile: $keyTabFile") + } else try { + httpUGI = HiveAuthFactory.loginFromSpnegoKeytabAndReturnUGI(hiveConf) + info("SPNego httpUGI successfully created.") + } catch { + case e: IOException => + warn("SPNego httpUGI creation failed: ", e) + } + } + super.init(hiveConf) + } + + def getServiceUGI: UserGroupInformation = this.serviceUGI + + def getHttpUGI: UserGroupInformation = this.httpUGI + + def getSessionManager: LivyThriftSessionManager = sessionManager + + @throws[HiveSQLException] + override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = { + getInfoType match { + case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Livy JDBC") + case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Livy JDBC") + case GetInfoType.CLI_DBMS_VER => new GetInfoValue(LIVY_VERSION) + // below values are copied from Hive + case GetInfoType.CLI_MAX_COLUMN_NAME_LEN => new GetInfoValue(128) + case GetInfoType.CLI_MAX_SCHEMA_NAME_LEN => new GetInfoValue(128) + case GetInfoType.CLI_MAX_TABLE_NAME_LEN => new GetInfoValue(128) + case GetInfoType.CLI_ODBC_KEYWORDS => + new GetInfoValue(ParseUtils.getKeywords(LivyCLIService.ODBC_KEYWORDS)) + case _ => throw new HiveSQLException(s"Unrecognized GetInfoType value: $getInfoType") + } + } + + @throws[HiveSQLException] + def openSession( + protocol: TProtocolVersion, + username: String, + password: String, + ipAddress: String, + configuration: util.Map[String, String]): SessionHandle = { + val sessionHandle = sessionManager.openSession( + protocol, username, password, ipAddress, configuration, false, null) + debug(sessionHandle + ": openSession()") + sessionHandle + } + + @throws[HiveSQLException] + def openSessionWithImpersonation( + protocol: TProtocolVersion, + username: String, + password: String, + ipAddress: String, + configuration: util.Map[String, String], + delegationToken: String): SessionHandle = { + val sessionHandle = sessionManager.openSession( + protocol, username, password, ipAddress, configuration, true, delegationToken) + debug(sessionHandle + ": openSession()") + sessionHandle + } + + @throws[HiveSQLException] + override def openSession( + username: String, + password: String, + configuration: util.Map[String, String]): SessionHandle = { + val sessionHandle = sessionManager.openSession( + SERVER_VERSION, username, password, null, configuration, false, null) + debug(sessionHandle + ": openSession()") + sessionHandle + } + + @throws[HiveSQLException] + override def openSessionWithImpersonation( + username: String, + password: String, + configuration: util.Map[String, String], delegationToken: String): SessionHandle = { + val sessionHandle = sessionManager.openSession( + SERVER_VERSION, username, password, null, configuration, true, delegationToken) + debug(sessionHandle + ": openSession()") + sessionHandle + } + + @throws[HiveSQLException] + override def closeSession(sessionHandle: SessionHandle): Unit = { + sessionManager.closeSession(sessionHandle) + debug(sessionHandle + ": closeSession()") + } + + @throws[HiveSQLException] + override def executeStatement( + sessionHandle: SessionHandle, + statement: String, + confOverlay: util.Map[String, String]): OperationHandle = { + executeStatement(sessionHandle, statement, confOverlay, 0) + } + + /** + * Execute statement on the server with a timeout. This is a blocking call. + */ + @throws[HiveSQLException] + override def executeStatement( + sessionHandle: SessionHandle, + statement: String, + confOverlay: util.Map[String, String], + queryTimeout: Long): OperationHandle = { + val opHandle: OperationHandle = sessionManager.operationManager.executeStatement( + sessionHandle, statement, confOverlay, runAsync = false, queryTimeout) + debug(sessionHandle + ": executeStatement()") + opHandle + } + + @throws[HiveSQLException] + override def executeStatementAsync( + sessionHandle: SessionHandle, + statement: String, + confOverlay: util.Map[String, String]): OperationHandle = { + executeStatementAsync(sessionHandle, statement, confOverlay, 0) + } + + /** + * Execute statement asynchronously on the server with a timeout. This is a non-blocking call + */ + @throws[HiveSQLException] + override def executeStatementAsync( + sessionHandle: SessionHandle, + statement: String, + confOverlay: util.Map[String, String], + queryTimeout: Long): OperationHandle = { + val opHandle = sessionManager.operationManager.executeStatement( + sessionHandle, statement, confOverlay, runAsync = true, queryTimeout) + debug(sessionHandle + ": executeStatementAsync()") + opHandle + } + + @throws[HiveSQLException] + override def getTypeInfo(sessionHandle: SessionHandle): OperationHandle = { + debug(sessionHandle + ": getTypeInfo()") + sessionManager.operationManager.getTypeInfo(sessionHandle) + } + + @throws[HiveSQLException] + override def getCatalogs(sessionHandle: SessionHandle): OperationHandle = { + debug(sessionHandle + ": getCatalogs()") + sessionManager.operationManager.getCatalogs(sessionHandle) + } + + @throws[HiveSQLException] + override def getSchemas( + sessionHandle: SessionHandle, + catalogName: String, + schemaName: String): OperationHandle = { + // TODO + throw new HiveSQLException("Operation GET_SCHEMAS is not yet supported") + } + + @throws[HiveSQLException] + override def getTables( + sessionHandle: SessionHandle, + catalogName: String, + schemaName: String, + tableName: String, + tableTypes: util.List[String]): OperationHandle = { + // TODO + throw new HiveSQLException("Operation GET_TABLES is not yet supported") + } + + @throws[HiveSQLException] + override def getTableTypes(sessionHandle: SessionHandle): OperationHandle = { + debug(sessionHandle + ": getTableTypes()") + sessionManager.operationManager.getTableTypes(sessionHandle) + } + + @throws[HiveSQLException] + override def getColumns( + sessionHandle: SessionHandle, + catalogName: String, + schemaName: String, + tableName: String, + columnName: String): OperationHandle = { + // TODO + throw new HiveSQLException("Operation GET_COLUMNS is not yet supported") + } + + @throws[HiveSQLException] + override def getFunctions( + sessionHandle: SessionHandle, + catalogName: String, + schemaName: String, + functionName: String): OperationHandle = { + // TODO + throw new HiveSQLException("Operation GET_FUNCTIONS is not yet supported") + } + + @throws[HiveSQLException] + override def getPrimaryKeys( + sessionHandle: SessionHandle, + catalog: String, + schema: String, + table: String): OperationHandle = { + // TODO + throw new HiveSQLException("Operation GET_PRIMARY_KEYS is not yet supported") + } + + @throws[HiveSQLException] + override def getCrossReference( + sessionHandle: SessionHandle, + primaryCatalog: String, + primarySchema: String, + primaryTable: String, + foreignCatalog: String, + foreignSchema: String, + foreignTable: String): OperationHandle = { + // TODO + throw new HiveSQLException("Operation GET_CROSS_REFERENCE is not yet supported") + } + + @throws[HiveSQLException] + override def getOperationStatus( + opHandle: OperationHandle, + getProgressUpdate: Boolean): OperationStatus = { + val operation: Operation = sessionManager.operationManager.getOperation(opHandle) + /** + * If this is a background operation run asynchronously, + * we block for a duration determined by a step function, before we return + * However, if the background operation is complete, we return immediately. + */ + if (operation.shouldRunAsync) { + val maxTimeout: Long = HiveConf.getTimeVar( + getHiveConf, + HiveConf.ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT, + TimeUnit.MILLISECONDS) + val elapsed: Long = System.currentTimeMillis - operation.getBeginTime + // A step function to increase the polling timeout by 500 ms every 10 sec, + // starting from 500 ms up to HIVE_SERVER2_LONG_POLLING_TIMEOUT + val timeout: Long = Math.min(maxTimeout, (elapsed / TimeUnit.SECONDS.toMillis(10) + 1) * 500) + try { + operation.getBackgroundHandle.get(timeout, TimeUnit.MILLISECONDS) + } catch { + case e: TimeoutException => + // No Op, return to the caller since long polling timeout has expired + trace(opHandle + ": Long polling timed out") + case e: CancellationException => + // The background operation thread was cancelled + trace(opHandle + ": The background operation was cancelled", e) + case e: ExecutionException => + // Note: Hive ops do not use the normal Future failure path, so this will not happen + // in case of actual failure; the Future will just be done. + // The background operation thread was aborted + warn(opHandle + ": The background operation was aborted", e) + case _: InterruptedException => + // No op, this thread was interrupted + // In this case, the call might return sooner than long polling timeout + } + } + val opStatus: OperationStatus = operation.getStatus + debug(opHandle + ": getOperationStatus()") + opStatus.setJobProgressUpdate(new JobProgressUpdate(ProgressMonitor.NULL)) + opStatus + } + + @throws[HiveSQLException] + override def cancelOperation(opHandle: OperationHandle): Unit = { + sessionManager.operationManager.cancelOperation(opHandle) + debug(opHandle + ": cancelOperation()") + } + + @throws[HiveSQLException] + override def closeOperation(opHandle: OperationHandle): Unit = { + sessionManager.operationManager.closeOperation(opHandle) + debug(opHandle + ": closeOperation") + } + + @throws[HiveSQLException] + override def getResultSetMetadata(opHandle: OperationHandle): TableSchema = { + debug(opHandle + ": getResultSetMetadata()") + sessionManager.operationManager.getOperation(opHandle).getResultSetSchema + } + + @throws[HiveSQLException] + override def fetchResults(opHandle: OperationHandle): RowSet = { + fetchResults( + opHandle, Operation.DEFAULT_FETCH_ORIENTATION, defaultFetchRows, FetchType.QUERY_OUTPUT) + } + + @throws[HiveSQLException] + override def fetchResults( + opHandle: OperationHandle, + orientation: FetchOrientation, + maxRows: Long, + fetchType: FetchType): RowSet = { + debug(opHandle + ": fetchResults()") + sessionManager.operationManager.fetchResults(opHandle, orientation, maxRows, fetchType) + } + + @throws[HiveSQLException] + override def getDelegationToken( + sessionHandle: SessionHandle, + authFactory: HiveAuthFactory, + owner: String, + renewer: String): String = { + throw new HiveSQLException("Operation not yet supported.") + } + + @throws[HiveSQLException] + override def setApplicationName(sh: SessionHandle, value: String): Unit = { + throw new HiveSQLException("Operation not yet supported.") + } + + override def cancelDelegationToken( + sessionHandle: SessionHandle, + authFactory: HiveAuthFactory, + tokenStr: String): Unit = { + throw new HiveSQLException("Operation not yet supported.") + } + + override def renewDelegationToken( + sessionHandle: SessionHandle, + authFactory: HiveAuthFactory, + tokenStr: String): Unit = { + throw new HiveSQLException("Operation not yet supported.") + } + + @throws[HiveSQLException] + override def getQueryId(opHandle: TOperationHandle): String = { + throw new HiveSQLException("Operation not yet supported.") + } +} + + +object LivyCLIService { + val SERVER_VERSION: TProtocolVersion = TProtocolVersion.values().last + + // scalastyle:off line.size.limit + // From https://docs.microsoft.com/en-us/sql/t-sql/language-elements/reserved-keywords-transact-sql#odbc-reserved-keywords + // scalastyle:on line.size.limit + private val ODBC_KEYWORDS = Set("ABSOLUTE", "ACTION", "ADA", "ADD", "ALL", "ALLOCATE", "ALTER", + "AND", "ANY", "ARE", "AS", "ASC", "ASSERTION", "AT", "AUTHORIZATION", "AVG", "BEGIN", + "BETWEEN", "BIT_LENGTH", "BIT", "BOTH", "BY", "CASCADE", "CASCADED", "CASE", "CAST", "CATALOG", + "CHAR_LENGTH", "CHAR", "CHARACTER_LENGTH", "CHARACTER", "CHECK", "CLOSE", "COALESCE", + "COLLATE", "COLLATION", "COLUMN", "COMMIT", "CONNECT", "CONNECTION", "CONSTRAINT", + "CONSTRAINTS", "CONTINUE", "CONVERT", "CORRESPONDING", "COUNT", "CREATE", "CROSS", + "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", "CURRENT_USER", "CURRENT", "CURSOR", + "DATE", "DAY", "DEALLOCATE", "DEC", "DECIMAL", "DECLARE", "DEFAULT", "DEFERRABLE", "DEFERRED", + "DELETE", "DESC", "DESCRIBE", "DESCRIPTOR", "DIAGNOSTICS", "DISCONNECT", "DISTINCT", "DOMAIN", + "DOUBLE", "DROP", "ELSE", "END", "ESCAPE", "EXCEPT", "EXCEPTION", "EXEC", "EXECUTE", "EXISTS", + "EXTERNAL", "EXTRACT", "FALSE", "FETCH", "FIRST", "FLOAT", "FOR", "FOREIGN", "FORTRAN", + "FOUND", "FROM", "FULL", "GET", "GLOBAL", "GO", "GOTO", "GRANT", "GROUP", "HAVING", "HOUR", + "IDENTITY", "IMMEDIATE", "IN", "INCLUDE", "INDEX", "INDICATOR", "INITIALLY", "INNER", "INPUT", + "INSENSITIVE", "INSERT", "INT", "INTEGER", "INTERSECT", "INTERVAL", "INTO", "IS", "ISOLATION", + "JOIN", "KEY", "LANGUAGE", "LAST", "LEADING", "LEFT", "LEVEL", "LIKE", "LOCAL", "LOWER", + "MATCH", "MAX", "MIN", "MINUTE", "MODULE", "MONTH", "NAMES", "NATIONAL", "NATURAL", "NCHAR", + "NEXT", "NO", "NONE", "NOT", "NULL", "NULLIF", "NUMERIC", "OCTET_LENGTH", "OF", "ON", "ONLY", + "OPEN", "OPTION", "OR", "ORDER", "OUTER", "OUTPUT", "OVERLAPS", "PAD", "PARTIAL", "PASCAL", + "POSITION", "PRECISION", "PREPARE", "PRESERVE", "PRIMARY", "PRIOR", "PRIVILEGES", "PROCEDURE", + "PUBLIC", "READ", "REAL", "REFERENCES", "RELATIVE", "RESTRICT", "REVOKE", "RIGHT", "ROLLBACK", + "ROWS", "SCHEMA", "SCROLL", "SECOND", "SECTION", "SELECT", "SESSION_USER", "SESSION", "SET", + "SIZE", "SMALLINT", "SOME", "SPACE", "SQL", "SQLCA", "SQLCODE", "SQLERROR", "SQLSTATE", + "SQLWARNING", "SUBSTRING", "SUM", "SYSTEM_USER", "TABLE", "TEMPORARY", "THEN", "TIME", + "TIMESTAMP", "TIMEZONE_HOUR", "TIMEZONE_MINUTE", "TO", "TRAILING", "TRANSACTION", "TRANSLATE", + "TRANSLATION", "TRIM", "TRUE", "UNION", "UNIQUE", "UNKNOWN", "UPDATE", "UPPER", "USAGE", + "USER", "USING", "VALUE", "VALUES", "VARCHAR", "VARYING", "VIEW", "WHEN", "WHENEVER", "WHERE", + "WITH", "WORK", "WRITE", "YEAR", "ZONE").asJava +}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/7d8fa69f/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyExecuteStatementOperation.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyExecuteStatementOperation.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyExecuteStatementOperation.scala new file mode 100644 index 0000000..142eebf --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyExecuteStatementOperation.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver + +import java.security.PrivilegedExceptionAction +import java.util +import java.util.{Map => JMap} +import java.util.concurrent.{ConcurrentLinkedQueue, RejectedExecutionException} + +import scala.collection.mutable +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.hadoop.hive.serde2.thrift.ColumnBuffer +import org.apache.hadoop.hive.shims.Utils +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.Operation + +import org.apache.livy.Logging +import org.apache.livy.thriftserver.SessionStates._ +import org.apache.livy.thriftserver.rpc.RpcClient +import org.apache.livy.thriftserver.types.DataTypeUtils._ + +class LivyExecuteStatementOperation( + sessionHandle: SessionHandle, + statement: String, + confOverlay: JMap[String, String], + runInBackground: Boolean = true, + sessionManager: LivyThriftSessionManager) + extends Operation(sessionHandle, confOverlay, OperationType.EXECUTE_STATEMENT) + with Logging { + + /** + * Contains the messages which have to be sent to the client. + */ + private val operationMessages = new ConcurrentLinkedQueue[String] + + // The initialization need to be lazy in order not to block when the instance is created + private lazy val rpcClient = { + val sessionState = sessionManager.livySessionState(sessionHandle) + if (sessionState == CREATION_IN_PROGRESS) { + operationMessages.offer( + "Livy session has not yet started. Please wait for it to be ready...") + } + // This call is blocking, we are waiting for the session to be ready. + new RpcClient(sessionManager.getLivySession(sessionHandle)) + } + private var rowOffset = 0L + + private def statementId: String = getHandle.getHandleIdentifier.toString + + private def rpcClientValid: Boolean = + sessionManager.livySessionState(sessionHandle) == CREATION_SUCCESS && rpcClient.isValid + + override def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + validateDefaultFetchOrientation(order) + assertState(util.Arrays.asList(OperationState.FINISHED)) + setHasResultSet(true) + + // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int + val maxRows = maxRowsL.toInt + val jsonSchema = rpcClient.fetchResultSchema(statementId).get() + val types = getInternalTypes(jsonSchema) + val livyColumnResultSet = rpcClient.fetchResult(statementId, types, maxRows).get() + + val thriftColumns = livyColumnResultSet.columns.map { col => + new ColumnBuffer(toHiveThriftType(col.dataType), col.getNulls, col.getColumnValues) + } + val result = new ColumnBasedSet(tableSchemaFromSparkJson(jsonSchema).toTypeDescriptors, + thriftColumns.toList.asJava, + rowOffset) + livyColumnResultSet.columns.headOption.foreach { c => + rowOffset += c.size + } + result + } + + override def runInternal(): Unit = { + setState(OperationState.PENDING) + setHasResultSet(true) // avoid no resultset for async run + + if (!runInBackground) { + execute() + } else { + val livyServiceUGI = Utils.getUGI + + // Runnable impl to call runInternal asynchronously, + // from a different thread + val backgroundOperation = new Runnable() { + + override def run(): Unit = { + val doAsAction = new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + try { + execute() + } catch { + case e: HiveSQLException => + setOperationException(e) + error("Error running hive query: ", e) + } + } + } + + try { + livyServiceUGI.doAs(doAsAction) + } catch { + case e: Exception => + setOperationException(new HiveSQLException(e)) + error("Error running hive query as user : " + + livyServiceUGI.getShortUserName, e) + } + } + } + try { + // This submit blocks if no background threads are available to run this operation + val backgroundHandle = sessionManager.submitBackgroundOperation(backgroundOperation) + setBackgroundHandle(backgroundHandle) + } catch { + case rejected: RejectedExecutionException => + setState(OperationState.ERROR) + throw new HiveSQLException("The background threadpool cannot accept" + + " new task for execution, please retry the operation", rejected) + case NonFatal(e) => + error(s"Error executing query in background", e) + setState(OperationState.ERROR) + throw e + } + } + } + + protected def execute(): Unit = { + if (logger.isDebugEnabled) { + debug(s"Running query '$statement' with id $statementId (session = " + + s"${sessionHandle.getSessionId})") + } + setState(OperationState.RUNNING) + + try { + rpcClient.executeSql(sessionHandle, statementId, statement).get() + } catch { + case e: Throwable => + val currentState = getStatus.getState + info(s"Error executing query, currentState $currentState, ", e) + setState(OperationState.ERROR) + throw new HiveSQLException(e) + } + setState(OperationState.FINISHED) + } + + def close(): Unit = { + info(s"Close $statementId") + cleanup(OperationState.CLOSED) + } + + override def cancel(state: OperationState): Unit = { + info(s"Cancel $statementId with state $state") + cleanup(state) + } + + def getResultSetSchema: TableSchema = { + val tableSchema = tableSchemaFromSparkJson(rpcClient.fetchResultSchema(statementId).get()) + // Workaround for operations returning an empty schema (eg. CREATE, INSERT, ...) + if (tableSchema.getSize == 0) { + tableSchema.addStringColumn("Result", "") + } + tableSchema + } + + private def cleanup(state: OperationState) { + if (statementId != null && rpcClientValid) { + rpcClient.cleanupStatement(statementId).get() + } + setState(state) + } + + /** + * Returns the messages that should be sent to the client and removes them from the queue in + * order not to send them twice. + */ + def getOperationMessages: Seq[String] = { + def fetchNext(acc: mutable.ListBuffer[String]): Boolean = { + val m = operationMessages.poll() + if (m == null) { + false + } else { + acc += m + true + } + } + val res = new mutable.ListBuffer[String] + while (fetchNext(res)) {} + res + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/7d8fa69f/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyOperationManager.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyOperationManager.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyOperationManager.scala new file mode 100644 index 0000000..c71171a --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyOperationManager.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver + +import java.util +import java.util.{Map => JMap} +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.{GetCatalogsOperation, GetTableTypesOperation, GetTypeInfoOperation, Operation} + +import org.apache.livy.Logging + +class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManager) + extends Logging { + + private val handleToOperation = new ConcurrentHashMap[OperationHandle, Operation]() + private val sessionToOperationHandles = + new mutable.HashMap[SessionHandle, mutable.Set[OperationHandle]]() + + private def addOperation(operation: Operation, sessionHandle: SessionHandle): Unit = { + handleToOperation.put(operation.getHandle, operation) + sessionToOperationHandles.synchronized { + val set = sessionToOperationHandles.getOrElseUpdate(sessionHandle, + new mutable.HashSet[OperationHandle]) + set += operation.getHandle + } + } + + @throws[HiveSQLException] + private def removeOperation(operationHandle: OperationHandle): Operation = { + val operation = handleToOperation.remove(operationHandle) + if (operation == null) { + throw new HiveSQLException(s"Operation does not exist: $operationHandle") + } + val sessionHandle = operation.getSessionHandle + sessionToOperationHandles.synchronized { + sessionToOperationHandles(sessionHandle) -= operationHandle + if (sessionToOperationHandles(sessionHandle).isEmpty) { + sessionToOperationHandles.remove(sessionHandle) + } + } + operation + } + + def getOperations(sessionHandle: SessionHandle): Set[OperationHandle] = { + sessionToOperationHandles.synchronized { + sessionToOperationHandles(sessionHandle).toSet + } + } + + def getTimedOutOperations(sessionHandle: SessionHandle): Set[Operation] = { + val opHandles = getOperations(sessionHandle) + val currentTime = System.currentTimeMillis() + opHandles.flatMap { handle => + // Some operations may have finished and been removed since we got them. + Option(handleToOperation.get(handle)) + }.filter(_.isTimedOut(currentTime)) + } + + @throws[HiveSQLException] + def getOperation(operationHandle: OperationHandle): Operation = { + val operation = handleToOperation.get(operationHandle) + if (operation == null) { + throw new HiveSQLException(s"Invalid OperationHandle: $operationHandle") + } + operation + } + + def newExecuteStatementOperation( + sessionHandle: SessionHandle, + statement: String, + confOverlay: JMap[String, String], + runAsync: Boolean, + queryTimeout: Long): Operation = { + val op = new LivyExecuteStatementOperation( + sessionHandle, + statement, + confOverlay, + runAsync, + livyThriftSessionManager) + addOperation(op, sessionHandle) + debug(s"Created Operation for $statement with session=$sessionHandle, " + + s"runInBackground=$runAsync") + op + } + + def getOperationLogRowSet( + opHandle: OperationHandle, + orientation: FetchOrientation, + maxRows: Long): RowSet = { + val tableSchema = new TableSchema(LivyOperationManager.LOG_SCHEMA) + val session = livyThriftSessionManager.getSessionInfo(getOperation(opHandle).getSessionHandle) + val logs = RowSetFactory.create(tableSchema, session.protocolVersion, false) + + if (!livyThriftSessionManager.getHiveConf.getBoolVar( + ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { + warn("Try to get operation log when " + + ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED.varname + + " is false, no log will be returned. ") + } else { + // Get the operation log. This is implemented only for LivyExecuteStatementOperation + val operation = getOperation(opHandle) + if (operation.isInstanceOf[LivyExecuteStatementOperation]) { + val op = getOperation(opHandle).asInstanceOf[LivyExecuteStatementOperation] + op.getOperationMessages.foreach { l => + logs.addRow(Array(l)) + } + } + } + logs + } + + @throws[HiveSQLException] + def executeStatement( + sessionHandle: SessionHandle, + statement: String, + confOverlay: util.Map[String, String], + runAsync: Boolean, + queryTimeout: Long): OperationHandle = { + executeOperation(sessionHandle, { + newExecuteStatementOperation(sessionHandle, statement, confOverlay, runAsync, queryTimeout) + }) + } + + @throws[HiveSQLException] + private def executeOperation( + sessionHandle: SessionHandle, + operationCreator: => Operation): OperationHandle = { + var opHandle: OperationHandle = null + try { + val operation = operationCreator + opHandle = operation.getHandle + operation.run() + opHandle + } catch { + case e: HiveSQLException => + if (opHandle != null) { + closeOperation(opHandle) + } + throw e + } + } + + @throws[HiveSQLException] + def getTypeInfo(sessionHandle: SessionHandle): OperationHandle = { + executeOperation(sessionHandle, { + val op = new GetTypeInfoOperation(sessionHandle) + addOperation(op, sessionHandle) + op + }) + } + + @throws[HiveSQLException] + def getCatalogs(sessionHandle: SessionHandle): OperationHandle = { + executeOperation(sessionHandle, { + val op = new GetCatalogsOperation(sessionHandle) + addOperation(op, sessionHandle) + op + }) + } + + @throws[HiveSQLException] + def getTableTypes(sessionHandle: SessionHandle): OperationHandle = { + executeOperation(sessionHandle, { + val op = new GetTableTypesOperation(sessionHandle) + addOperation(op, sessionHandle) + op + }) + } + + /** + * Cancel the running operation unless it is already in a terminal state + */ + @throws[HiveSQLException] + def cancelOperation(opHandle: OperationHandle, errMsg: String): Unit = { + val operation = getOperation(opHandle) + val opState = operation.getStatus.getState + if (opState.isTerminal) { + // Cancel should be a no-op + debug(s"$opHandle: Operation is already aborted in state - $opState") + } else { + debug(s"$opHandle: Attempting to cancel from state - $opState") + val operationState = OperationState.CANCELED + operationState.setErrorMessage(errMsg) + operation.cancel(operationState) + } + } + + @throws[HiveSQLException] + def cancelOperation(opHandle: OperationHandle): Unit = { + cancelOperation(opHandle, "") + } + + @throws[HiveSQLException] + def closeOperation(opHandle: OperationHandle): Unit = { + info("Closing operation: " + opHandle) + val operation = removeOperation(opHandle) + operation.close() + } + + @throws[HiveSQLException] + def fetchResults( + opHandle: OperationHandle, + orientation: FetchOrientation, + maxRows: Long, + fetchType: FetchType): RowSet = { + if (fetchType == FetchType.QUERY_OUTPUT) { + getOperation(opHandle).getNextRowSet(orientation, maxRows) + } else { + getOperationLogRowSet(opHandle, orientation, maxRows) + } + } +} + +object LivyOperationManager { + val LOG_SCHEMA: Schema = { + val schema = new Schema + val fieldSchema = new FieldSchema + fieldSchema.setName("operation_log") + fieldSchema.setType("string") + schema.addToFieldSchemas(fieldSchema) + schema + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/7d8fa69f/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftServer.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftServer.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftServer.scala new file mode 100644 index 0000000..d34c1c0 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftServer.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hive.service.server.HiveServer2 +import org.scalatra.ScalatraServlet + +import org.apache.livy.{LivyConf, Logging} +import org.apache.livy.server.AccessManager +import org.apache.livy.server.interactive.InteractiveSession +import org.apache.livy.server.recovery.SessionStore +import org.apache.livy.sessions.InteractiveSessionManager + +/** + * The main entry point for the Livy thrift server leveraging HiveServer2. Starts up a + * `HiveThriftServer2` thrift server. + */ +object LivyThriftServer extends Logging { + + // Visible for testing + private[thriftserver] var thriftServerThread: Thread = _ + private var thriftServer: LivyThriftServer = _ + + private def hiveConf(livyConf: LivyConf): HiveConf = { + val conf = new HiveConf() + // Remove all configs coming from hive-site.xml which may be in the classpath for the Spark + // applications to run. + conf.getAllProperties.asScala.filter(_._1.startsWith("hive.")).foreach { case (key, _) => + conf.unset(key) + } + livyConf.asScala.foreach { + case nameAndValue if nameAndValue.getKey.startsWith("livy.hive") => + conf.set(nameAndValue.getKey.stripPrefix("livy."), nameAndValue.getValue) + case _ => // Ignore + } + conf + } + + def start( + livyConf: LivyConf, + livySessionManager: InteractiveSessionManager, + sessionStore: SessionStore, + accessManager: AccessManager): Unit = synchronized { + if (thriftServerThread == null) { + info("Starting LivyThriftServer") + val runThriftServer = new Runnable { + override def run(): Unit = { + try { + thriftServer = new LivyThriftServer( + livyConf, + livySessionManager, + sessionStore, + accessManager) + thriftServer.init(hiveConf(livyConf)) + thriftServer.start() + info("LivyThriftServer started") + } catch { + case e: Exception => + error("Error starting LivyThriftServer", e) + } + } + } + thriftServerThread = + new Thread(new ThreadGroup("thriftserver"), runThriftServer, "Livy-Thriftserver") + thriftServerThread.start() + } else { + error("Livy Thriftserver is already started") + } + } + + private[thriftserver] def getInstance: Option[LivyThriftServer] = { + Option(thriftServer) + } + + // Used in testing + def stopServer(): Unit = { + if (thriftServerThread != null) { + thriftServerThread.join() + } + thriftServerThread = null + thriftServer.stop() + thriftServer = null + } +} + + +class LivyThriftServer( + private[thriftserver] val livyConf: LivyConf, + private[thriftserver] val livySessionManager: InteractiveSessionManager, + private[thriftserver] val sessionStore: SessionStore, + private val accessManager: AccessManager) extends HiveServer2 { + override def init(hiveConf: HiveConf): Unit = { + this.cliService = new LivyCLIService(this) + super.init(hiveConf) + } + + private[thriftserver] def getSessionManager(): LivyThriftSessionManager = { + this.cliService.asInstanceOf[LivyCLIService].getSessionManager + } + + def isAllowedToUse(user: String, session: InteractiveSession): Boolean = { + session.owner == user || accessManager.checkModifyPermissions(user) + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/7d8fa69f/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala new file mode 100644 index 0000000..ec987c5 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala @@ -0,0 +1,635 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver + +import java.lang.reflect.UndeclaredThrowableException +import java.net.URI +import java.security.PrivilegedExceptionAction +import java.util +import java.util.{Date, Map => JMap, UUID} +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.{Await, Future} +import scala.concurrent.duration.Duration +import scala.concurrent.ExecutionContext.Implicits.global +import scala.util.{Failure, Success, Try} + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.shims.Utils +import org.apache.hive.service.CompositeService +import org.apache.hive.service.cli.{HiveSQLException, SessionHandle} +import org.apache.hive.service.rpc.thrift.TProtocolVersion +import org.apache.hive.service.server.ThreadFactoryWithGarbageCleanup + +import org.apache.livy.LivyConf +import org.apache.livy.Logging +import org.apache.livy.server.interactive.{CreateInteractiveRequest, InteractiveSession} +import org.apache.livy.sessions.Spark +import org.apache.livy.thriftserver.SessionStates._ +import org.apache.livy.thriftserver.rpc.RpcClient +import org.apache.livy.utils.LivySparkUtils + +class LivyThriftSessionManager(val server: LivyThriftServer, hiveConf: HiveConf) + extends CompositeService(classOf[LivyThriftSessionManager].getName) with Logging { + + private[thriftserver] val operationManager = new LivyOperationManager(this) + private val sessionHandleToLivySession = + new ConcurrentHashMap[SessionHandle, Future[InteractiveSession]]() + // A map which returns how many incoming connections are open for a Livy session. + // This map tracks only the sessions created by the Livy thriftserver and not those which have + // been created through the REST API, as those should not be stopped even though there are no + // more active connections. + private val managedLivySessionActiveUsers = new mutable.HashMap[Int, Int]() + + // Contains metadata about a session + private val sessionInfo = new ConcurrentHashMap[SessionHandle, SessionInfo]() + + // Map the number of incoming connections for IP, user. It is used in order to check + // that the configured limits are not exceeded. + private val connectionsCount = new ConcurrentHashMap[String, AtomicLong] + + // Timeout for a Spark session creation + private val maxSessionWait = Duration( + server.livyConf.getTimeAsMs(LivyConf.THRIFT_SESSION_CREATION_TIMEOUT), + scala.concurrent.duration.MILLISECONDS) + + // Flag indicating whether the Spark version being used supports the USE database statement + val supportUseDatabase: Boolean = { + val sparkVersion = server.livyConf.get(LivyConf.LIVY_SPARK_VERSION) + val (sparkMajorVersion, _) = LivySparkUtils.formatSparkVersion(sparkVersion) + sparkMajorVersion > 1 || server.livyConf.getBoolean(LivyConf.ENABLE_HIVE_CONTEXT) + } + + // Configs from Hive + private val userLimit = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_LIMIT_CONNECTIONS_PER_USER) + private val ipAddressLimit = + hiveConf.getIntVar(ConfVars.HIVE_SERVER2_LIMIT_CONNECTIONS_PER_IPADDRESS) + private val userIpAddressLimit = + hiveConf.getIntVar(ConfVars.HIVE_SERVER2_LIMIT_CONNECTIONS_PER_USER_IPADDRESS) + private val checkInterval = HiveConf.getTimeVar( + hiveConf, ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL, TimeUnit.MILLISECONDS) + private val sessionTimeout = HiveConf.getTimeVar( + hiveConf, ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT, TimeUnit.MILLISECONDS) + private val checkOperation = HiveConf.getBoolVar( + hiveConf, ConfVars.HIVE_SERVER2_IDLE_SESSION_CHECK_OPERATION) + + private var backgroundOperationPool: ThreadPoolExecutor = _ + + def getLivySession(sessionHandle: SessionHandle): InteractiveSession = { + val future = sessionHandleToLivySession.get(sessionHandle) + assert(future != null, s"Looking for not existing session: $sessionHandle.") + + if (!future.isCompleted) { + Try(Await.result(future, maxSessionWait)) match { + case Success(session) => session + case Failure(e) => throw e.getCause + } + } else { + future.value match { + case Some(Success(session)) => session + case Some(Failure(e)) => throw e.getCause + case None => throw new RuntimeException("Future cannot be None when it is completed") + } + } + } + + def livySessionId(sessionHandle: SessionHandle): Option[Int] = { + sessionHandleToLivySession.get(sessionHandle).value.filter(_.isSuccess).map(_.get.id) + } + + def livySessionState(sessionHandle: SessionHandle): SessionStates = { + sessionHandleToLivySession.get(sessionHandle).value match { + case Some(Success(_)) => CREATION_SUCCESS + case Some(Failure(_)) => CREATION_FAILED + case None => CREATION_IN_PROGRESS + } + } + + def onLivySessionOpened(livySession: InteractiveSession): Unit = { + server.livySessionManager.register(livySession) + synchronized { + managedLivySessionActiveUsers += livySession.id -> 0 + } + } + + def onUserSessionClosed(sessionHandle: SessionHandle, livySession: InteractiveSession): Unit = { + val closeSession = synchronized[Boolean] { + managedLivySessionActiveUsers.get(livySession.id) match { + case Some(1) => + // it was the last user, so we can close the LivySession + managedLivySessionActiveUsers -= livySession.id + true + case Some(activeUsers) => + managedLivySessionActiveUsers(livySession.id) = activeUsers - 1 + false + case None => + // This case can happen when we don't track the number of active users because the session + // has not been created in the thriftserver (ie. it has been created in the REST API). + false + } + } + if (closeSession) { + server.livySessionManager.delete(livySession) + } else { + // We unregister the session only if we don't close it, as it is unnecessary in that case + val rpcClient = new RpcClient(livySession) + try { + rpcClient.executeUnregisterSession(sessionHandle).get() + } catch { + case e: Exception => warn(s"Unable to unregister session $sessionHandle", e) + } + } + } + + /** + * If the user specified an existing sessionId to use, the corresponding session is returned, + * otherwise a new session is created and returned. + */ + private def getOrCreateLivySession( + sessionHandle: SessionHandle, + sessionId: Option[Int], + username: String, + createLivySession: () => InteractiveSession): InteractiveSession = { + sessionId match { + case Some(id) => + server.livySessionManager.get(id) match { + case None => + warn(s"InteractiveSession $id doesn't exist.") + throw new IllegalArgumentException(s"Session $id doesn't exist.") + case Some(session) if !server.isAllowedToUse(username, session) => + warn(s"$username has no modify permissions to InteractiveSession $id.") + throw new IllegalAccessException( + s"$username is not allowed to use InteractiveSession $id.") + case Some(session) => + if (session.state.isActive) { + info(s"Reusing Session $id for $sessionHandle.") + session + } else { + warn(s"InteractiveSession $id is not active anymore.") + throw new IllegalArgumentException(s"Session $id is not active anymore.") + } + } + case None => + createLivySession() + } + } + + /** + * Performs the initialization of the new Thriftserver session: + * - adds the Livy thrifserver JAR to the Spark application; + * - register the new Thriftserver session in the Spark application; + * - runs the initialization statements; + */ + private def initSession( + sessionHandle: SessionHandle, + livySession: InteractiveSession, + initStatements: List[String]): Unit = { + // Add the thriftserver jar to Spark application as we need to deserialize there the classes + // which handle the job submission. + // Note: if this is an already existing session, adding the JARs multiple times is not a + // problem as Spark ignores JARs which have already been added. + try { + livySession.addJar(LivyThriftSessionManager.thriftserverJarLocation(server.livyConf)) + } catch { + case e: java.util.concurrent.ExecutionException + if Option(e.getCause).forall(_.getMessage.contains("has already been uploaded")) => + // We have already uploaded the jar to this session, we can ignore this error + debug(e.getMessage, e) + } + + val rpcClient = new RpcClient(livySession) + rpcClient.executeRegisterSession(sessionHandle).get() + initStatements.foreach { statement => + val statementId = UUID.randomUUID().toString + try { + rpcClient.executeSql(sessionHandle, statementId, statement).get() + } finally { + Try(rpcClient.cleanupStatement(statementId).get()).failed.foreach { e => + error(s"Failed to close init operation $statementId", e) + } + } + } + } + + def openSession( + protocol: TProtocolVersion, + username: String, + password: String, + ipAddress: String, + sessionConf: JMap[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { + val sessionHandle = new SessionHandle(protocol) + incrementConnections(username, ipAddress, SessionInfo.getForwardedAddresses) + sessionInfo.put(sessionHandle, + new SessionInfo(username, ipAddress, SessionInfo.getForwardedAddresses, protocol)) + val (initStatements, createInteractiveRequest, sessionId) = + LivyThriftSessionManager.processSessionConf(sessionConf, supportUseDatabase) + val createLivySession = () => { + createInteractiveRequest.kind = Spark + val newSession = InteractiveSession.create( + server.livySessionManager.nextId(), + username, + None, + server.livyConf, + createInteractiveRequest, + server.sessionStore) + onLivySessionOpened(newSession) + newSession + } + val futureLivySession = Future { + val livyServiceUGI = Utils.getUGI + var livySession: InteractiveSession = null + try { + livyServiceUGI.doAs(new PrivilegedExceptionAction[InteractiveSession] { + override def run(): InteractiveSession = { + livySession = + getOrCreateLivySession(sessionHandle, sessionId, username, createLivySession) + synchronized { + managedLivySessionActiveUsers.get(livySession.id).foreach { numUsers => + managedLivySessionActiveUsers(livySession.id) = numUsers + 1 + } + } + initSession(sessionHandle, livySession, initStatements) + livySession + } + }) + } catch { + case e: UndeclaredThrowableException => + throw new ThriftSessionCreationException(Option(livySession), e.getCause) + case e: Throwable => + throw new ThriftSessionCreationException(Option(livySession), e) + } + } + sessionHandleToLivySession.put(sessionHandle, futureLivySession) + sessionHandle + } + + def closeSession(sessionHandle: SessionHandle): Unit = { + val removedSession = sessionHandleToLivySession.remove(sessionHandle) + val removedSessionInfo = sessionInfo.remove(sessionHandle) + try { + removedSession.value match { + case Some(Success(interactiveSession)) => + onUserSessionClosed(sessionHandle, interactiveSession) + case Some(Failure(e: ThriftSessionCreationException)) => + e.livySession.foreach(onUserSessionClosed(sessionHandle, _)) + case None => + removedSession.onComplete { + case Success(interactiveSession) => + onUserSessionClosed(sessionHandle, interactiveSession) + case Failure(e: ThriftSessionCreationException) => + e.livySession.foreach(onUserSessionClosed(sessionHandle, _)) + } + case _ => // We should never get here + } + } finally { + decrementConnections(removedSessionInfo) + } + } + + // Taken from Hive + override def init(hiveConf: HiveConf): Unit = { + createBackgroundOperationPool(hiveConf) + info("Connections limit are user: {} ipaddress: {} user-ipaddress: {}", + userLimit, ipAddressLimit, userIpAddressLimit) + super.init(hiveConf) + } + + // Taken from Hive + private def createBackgroundOperationPool(hiveConf: HiveConf): Unit = { + val poolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + info("HiveServer2: Background operation thread pool size: " + poolSize) + val poolQueueSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_WAIT_QUEUE_SIZE) + info("HiveServer2: Background operation thread wait queue size: " + poolQueueSize) + val keepAliveTime = HiveConf.getTimeVar( + hiveConf, ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME, TimeUnit.SECONDS) + info(s"HiveServer2: Background operation thread keepalive time: $keepAliveTime seconds") + // Create a thread pool with #poolSize threads + // Threads terminate when they are idle for more than the keepAliveTime + // A bounded blocking queue is used to queue incoming operations, if #operations > poolSize + val threadPoolName = "LivyServer2-Background-Pool" + val queue = new LinkedBlockingQueue[Runnable](poolQueueSize) + backgroundOperationPool = new ThreadPoolExecutor( + poolSize, + poolSize, + keepAliveTime, + TimeUnit.SECONDS, + queue, + new ThreadFactoryWithGarbageCleanup(threadPoolName)) + backgroundOperationPool.allowCoreThreadTimeOut(true) + } + + // Taken from Hive + override def start(): Unit = { + super.start() + if (checkInterval > 0) startTimeoutChecker() + } + + private val timeoutCheckerLock: Object = new Object + @volatile private var shutdown: Boolean = false + + // Taken from Hive + private def startTimeoutChecker(): Unit = { + val interval: Long = Math.max(checkInterval, 3000L) + // minimum 3 seconds + val timeoutChecker: Runnable = new Runnable() { + override def run(): Unit = { + sleepFor(interval) + while (!shutdown) { + val current: Long = System.currentTimeMillis + val iterator = sessionHandleToLivySession.entrySet().iterator() + while (iterator.hasNext && ! shutdown) { + val entry = iterator.next() + val sessionHandle = entry.getKey + entry.getValue.value.flatMap(_.toOption).foreach { livySession => + + if (sessionTimeout > 0 && livySession.lastActivity + sessionTimeout <= current && + (!checkOperation || getNoOperationTime(sessionHandle) > sessionTimeout)) { + warn(s"Session $sessionHandle is Timed-out (last access : " + + new Date(livySession.lastActivity) + ") and will be closed") + try { + closeSession(sessionHandle) + } catch { + case e: HiveSQLException => + warn(s"Exception is thrown closing session $sessionHandle", e) + } + } else { + val operations = operationManager.getTimedOutOperations(sessionHandle) + if (operations.nonEmpty) { + operations.foreach { op => + try { + warn(s"Operation ${op.getHandle} is timed-out and will be closed") + operationManager.closeOperation(op.getHandle) + } catch { + case e: Exception => + warn("Exception is thrown closing timed-out operation: " + op.getHandle, e) + } + } + } + } + } + } + sleepFor(interval) + } + } + + private def sleepFor(interval: Long): Unit = { + timeoutCheckerLock.synchronized { + try { + timeoutCheckerLock.wait(interval) + } catch { + case e: InterruptedException => + // Ignore, and break. + } + } + } + } + backgroundOperationPool.execute(timeoutChecker) + } + + // Taken from Hive + private def shutdownTimeoutChecker(): Unit = { + shutdown = true + timeoutCheckerLock.synchronized { timeoutCheckerLock.notify() } + } + + // Taken from Hive + override def stop(): Unit = { + super.stop() + shutdownTimeoutChecker() + if (backgroundOperationPool != null) { + backgroundOperationPool.shutdown() + val timeout = + hiveConf.getTimeVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT, TimeUnit.SECONDS) + try { + backgroundOperationPool.awaitTermination(timeout, TimeUnit.SECONDS) + } catch { + case e: InterruptedException => + warn("HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT = " + timeout + + " seconds has been exceeded. RUNNING background operations will be shut down", e) + } + backgroundOperationPool = null + } + } + + // Taken from Hive + @throws[HiveSQLException] + private def incrementConnections( + username: String, + ipAddress: String, + forwardedAddresses: util.List[String]): Unit = { + val clientIpAddress: String = getOriginClientIpAddress(ipAddress, forwardedAddresses) + val violation = anyViolations(username, clientIpAddress) + // increment the counters only when there are no violations + if (violation.isEmpty) { + if (trackConnectionsPerUser(username)) incrementConnectionsCount(username) + if (trackConnectionsPerIpAddress(clientIpAddress)) incrementConnectionsCount(clientIpAddress) + if (trackConnectionsPerUserIpAddress(username, clientIpAddress)) { + incrementConnectionsCount(username + ":" + clientIpAddress) + } + } else { + error(violation.get) + throw new HiveSQLException(violation.get) + } + } + + // Taken from Hive + private def incrementConnectionsCount(key: String): Unit = { + if (!connectionsCount.containsKey(key)) connectionsCount.get(key).incrementAndGet + else connectionsCount.put(key, new AtomicLong) + } + + // Taken from Hive + private def decrementConnectionsCount(key: String): Unit = { + if (!connectionsCount.containsKey(key)) connectionsCount.get(key).decrementAndGet + else connectionsCount.put(key, new AtomicLong) + } + + // Taken from Hive + private def getOriginClientIpAddress(ipAddress: String, forwardedAddresses: util.List[String]) = { + if (forwardedAddresses == null || forwardedAddresses.isEmpty) { + ipAddress + } else { + // order of forwarded ips per X-Forwarded-For http spec (client, proxy1, proxy2) + forwardedAddresses.get(0) + } + } + + // Taken from Hive + private def anyViolations(username: String, ipAddress: String): Option[String] = { + val userAndAddress = username + ":" + ipAddress + if (trackConnectionsPerUser(username) && !withinLimits(username, userLimit)) { + Some(s"Connection limit per user reached (user: $username limit: $userLimit)") + } else if (trackConnectionsPerIpAddress(ipAddress) && + !withinLimits(ipAddress, ipAddressLimit)) { + Some(s"Connection limit per ipaddress reached (ipaddress: $ipAddress limit: " + + s"$ipAddressLimit)") + } else if (trackConnectionsPerUserIpAddress(username, ipAddress) && + !withinLimits(userAndAddress, userIpAddressLimit)) { + Some(s"Connection limit per user:ipaddress reached (user:ipaddress: $userAndAddress " + + s"limit: $userIpAddressLimit)") + } else { + None + } + } + + // Taken from Hive + private def trackConnectionsPerUserIpAddress(username: String, ipAddress: String): Boolean = { + userIpAddressLimit > 0 && username != null && !username.isEmpty && ipAddress != null && + !ipAddress.isEmpty + } + + // Taken from Hive + private def trackConnectionsPerIpAddress(ipAddress: String): Boolean = { + ipAddressLimit > 0 && ipAddress != null && !ipAddress.isEmpty + } + + // Taken from Hive + private def trackConnectionsPerUser(username: String): Boolean = { + userLimit > 0 && username != null && !username.isEmpty + } + + // Taken from Hive + private def withinLimits(track: String, limit: Int): Boolean = { + !(connectionsCount.containsKey(track) && connectionsCount.get(track).intValue >= limit) + } + + private def decrementConnections(sessionInfo: SessionInfo): Unit = { + val username = sessionInfo.username + val clientIpAddress = getOriginClientIpAddress( + sessionInfo.ipAddress, sessionInfo.forwardedAddresses) + if (trackConnectionsPerUser(username)) { + decrementConnectionsCount(username) + } + if (trackConnectionsPerIpAddress(clientIpAddress)) { + decrementConnectionsCount(clientIpAddress) + } + if (trackConnectionsPerUserIpAddress(username, clientIpAddress)) { + decrementConnectionsCount(username + ":" + clientIpAddress) + } + } + + def submitBackgroundOperation(r: Runnable): util.concurrent.Future[_] = { + backgroundOperationPool.submit(r) + } + + def getNoOperationTime(sessionHandle: SessionHandle): Long = { + if (operationManager.getOperations(sessionHandle).isEmpty) { + System.currentTimeMillis() - getLivySession(sessionHandle).lastActivity + } else { + 0 + } + } + + def getSessions: Set[SessionHandle] = { + sessionInfo.keySet().asScala.toSet + } + + def getSessionInfo(sessionHandle: SessionHandle): SessionInfo = { + sessionInfo.get(sessionHandle) + } +} + +object LivyThriftSessionManager extends Logging { + // Users can explicitly set the Livy connection id they want to connect to using this hiveconf + // variable + private val livySessionIdConfigKey = "set:hiveconf:livy.server.sessionId" + private val livySessionConfRegexp = "set:hiveconf:livy.session.conf.(.*)".r + private val hiveVarPattern = "set:hivevar:(.*)".r + private val JAR_LOCATION = getClass.getProtectionDomain.getCodeSource.getLocation.toURI + + def thriftserverJarLocation(livyConf: LivyConf): URI = { + Option(livyConf.get(LivyConf.THRIFT_SERVER_JAR_LOCATION)).map(new URI(_)) + .getOrElse(JAR_LOCATION) + } + + private def convertConfValueToInt(key: String, value: String) = { + val res = Try(value.toInt) + if (res.isFailure) { + warn(s"Ignoring $key = $value as it is not a valid integer") + None + } else { + Some(res.get) + } + } + + private def processSessionConf( + sessionConf: JMap[String, String], + supportUseDatabase: Boolean): (List[String], CreateInteractiveRequest, Option[Int]) = { + if (null != sessionConf && !sessionConf.isEmpty) { + val statements = new mutable.ListBuffer[String] + val extraLivyConf = new mutable.ListBuffer[(String, String)] + val createInteractiveRequest = new CreateInteractiveRequest + sessionConf.asScala.foreach { + case (key, value) => + key match { + case v if v.startsWith("use:") && supportUseDatabase => + statements += s"use $value" + // Process session configs for Livy session creation request + case "set:hiveconf:livy.session.driverMemory" => + createInteractiveRequest.driverMemory = Some(value) + case "set:hiveconf:livy.session.driverCores" => + createInteractiveRequest.driverCores = convertConfValueToInt(key, value) + case "set:hiveconf:livy.session.executorMemory" => + createInteractiveRequest.executorMemory = Some(value) + case "set:hiveconf:livy.session.executorCores" => + createInteractiveRequest.executorCores = convertConfValueToInt(key, value) + case "set:hiveconf:livy.session.queue" => + createInteractiveRequest.queue = Some(value) + case "set:hiveconf:livy.session.name" => + createInteractiveRequest.name = Some(value) + case "set:hiveconf:livy.session.heartbeatTimeoutInSecond" => + convertConfValueToInt(key, value).foreach { heartbeatTimeoutInSecond => + createInteractiveRequest.heartbeatTimeoutInSecond = heartbeatTimeoutInSecond + } + case livySessionConfRegexp(livyConfKey) => extraLivyConf += (livyConfKey -> value) + // set the hivevars specified by the user + case hiveVarPattern(confKey) => statements += s"set hivevar:${confKey.trim}=$value" + case _ if key == livySessionIdConfigKey => // Ignore it, we handle it later + case _ => + info(s"Ignoring key: $key = '$value'") + } + } + createInteractiveRequest.conf = extraLivyConf.toMap + val sessionId = Option(sessionConf.get(livySessionIdConfigKey)).flatMap { id => + val res = Try(id.toInt) + if (res.isFailure) { + warn(s"Ignoring $livySessionIdConfigKey=$id as it is not an int.") + None + } else { + Some(res.get) + } + } + (statements.toList, createInteractiveRequest, sessionId) + } else { + (List(), new CreateInteractiveRequest, None) + } + } +} + +/** + * Exception which happened during the session creation and/or initialization. It contains the + * `livySession` (if it was created) where the error occurred and the `cause` of the error. + */ +class ThriftSessionCreationException(val livySession: Option[InteractiveSession], cause: Throwable) + extends Exception(cause) http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/7d8fa69f/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/SessionInfo.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/SessionInfo.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/SessionInfo.scala new file mode 100644 index 0000000..4ebf867 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/SessionInfo.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver + +import java.util + +import org.apache.hive.service.rpc.thrift.TProtocolVersion + +import org.apache.livy.Logging + +case class SessionInfo(username: String, + ipAddress: String, + forwardedAddresses: util.List[String], + protocolVersion: TProtocolVersion) { + val creationTime: Long = System.currentTimeMillis() +} + +/** + * Mirrors Hive behavior which stores thread local information in its session manager. + */ +object SessionInfo extends Logging { + + private val threadLocalIpAddress = new ThreadLocal[String] + + def setIpAddress(ipAddress: String): Unit = { + threadLocalIpAddress.set(ipAddress) + } + + def clearIpAddress(): Unit = { + threadLocalIpAddress.remove() + } + + def getIpAddress: String = threadLocalIpAddress.get + + private val threadLocalForwardedAddresses = new ThreadLocal[util.List[String]] + + def setForwardedAddresses(ipAddress: util.List[String]): Unit = { + threadLocalForwardedAddresses.set(ipAddress) + } + + def clearForwardedAddresses(): Unit = { + threadLocalForwardedAddresses.remove() + } + + def getForwardedAddresses: util.List[String] = threadLocalForwardedAddresses.get + + private val threadLocalUserName = new ThreadLocal[String]() { + override protected def initialValue: String = null + } + + def setUserName(userName: String): Unit = { + threadLocalUserName.set(userName) + } + + def clearUserName(): Unit = { + threadLocalUserName.remove() + } + + def getUserName: String = threadLocalUserName.get + + private val threadLocalProxyUserName = new ThreadLocal[String]() { + override protected def initialValue: String = null + } + + def setProxyUserName(userName: String): Unit = { + debug("setting proxy user name based on query param to: " + userName) + threadLocalProxyUserName.set(userName) + } + + def getProxyUserName: String = threadLocalProxyUserName.get + + def clearProxyUserName(): Unit = { + threadLocalProxyUserName.remove() + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/7d8fa69f/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/SessionStates.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/SessionStates.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/SessionStates.scala new file mode 100644 index 0000000..c4fd248 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/SessionStates.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver + +object SessionStates extends Enumeration { + type SessionStates = Value + + val CREATION_SUCCESS = Value("success") + val CREATION_FAILED = Value("failed") + val CREATION_IN_PROGRESS = Value("in_progress") +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/7d8fa69f/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/rpc/RpcClient.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/rpc/RpcClient.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/rpc/RpcClient.scala new file mode 100644 index 0000000..75bab0b --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/rpc/RpcClient.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.rpc + +import java.lang.reflect.InvocationTargetException + +import scala.collection.immutable.HashMap +import scala.collection.mutable.ArrayBuffer +import scala.util.Try + +import org.apache.hive.service.cli.SessionHandle + +import org.apache.livy._ +import org.apache.livy.server.interactive.InteractiveSession +import org.apache.livy.thriftserver.serde.ColumnOrientedResultSet +import org.apache.livy.thriftserver.types.DataType +import org.apache.livy.utils.LivySparkUtils + +class RpcClient(livySession: InteractiveSession) extends Logging { + import RpcClient._ + + private val isSpark1 = { + val (sparkMajorVersion, _) = + LivySparkUtils.formatSparkVersion(livySession.livyConf.get(LivyConf.LIVY_SPARK_VERSION)) + sparkMajorVersion == 1 + } + private val defaultIncrementalCollect = + livySession.livyConf.getBoolean(LivyConf.THRIFT_INCR_COLLECT_ENABLED).toString + + private val rscClient = livySession.client.get + + def isValid: Boolean = rscClient.isAlive + + private def sessionId(sessionHandle: SessionHandle): String = { + sessionHandle.getSessionId.toString + } + + @throws[Exception] + def executeSql( + sessionHandle: SessionHandle, + statementId: String, + statement: String): JobHandle[_] = { + info(s"RSC client is executing SQL query: $statement, statementId = $statementId, session = " + + sessionHandle) + require(null != statementId, s"Invalid statementId specified. StatementId = $statementId") + require(null != statement, s"Invalid statement specified. StatementId = $statement") + livySession.recordActivity() + rscClient.submit(executeSqlJob(sessionId(sessionHandle), + statementId, + statement, + isSpark1, + defaultIncrementalCollect, + s"spark.${LivyConf.THRIFT_INCR_COLLECT_ENABLED}")) + } + + @throws[Exception] + def fetchResult(statementId: String, + types: Array[DataType], + maxRows: Int): JobHandle[ColumnOrientedResultSet] = { + info(s"RSC client is fetching result for statementId $statementId with $maxRows maxRows.") + require(null != statementId, s"Invalid statementId specified. StatementId = $statementId") + livySession.recordActivity() + rscClient.submit(fetchResultJob(statementId, types, maxRows)) + } + + @throws[Exception] + def fetchResultSchema(statementId: String): JobHandle[String] = { + info(s"RSC client is fetching result schema for statementId = $statementId") + require(null != statementId, s"Invalid statementId specified. statementId = $statementId") + livySession.recordActivity() + rscClient.submit(fetchResultSchemaJob(statementId)) + } + + @throws[Exception] + def cleanupStatement(statementId: String, cancelJob: Boolean = false): JobHandle[_] = { + info(s"Cleaning up remote session for statementId = $statementId") + require(null != statementId, s"Invalid statementId specified. statementId = $statementId") + livySession.recordActivity() + rscClient.submit(cleanupStatementJob(statementId)) + } + + /** + * Creates a new Spark context for the specified session and stores it in a shared variable so + * that any incoming session uses a different one: it is needed in order to avoid interactions + * between different users working on the same remote Livy session (eg. setting a property, + * changing database, etc.). + */ + @throws[Exception] + def executeRegisterSession(sessionHandle: SessionHandle): JobHandle[_] = { + info(s"RSC client is executing register session $sessionHandle") + livySession.recordActivity() + rscClient.submit(registerSessionJob(sessionId(sessionHandle), isSpark1)) + } + + /** + * Removes the Spark session created for the specified session from the shared variable. + */ + @throws[Exception] + def executeUnregisterSession(sessionHandle: SessionHandle): JobHandle[_] = { + info(s"RSC client is executing unregister session $sessionHandle") + livySession.recordActivity() + rscClient.submit(unregisterSessionJob(sessionId(sessionHandle))) + } +} + +/** + * As remotely we don't have any class instance, all the job definitions are placed here in + * order to enforce that we are not accessing any class attribute + */ +object RpcClient { + // Maps a session ID to its SparkSession (or HiveContext/SQLContext according to the Spark + // version used) + val SESSION_SPARK_ENTRY_MAP = "livy.thriftserver.rpc_sessionIdToSparkSQLSession" + val STATEMENT_RESULT_ITER_MAP = "livy.thriftserver.rpc_statementIdToResultIter" + val STATEMENT_SCHEMA_MAP = "livy.thriftserver.rpc_statementIdToSchema" + + private def registerSessionJob(sessionId: String, isSpark1: Boolean): Job[_] = new Job[Boolean] { + override def call(jc: JobContext): Boolean = { + val spark: Any = if (isSpark1) { + Option(jc.hivectx()).getOrElse(jc.sqlctx()) + } else { + jc.sparkSession() + } + val sessionSpecificSpark = spark.getClass.getMethod("newSession").invoke(spark) + jc.sc().synchronized { + val existingMap = + Try(jc.getSharedObject[HashMap[String, AnyRef]](SESSION_SPARK_ENTRY_MAP)) + .getOrElse(new HashMap[String, AnyRef]()) + jc.setSharedObject(SESSION_SPARK_ENTRY_MAP, + existingMap + ((sessionId, sessionSpecificSpark))) + Try(jc.getSharedObject[HashMap[String, String]](STATEMENT_SCHEMA_MAP)) + .failed.foreach { _ => + jc.setSharedObject(STATEMENT_SCHEMA_MAP, new HashMap[String, String]()) + } + Try(jc.getSharedObject[HashMap[String, Iterator[_]]](STATEMENT_RESULT_ITER_MAP)) + .failed.foreach { _ => + jc.setSharedObject(STATEMENT_RESULT_ITER_MAP, new HashMap[String, Iterator[_]]()) + } + } + true + } + } + + private def unregisterSessionJob(sessionId: String): Job[_] = new Job[Boolean] { + override def call(jobContext: JobContext): Boolean = { + jobContext.sc().synchronized { + val existingMap = + jobContext.getSharedObject[HashMap[String, AnyRef]](SESSION_SPARK_ENTRY_MAP) + jobContext.setSharedObject(SESSION_SPARK_ENTRY_MAP, existingMap - sessionId) + } + true + } + } + + private def cleanupStatementJob(statementId: String): Job[_] = new Job[Boolean] { + override def call(jc: JobContext): Boolean = { + val sparkContext = jc.sc() + sparkContext.cancelJobGroup(statementId) + sparkContext.synchronized { + // Clear job group only if current job group is same as expected job group. + if (sparkContext.getLocalProperty("spark.jobGroup.id") == statementId) { + sparkContext.clearJobGroup() + } + val iterMap = jc.getSharedObject[HashMap[String, Iterator[_]]](STATEMENT_RESULT_ITER_MAP) + jc.setSharedObject(STATEMENT_RESULT_ITER_MAP, iterMap - statementId) + val schemaMap = jc.getSharedObject[HashMap[String, String]](STATEMENT_SCHEMA_MAP) + jc.setSharedObject(STATEMENT_SCHEMA_MAP, schemaMap - statementId) + } + true + } + } + + private def fetchResultSchemaJob(statementId: String): Job[String] = new Job[String] { + override def call(jobContext: JobContext): String = { + jobContext.getSharedObject[HashMap[String, String]](STATEMENT_SCHEMA_MAP)(statementId) + } + } + + private def fetchResultJob(statementId: String, + types: Array[DataType], + maxRows: Int): Job[ColumnOrientedResultSet] = new Job[ColumnOrientedResultSet] { + override def call(jobContext: JobContext): ColumnOrientedResultSet = { + val statementIterMap = + jobContext.getSharedObject[HashMap[String, Iterator[_]]](STATEMENT_RESULT_ITER_MAP) + val iter = statementIterMap(statementId) + + if (null == iter) { + // Previous query execution failed. + throw new NoSuchElementException("No successful query executed for output") + } + + val resultSet = new ColumnOrientedResultSet(types) + val numOfColumns = types.length + if (!iter.hasNext) { + resultSet + } else { + var curRow = 0 + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = ArrayBuffer[Any]() + var curCol: Integer = 0 + while (curCol < numOfColumns) { + row += sparkRow.getClass.getMethod("get", classOf[Int]).invoke(sparkRow, curCol) + curCol += 1 + } + resultSet.addRow(row.toArray.asInstanceOf[Array[Object]]) + curRow += 1 + } + resultSet + } + } + } + + private def executeSqlJob(sessionId: String, + statementId: String, + statement: String, + isSpark1: Boolean, + defaultIncrementalCollect: String, + incrementalCollectEnabledProp: String): Job[_] = new Job[Boolean] { + override def call(jc: JobContext): Boolean = { + val sparkContext = jc.sc() + sparkContext.synchronized { + sparkContext.setJobGroup(statementId, statement) + } + val spark = jc.getSharedObject[HashMap[String, AnyRef]](SESSION_SPARK_ENTRY_MAP)(sessionId) + try { + val result = spark.getClass.getMethod("sql", classOf[String]).invoke(spark, statement) + val schema = result.getClass.getMethod("schema").invoke(result) + val jsonString = schema.getClass.getMethod("json").invoke(schema).asInstanceOf[String] + + // Set the schema in the shared map + sparkContext.synchronized { + val existingMap = jc.getSharedObject[HashMap[String, String]](STATEMENT_SCHEMA_MAP) + jc.setSharedObject(STATEMENT_SCHEMA_MAP, existingMap + ((statementId, jsonString))) + } + + val incrementalCollect = { + if (isSpark1) { + spark.getClass.getMethod("getConf", classOf[String], classOf[String]) + .invoke(spark, + incrementalCollectEnabledProp, + defaultIncrementalCollect) + .asInstanceOf[String].toBoolean + } else { + val conf = spark.getClass.getMethod("conf").invoke(spark) + conf.getClass.getMethod("get", classOf[String], classOf[String]) + .invoke(conf, + incrementalCollectEnabledProp, + defaultIncrementalCollect) + .asInstanceOf[String].toBoolean + } + } + + val iter = if (incrementalCollect) { + val rdd = result.getClass.getMethod("rdd").invoke(result) + rdd.getClass.getMethod("toLocalIterator").invoke(rdd).asInstanceOf[Iterator[_]] + } else { + result.getClass.getMethod("collect").invoke(result).asInstanceOf[Array[_]].iterator + } + + // Set the iterator in the shared map + sparkContext.synchronized { + val existingMap = + jc.getSharedObject[HashMap[String, Iterator[_]]](STATEMENT_RESULT_ITER_MAP) + jc.setSharedObject(STATEMENT_RESULT_ITER_MAP, existingMap + ((statementId, iter))) + } + } catch { + case e: InvocationTargetException => throw e.getCause + } + + true + } + } +}