http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftBinaryCLIService.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftBinaryCLIService.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftBinaryCLIService.scala new file mode 100644 index 0000000..e16313d --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftBinaryCLIService.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.cli + +import java.net.InetSocketAddress +import java.util +import java.util.concurrent._ +import javax.net.ssl.SSLServerSocket + +import org.apache.hive.service.cli.HiveSQLException +import org.apache.hive.service.server.ThreadFactoryWithGarbageCleanup +import org.apache.thrift.TProcessorFactory +import org.apache.thrift.protocol.{TBinaryProtocol, TProtocol} +import org.apache.thrift.server.{ServerContext, TServer, TServerEventHandler, TThreadPoolServer} +import org.apache.thrift.transport.{TServerSocket, TSSLTransportFactory, TTransport, TTransportFactory} + +import org.apache.livy.LivyConf +import org.apache.livy.thriftserver.LivyCLIService +import org.apache.livy.thriftserver.auth.AuthFactory + +/** + * This class is ported from Hive. We cannot reuse Hive's one because we need to use the + * `LivyCLIService`, `LivyConf` and `AuthFacotry` instead of Hive's one. + */ +class ThriftBinaryCLIService(override val cliService: LivyCLIService, val oomHook: Runnable) + extends ThriftCLIService(cliService, classOf[ThriftBinaryCLIService].getSimpleName) { + + protected var server: TServer = _ + + override lazy val hiveAuthFactory = new AuthFactory(livyConf) + + protected def initServer(): Unit = { + try { + // Server thread pool + val executorService = new ThreadPoolExecutorWithOomHook( + minWorkerThreads, + maxWorkerThreads, + workerKeepAliveTime, + TimeUnit.SECONDS, + new SynchronousQueue[Runnable], + new ThreadFactoryWithGarbageCleanup("LivyThriftserver-Handler-Pool"), + oomHook) + // Thrift configs + val transportFactory: TTransportFactory = hiveAuthFactory.getAuthTransFactory + val processorFactory: TProcessorFactory = hiveAuthFactory.getAuthProcFactory(this) + var serverSocket: TServerSocket = null + val serverAddress = if (hiveHost == null || hiveHost.isEmpty) { + new InetSocketAddress(portNum) // Wildcard bind + } else { + new InetSocketAddress(hiveHost, portNum) + } + if (!livyConf.getBoolean(LivyConf.THRIFT_USE_SSL)) { + serverSocket = new TServerSocket(serverAddress) + } else { + val sslVersionBlacklist = new util.ArrayList[String] + livyConf.get(LivyConf.THRIFT_SSL_PROTOCOL_BLACKLIST).split(",").foreach { sslVersion => + sslVersionBlacklist.add(sslVersion.trim.toLowerCase) + } + val keyStorePath = livyConf.get(LivyConf.SSL_KEYSTORE).trim + if (keyStorePath.isEmpty) { + throw new IllegalArgumentException( + s"${LivyConf.SSL_KEYSTORE.key} Not configured for SSL connection") + } + val keyStorePassword = livyConf.get(LivyConf.SSL_KEYSTORE_PASSWORD) + val params = new TSSLTransportFactory.TSSLTransportParameters + params.setKeyStore(keyStorePath, keyStorePassword) + serverSocket = + TSSLTransportFactory.getServerSocket(portNum, 0, serverAddress.getAddress, params) + if (serverSocket.getServerSocket.isInstanceOf[SSLServerSocket]) { + val sslServerSocket = serverSocket.getServerSocket.asInstanceOf[SSLServerSocket] + val enabledProtocols = sslServerSocket.getEnabledProtocols.filter { protocol => + if (sslVersionBlacklist.contains(protocol.toLowerCase)) { + debug(s"Disabling SSL Protocol: $protocol") + false + } else { + true + } + } + sslServerSocket.setEnabledProtocols(enabledProtocols) + info(s"SSL Server Socket Enabled Protocols: ${sslServerSocket.getEnabledProtocols}") + } + } + // Server args + val maxMessageSize = livyConf.getInt(LivyConf.THRIFT_MAX_MESSAGE_SIZE) + val requestTimeout = + livyConf.getTimeAsMs(LivyConf.THRIFT_LOGIN_TIMEOUT).asInstanceOf[Int] + val beBackoffSlotLength = + livyConf.getTimeAsMs(LivyConf.THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH).asInstanceOf[Int] + val sargs = new TThreadPoolServer.Args(serverSocket) + .processorFactory(processorFactory) + .transportFactory(transportFactory) + .protocolFactory(new TBinaryProtocol.Factory) + .inputProtocolFactory( + new TBinaryProtocol.Factory(true, true, maxMessageSize, maxMessageSize)) + .requestTimeout(requestTimeout) + .requestTimeoutUnit(TimeUnit.MILLISECONDS) + .beBackoffSlotLength(beBackoffSlotLength) + .beBackoffSlotLengthUnit(TimeUnit.MILLISECONDS) + .executorService(executorService) + // TCP Server + server = new TThreadPoolServer(sargs) + server.setServerEventHandler(new TServerEventHandler() { + override def createContext(input: TProtocol, output: TProtocol): ServerContext = { + new ThriftCLIServerContext + } + + override def deleteContext( + serverContext: ServerContext, + input: TProtocol, + output: TProtocol): Unit = { + val context = serverContext.asInstanceOf[ThriftCLIServerContext] + val sessionHandle = context.getSessionHandle + if (sessionHandle != null) { + info("Session disconnected without closing properly. ") + try { + val close = livyConf.getBoolean(LivyConf.THRIFT_CLOSE_SESSION_ON_DISCONNECT) + info("Closing the session: " + sessionHandle) + if (close) { + cliService.closeSession(sessionHandle) + } + } catch { + case e: HiveSQLException => warn("Failed to close session: " + e, e) + } + } + } + + override def preServe(): Unit = {} + + override def processContext( + serverContext: ServerContext, + input: TTransport, + output: TTransport): Unit = { + currentServerContext.set(serverContext) + } + }) + info(s"Starting ${classOf[ThriftBinaryCLIService].getSimpleName} on port $portNum " + + s"with $minWorkerThreads...$maxWorkerThreads worker threads") + } catch { + case e: Exception => throw new RuntimeException("Failed to init thrift server", e) + } + } + + override def run(): Unit = { + try { + server.serve() + } catch { + case t: InterruptedException => + // This is likely a shutdown + info(s"Caught ${t.getClass.getSimpleName}. Shutting down thrift server.") + case t: Throwable => + error(s"Exception caught by ${this.getClass.getSimpleName}. Exiting.", t) + System.exit(-1) + } + } + + protected def stopServer(): Unit = { + server.stop() + server = null + info("Thrift server has stopped") + } +}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftCLIService.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftCLIService.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftCLIService.scala new file mode 100644 index 0000000..4a3276f --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftCLIService.scala @@ -0,0 +1,745 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.cli + +import java.io.IOException +import java.net.{InetAddress, UnknownHostException} +import java.util +import java.util.Collections +import javax.security.auth.login.LoginException + +import scala.collection.JavaConverters._ + +import com.google.common.base.Preconditions.checkArgument +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.authentication.util.KerberosName +import org.apache.hadoop.security.authorize.ProxyUsers +import org.apache.hadoop.util.StringUtils +import org.apache.hive.service.{ServiceException, ServiceUtils} +import org.apache.hive.service.auth.{HiveAuthConstants, TSetIpAddressProcessor} +import org.apache.hive.service.auth.HiveAuthConstants.AuthTypes +import org.apache.hive.service.cli._ +import org.apache.hive.service.rpc.thrift._ +import org.apache.thrift.TException +import org.apache.thrift.server.ServerContext + +import org.apache.livy.LivyConf +import org.apache.livy.thriftserver.{LivyCLIService, LivyThriftServer, SessionInfo, ThriftService} +import org.apache.livy.thriftserver.auth.AuthFactory + +/** + * This class is ported from Hive. We cannot reuse Hive's one because we need to use the + * `LivyCLIService`, `LivyConf` and `AuthFacotry` instead of Hive's one. + */ +abstract class ThriftCLIService(val cliService: LivyCLIService, val serviceName: String) + extends ThriftService(serviceName) with TCLIService.Iface with Runnable { + + def hiveAuthFactory: AuthFactory + + protected val currentServerContext = new ThreadLocal[ServerContext] + protected var portNum: Int = 0 + protected var serverIPAddress: InetAddress = _ + protected var hiveHost: String = _ + private var isStarted: Boolean = false + protected var isEmbedded: Boolean = false + protected var livyConf: LivyConf = _ + protected var minWorkerThreads: Int = 0 + protected var maxWorkerThreads: Int = 0 + protected var workerKeepAliveTime: Long = 0L + private var serverThread: Thread = _ + + override def init(conf: LivyConf): Unit = { + livyConf = conf + hiveHost = livyConf.get(LivyConf.THRIFT_BIND_HOST) + try { + if (hiveHost == null || hiveHost.isEmpty) { + serverIPAddress = InetAddress.getLocalHost + } else { + serverIPAddress = InetAddress.getByName(hiveHost) + } + } catch { + case e: UnknownHostException => + throw new ServiceException(e) + } + portNum = livyConf.getInt(LivyConf.THRIFT_SERVER_PORT) + workerKeepAliveTime = livyConf.getTimeAsMs(LivyConf.THRIFT_WORKER_KEEPALIVE_TIME) / 1000 + minWorkerThreads = livyConf.getInt(LivyConf.THRIFT_MIN_WORKER_THREADS) + maxWorkerThreads = livyConf.getInt(LivyConf.THRIFT_MAX_WORKER_THREADS) + super.init(livyConf) + } + + protected def initServer(): Unit + + override def start(): Unit = { + super.start() + if (!isStarted && !isEmbedded) { + initServer() + serverThread = new Thread(this) + serverThread.setName("Thrift Server") + serverThread.start() + isStarted = true + } + } + + protected def stopServer(): Unit + + override def stop(): Unit = { + if (isStarted && !isEmbedded) { + if (serverThread != null) { + serverThread.interrupt() + serverThread = null + } + stopServer() + isStarted = false + } + super.stop() + } + + def getPortNumber: Int = portNum + + def getServerIPAddress: InetAddress = serverIPAddress + + @throws[TException] + override def GetDelegationToken(req: TGetDelegationTokenReq): TGetDelegationTokenResp = { + val resp: TGetDelegationTokenResp = new TGetDelegationTokenResp + if (!hiveAuthFactory.isSASLKerberosUser) { + resp.setStatus(unsecureTokenErrorStatus) + } else { + try { + val token = cliService.getDelegationToken( + new SessionHandle(req.getSessionHandle), hiveAuthFactory, req.getOwner, req.getRenewer) + resp.setDelegationToken(token) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: HiveSQLException => + error("Error obtaining delegation token", e) + val tokenErrorStatus = HiveSQLException.toTStatus(e) + tokenErrorStatus.setSqlState("42000") + resp.setStatus(tokenErrorStatus) + } + } + resp + } + + @throws[TException] + override def CancelDelegationToken(req: TCancelDelegationTokenReq): TCancelDelegationTokenResp = { + val resp: TCancelDelegationTokenResp = new TCancelDelegationTokenResp + if (!hiveAuthFactory.isSASLKerberosUser) { + resp.setStatus(unsecureTokenErrorStatus) + } else { + try { + cliService.cancelDelegationToken( + new SessionHandle(req.getSessionHandle), hiveAuthFactory, req.getDelegationToken) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: HiveSQLException => + error("Error canceling delegation token", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + } + resp + } + + @throws[TException] + override def RenewDelegationToken(req: TRenewDelegationTokenReq): TRenewDelegationTokenResp = { + val resp: TRenewDelegationTokenResp = new TRenewDelegationTokenResp + if (!hiveAuthFactory.isSASLKerberosUser) { + resp.setStatus(unsecureTokenErrorStatus) + } else { + try { + cliService.renewDelegationToken( + new SessionHandle(req.getSessionHandle), hiveAuthFactory, req.getDelegationToken) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: HiveSQLException => + error("Error obtaining renewing token", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + } + resp + } + + private def unsecureTokenErrorStatus: TStatus = { + val errorStatus: TStatus = new TStatus(TStatusCode.ERROR_STATUS) + errorStatus.setErrorMessage( + "Delegation token only supported over remote client with kerberos authentication") + errorStatus + } + + @throws[TException] + override def OpenSession(req: TOpenSessionReq): TOpenSessionResp = { + info("Client protocol version: " + req.getClient_protocol) + val resp: TOpenSessionResp = new TOpenSessionResp + try { + val sessionHandle = getSessionHandle(req, resp) + resp.setSessionHandle(sessionHandle.toTSessionHandle) + val configurationMap: util.Map[String, String] = new util.HashMap[String, String] + // Set the updated fetch size from the server into the configuration map for the client + val defaultFetchSize = + Integer.toString(livyConf.getInt(LivyConf.THRIFT_RESULTSET_DEFAULT_FETCH_SIZE)) + configurationMap.put(LivyConf.THRIFT_RESULTSET_DEFAULT_FETCH_SIZE.key, defaultFetchSize) + resp.setConfiguration(configurationMap) + resp.setStatus(ThriftCLIService.OK_STATUS) + Option(currentServerContext.get).foreach { context => + context.asInstanceOf[ThriftCLIServerContext].setSessionHandle(sessionHandle) + } + } catch { + case e: Exception => + warn("Error opening session: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def SetClientInfo(req: TSetClientInfoReq): TSetClientInfoResp = { + // TODO: We don't do anything for now, just log this for debugging. + // We may be able to make use of this later, e.g. for workload management. + if (req.isSetConfiguration) { + val sh = new SessionHandle(req.getSessionHandle) + val sb = new StringBuilder("Client information for ").append(sh).append(": ") + + def processEntry(e: util.Map.Entry[String, String]): Unit = { + sb.append(e.getKey).append(" = ").append(e.getValue) + if ("ApplicationName" == e.getKey) { + cliService.setApplicationName(sh, e.getValue) + } + } + + val entries = req.getConfiguration.entrySet.asScala.toSeq + try { + entries.headOption.foreach(processEntry) + entries.tail.foreach { e => + sb.append(", ") + processEntry(e) + } + } catch { + case ex: Exception => + warn("Error setting application name", ex) + return new TSetClientInfoResp(HiveSQLException.toTStatus(ex)) + } + info(sb.toString()) + } + new TSetClientInfoResp(ThriftCLIService.OK_STATUS) + } + + private def getIpAddress: String = { + // Http transport mode. + // We set the thread local ip address, in ThriftHttpServlet. + val clientIpAddress = if (LivyThriftServer.isHTTPTransportMode(livyConf)) { + SessionInfo.getIpAddress + } else if (hiveAuthFactory.isSASLWithKerberizedHadoop) { + hiveAuthFactory.getIpAddress + } else { + // NOSASL + TSetIpAddressProcessor.getUserIpAddress + } + debug(s"Client's IP Address: $clientIpAddress") + clientIpAddress + } + + /** + * Returns the effective username. + * 1. If livy.server.thrift.allow.user.substitution = false: the username of the connecting user + * 2. If livy.server.thrift.allow.user.substitution = true: the username of the end user, + * that the connecting user is trying to proxy for. + * This includes a check whether the connecting user is allowed to proxy for the end user. + */ + @throws[HiveSQLException] + @throws[IOException] + private def getUserName(req: TOpenSessionReq): String = { + val username = if (LivyThriftServer.isHTTPTransportMode(livyConf)) { + Option(SessionInfo.getUserName).getOrElse(req.getUsername) + } else if (hiveAuthFactory.isSASLWithKerberizedHadoop) { + Option(hiveAuthFactory.getRemoteUser).orElse(Option(TSetIpAddressProcessor.getUserName)) + .getOrElse(req.getUsername) + } else { + Option(TSetIpAddressProcessor.getUserName).getOrElse(req.getUsername) + } + val effectiveClientUser = + getProxyUser(getShortName(username), req.getConfiguration, getIpAddress) + debug(s"Client's username: $effectiveClientUser") + effectiveClientUser + } + + @throws[IOException] + private def getShortName(userName: String): String = { + Option(userName).map { un => + if (hiveAuthFactory.isSASLKerberosUser) { + // KerberosName.getShorName can only be used for kerberos user + new KerberosName(un).getShortName + } else { + val indexOfDomainMatch = ServiceUtils.indexOfDomainMatch(un) + if (indexOfDomainMatch <= 0) { + un + } else { + un.substring(0, indexOfDomainMatch) + } + } + }.orNull + } + + /** + * Create a session handle + */ + @throws[HiveSQLException] + @throws[LoginException] + @throws[IOException] + private[thriftserver] def getSessionHandle( + req: TOpenSessionReq, res: TOpenSessionResp): SessionHandle = { + val userName = getUserName(req) + val ipAddress = getIpAddress + val protocol = getMinVersion(LivyCLIService.SERVER_VERSION, req.getClient_protocol) + val sessionHandle = + if (livyConf.getBoolean(LivyConf.THRIFT_ENABLE_DOAS) && (userName != null)) { + cliService.openSessionWithImpersonation( + protocol, userName, req.getPassword, ipAddress, req.getConfiguration, null) + } else { + cliService.openSession(protocol, userName, req.getPassword, ipAddress, req.getConfiguration) + } + res.setServerProtocolVersion(protocol) + sessionHandle + } + + @throws[HiveSQLException] + private def getProgressedPercentage(opHandle: OperationHandle): Double = { + checkArgument(OperationType.EXECUTE_STATEMENT == opHandle.getOperationType) + 0.0 + } + + private def getMinVersion(versions: TProtocolVersion*): TProtocolVersion = { + val values = TProtocolVersion.values + var current = values(values.length - 1).getValue + versions.foreach { version => + if (current > version.getValue) { + current = version.getValue + } + } + val res = values.find(_.getValue == current) + assert(res.isDefined) + res.get + } + + @throws[TException] + override def CloseSession(req: TCloseSessionReq): TCloseSessionResp = { + val resp = new TCloseSessionResp + try { + val sessionHandle = new SessionHandle(req.getSessionHandle) + cliService.closeSession(sessionHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + Option(currentServerContext.get).foreach { ctx => + ctx.asInstanceOf[ThriftCLIServerContext].setSessionHandle(null) + } + } catch { + case e: Exception => + warn("Error closing session: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetInfo(req: TGetInfoReq): TGetInfoResp = { + val resp = new TGetInfoResp + try { + val getInfoValue = cliService.getInfo( + new SessionHandle(req.getSessionHandle), GetInfoType.getGetInfoType(req.getInfoType)) + resp.setInfoValue(getInfoValue.toTGetInfoValue) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting info: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def ExecuteStatement(req: TExecuteStatementReq): TExecuteStatementResp = { + val resp = new TExecuteStatementResp + try { + val sessionHandle = new SessionHandle(req.getSessionHandle) + val statement = req.getStatement + val confOverlay = req.getConfOverlay + val runAsync = req.isRunAsync + val queryTimeout = req.getQueryTimeout + val operationHandle = if (runAsync) { + cliService.executeStatementAsync(sessionHandle, statement, confOverlay, queryTimeout) + } else { + cliService.executeStatement(sessionHandle, statement, confOverlay, queryTimeout) + } + resp.setOperationHandle(operationHandle.toTOperationHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error executing statement: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetTypeInfo(req: TGetTypeInfoReq): TGetTypeInfoResp = { + val resp = new TGetTypeInfoResp + try { + val operationHandle = cliService.getTypeInfo(new SessionHandle(req.getSessionHandle)) + resp.setOperationHandle(operationHandle.toTOperationHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting type info: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetCatalogs(req: TGetCatalogsReq): TGetCatalogsResp = { + val resp = new TGetCatalogsResp + try { + val opHandle = cliService.getCatalogs(new SessionHandle(req.getSessionHandle)) + resp.setOperationHandle(opHandle.toTOperationHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting catalogs: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetSchemas(req: TGetSchemasReq): TGetSchemasResp = { + val resp = new TGetSchemasResp + try { + val opHandle = cliService.getSchemas( + new SessionHandle(req.getSessionHandle), req.getCatalogName, req.getSchemaName) + resp.setOperationHandle(opHandle.toTOperationHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting schemas: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetTables(req: TGetTablesReq): TGetTablesResp = { + val resp = new TGetTablesResp + try { + val opHandle = cliService.getTables( + new SessionHandle(req.getSessionHandle), + req.getCatalogName, + req.getSchemaName, + req.getTableName, + req.getTableTypes) + resp.setOperationHandle(opHandle.toTOperationHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting tables: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetTableTypes(req: TGetTableTypesReq): TGetTableTypesResp = { + val resp = new TGetTableTypesResp + try { + val opHandle = cliService.getTableTypes(new SessionHandle(req.getSessionHandle)) + resp.setOperationHandle(opHandle.toTOperationHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting table types: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetColumns(req: TGetColumnsReq): TGetColumnsResp = { + val resp = new TGetColumnsResp + try { + val opHandle = cliService.getColumns( + new SessionHandle(req.getSessionHandle), + req.getCatalogName, + req.getSchemaName, + req.getTableName, + req.getColumnName) + resp.setOperationHandle(opHandle.toTOperationHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting columns: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetFunctions(req: TGetFunctionsReq): TGetFunctionsResp = { + val resp = new TGetFunctionsResp + try { + val opHandle = cliService.getFunctions( + new SessionHandle(req.getSessionHandle), + req.getCatalogName, + req.getSchemaName, + req.getFunctionName) + resp.setOperationHandle(opHandle.toTOperationHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting functions: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetOperationStatus(req: TGetOperationStatusReq): TGetOperationStatusResp = { + val resp = new TGetOperationStatusResp + val operationHandle = new OperationHandle(req.getOperationHandle) + try { + val operationStatus = cliService.getOperationStatus(operationHandle, req.isGetProgressUpdate) + resp.setOperationState(operationStatus.state.toTOperationState) + resp.setErrorMessage(operationStatus.state.getErrorMessage) + val opException = operationStatus.operationException + resp.setOperationStarted(operationStatus.operationStarted) + resp.setOperationCompleted(operationStatus.operationCompleted) + resp.setHasResultSet(operationStatus.hasResultSet) + val executionStatus = TJobExecutionStatus.NOT_AVAILABLE + resp.setProgressUpdateResponse(new TProgressUpdateResp( + Collections.emptyList[String], + Collections.emptyList[util.List[String]], + 0.0D, + executionStatus, + "", + 0L)) + if (opException != null) { + resp.setSqlState(opException.getSQLState) + resp.setErrorCode(opException.getErrorCode) + if (opException.getErrorCode == 29999) { + resp.setErrorMessage(StringUtils.stringifyException(opException)) + } else { + resp.setErrorMessage(opException.getMessage) + } + } else if (OperationType.EXECUTE_STATEMENT == operationHandle.getOperationType) { + resp.getProgressUpdateResponse.setProgressedPercentage( + getProgressedPercentage(operationHandle)) + } + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting operation status: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def CancelOperation(req: TCancelOperationReq): TCancelOperationResp = { + val resp = new TCancelOperationResp + try { + cliService.cancelOperation(new OperationHandle(req.getOperationHandle)) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error cancelling operation: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def CloseOperation(req: TCloseOperationReq): TCloseOperationResp = { + val resp = new TCloseOperationResp + try { + cliService.closeOperation(new OperationHandle(req.getOperationHandle)) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error closing operation: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetResultSetMetadata(req: TGetResultSetMetadataReq): TGetResultSetMetadataResp = { + val resp = new TGetResultSetMetadataResp + try { + val schema = cliService.getResultSetMetadata(new OperationHandle(req.getOperationHandle)) + resp.setSchema(schema.toTTableSchema) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting result set metadata: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def FetchResults(req: TFetchResultsReq): TFetchResultsResp = { + val resp = new TFetchResultsResp + try { + // Set fetch size + val maxFetchSize = livyConf.getInt(LivyConf.THRIFT_RESULTSET_MAX_FETCH_SIZE) + if (req.getMaxRows > maxFetchSize) { + req.setMaxRows(maxFetchSize) + } + val rowSet = cliService.fetchResults( + new OperationHandle(req.getOperationHandle), + FetchOrientation.getFetchOrientation(req.getOrientation), + req.getMaxRows, + FetchType.getFetchType(req.getFetchType)) + resp.setResults(rowSet.toTRowSet) + resp.setHasMoreRows(false) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error fetching results: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetPrimaryKeys(req: TGetPrimaryKeysReq): TGetPrimaryKeysResp = { + val resp = new TGetPrimaryKeysResp + try { + val opHandle = cliService.getPrimaryKeys( + new SessionHandle(req.getSessionHandle), + req.getCatalogName, + req.getSchemaName, + req.getTableName) + resp.setOperationHandle(opHandle.toTOperationHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting functions: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetCrossReference(req: TGetCrossReferenceReq): TGetCrossReferenceResp = { + val resp = new TGetCrossReferenceResp + try { + val opHandle = cliService.getCrossReference( + new SessionHandle(req.getSessionHandle), + req.getParentCatalogName, + req.getParentSchemaName, + req.getParentTableName, + req.getForeignCatalogName, + req.getForeignSchemaName, + req.getForeignTableName) + resp.setOperationHandle(opHandle.toTOperationHandle) + resp.setStatus(ThriftCLIService.OK_STATUS) + } catch { + case e: Exception => + warn("Error getting functions: ", e) + resp.setStatus(HiveSQLException.toTStatus(e)) + } + resp + } + + @throws[TException] + override def GetQueryId(req: TGetQueryIdReq): TGetQueryIdResp = { + try { + new TGetQueryIdResp(cliService.getQueryId(req.getOperationHandle)) + } catch { + case e: HiveSQLException => throw new TException(e) + } + } + + override def run(): Unit + + /** + * If the proxy user name is provided then check privileges to substitute the user. + */ + @throws[HiveSQLException] + private def getProxyUser( + realUser: String, + sessionConf: util.Map[String, String], + ipAddress: String): String = { + var proxyUser: String = null + // We set the thread local proxy username, in ThriftHttpServlet. + if (livyConf.get(LivyConf.THRIFT_TRANSPORT_MODE).equalsIgnoreCase("http")) { + proxyUser = SessionInfo.getProxyUserName + debug("Proxy user from query string: " + proxyUser) + } + if (proxyUser == null && sessionConf != null && + sessionConf.containsKey(HiveAuthConstants.HS2_PROXY_USER)) { + val proxyUserFromThriftBody = sessionConf.get(HiveAuthConstants.HS2_PROXY_USER) + debug("Proxy user from thrift body: " + proxyUserFromThriftBody) + proxyUser = proxyUserFromThriftBody + } + if (proxyUser == null) return realUser + // check whether substitution is allowed + if (!livyConf.getBoolean(LivyConf.THRIFT_ALLOW_USER_SUBSTITUTION)) { + throw new HiveSQLException("Proxy user substitution is not allowed") + } + // If there's no authentication, then directly substitute the user + if (AuthTypes.NONE.toString.equalsIgnoreCase(livyConf.get(LivyConf.THRIFT_AUTHENTICATION))) { + return proxyUser + } + // Verify proxy user privilege of the realUser for the proxyUser + verifyProxyAccess(realUser, proxyUser, ipAddress) + debug("Verified proxy user: " + proxyUser) + proxyUser + } + + @throws[HiveSQLException] + private def verifyProxyAccess(realUser: String, proxyUser: String, ipAddress: String): Unit = { + try { + val sessionUgi = if (UserGroupInformation.isSecurityEnabled) { + UserGroupInformation.createProxyUser( + new KerberosName(realUser).getServiceName, UserGroupInformation.getLoginUser) + } else { + UserGroupInformation.createRemoteUser(realUser) + } + if (!proxyUser.equalsIgnoreCase(realUser)) { + ProxyUsers.refreshSuperUserGroupsConfiguration() + ProxyUsers.authorize(UserGroupInformation.createProxyUser(proxyUser, sessionUgi), ipAddress) + } + } catch { + case e: IOException => + throw new HiveSQLException( + s"Failed to validate proxy privilege of $realUser for $proxyUser", "08S01", e) + } + } +} + +object ThriftCLIService { + private val OK_STATUS: TStatus = new TStatus(TStatusCode.SUCCESS_STATUS) +} + +private[thriftserver] class ThriftCLIServerContext extends ServerContext { + private var sessionHandle: SessionHandle = _ + + def setSessionHandle(sessionHandle: SessionHandle): Unit = { + this.sessionHandle = sessionHandle + } + + def getSessionHandle: SessionHandle = sessionHandle +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftHttpCLIService.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftHttpCLIService.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftHttpCLIService.scala new file mode 100644 index 0000000..8a3d439 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftHttpCLIService.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.cli + +import java.util.concurrent.SynchronousQueue +import java.util.concurrent.TimeUnit +import javax.ws.rs.HttpMethod + +import org.apache.hive.service.rpc.thrift.TCLIService +import org.apache.hive.service.server.ThreadFactoryWithGarbageCleanup +import org.apache.thrift.protocol.TBinaryProtocol +import org.eclipse.jetty.server.HttpConfiguration +import org.eclipse.jetty.server.HttpConnectionFactory +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.ServerConnector +import org.eclipse.jetty.server.handler.gzip.GzipHandler +import org.eclipse.jetty.servlet.ServletContextHandler +import org.eclipse.jetty.servlet.ServletHolder +import org.eclipse.jetty.util.ssl.SslContextFactory +import org.eclipse.jetty.util.thread.ExecutorThreadPool + +import org.apache.livy.LivyConf +import org.apache.livy.thriftserver.LivyCLIService +import org.apache.livy.thriftserver.auth.AuthFactory + +/** + * This class is ported from Hive. We cannot reuse Hive's one because we need to use the + * `LivyCLIService`, `LivyConf` and `AuthFacotry` instead of Hive's one. + */ +class ThriftHttpCLIService( + override val cliService: LivyCLIService, + val oomHook: Runnable) + extends ThriftCLIService(cliService, classOf[ThriftHttpCLIService].getSimpleName) { + + protected var server: Server = _ + + override lazy val hiveAuthFactory = new AuthFactory(livyConf) + + /** + * Configure Jetty to serve http requests. Example of a client connection URL: + * http://localhost:10000/servlets/thrifths2/ A gateway may cause actual target URL to differ, + * e.g. http://gateway:port/livy/servlets/thrifths2/ + */ + protected def initServer(): Unit = { + try { + // Server thread pool + // Start with minWorkerThreads, expand till maxWorkerThreads and reject subsequent requests + val executorService = new ThreadPoolExecutorWithOomHook( + minWorkerThreads, + maxWorkerThreads, + workerKeepAliveTime, + TimeUnit.SECONDS, + new SynchronousQueue[Runnable], + new ThreadFactoryWithGarbageCleanup("LivyThriftserver-HttpHandler-Pool"), + oomHook) + val threadPool = new ExecutorThreadPool(executorService) + // HTTP Server + server = new Server(threadPool) + val conf = new HttpConfiguration + // Configure header size + val requestHeaderSize = livyConf.getInt(LivyConf.THRIFT_HTTP_REQUEST_HEADER_SIZE) + val responseHeaderSize = livyConf.getInt(LivyConf.THRIFT_HTTP_RESPONSE_HEADER_SIZE) + conf.setRequestHeaderSize(requestHeaderSize) + conf.setResponseHeaderSize(responseHeaderSize) + val http = new HttpConnectionFactory(conf) + val useSsl = livyConf.getBoolean(LivyConf.THRIFT_USE_SSL) + val schemeName = if (useSsl) "https" else "http" + // Change connector if SSL is used + val connector = if (useSsl) { + val keyStorePath = livyConf.get(LivyConf.SSL_KEYSTORE).trim + val keyStorePassword = livyConf.get(LivyConf.SSL_KEYSTORE_PASSWORD) + if (keyStorePath.isEmpty) { + throw new IllegalArgumentException( + s"${LivyConf.SSL_KEYSTORE.key} Not configured for SSL connection") + } + val sslContextFactory = new SslContextFactory + val excludedProtocols = livyConf.get(LivyConf.THRIFT_SSL_PROTOCOL_BLACKLIST).split(",") + info(s"HTTP Server SSL: adding excluded protocols: $excludedProtocols") + sslContextFactory.addExcludeProtocols(excludedProtocols: _*) + info("HTTP Server SSL: SslContextFactory.getExcludeProtocols = " + + sslContextFactory.getExcludeProtocols) + sslContextFactory.setKeyStorePath(keyStorePath) + sslContextFactory.setKeyStorePassword(keyStorePassword) + new ServerConnector(server, sslContextFactory, http) + } else { + new ServerConnector(server, http) + } + connector.setPort(portNum) + // Linux: yes, Windows:no + connector.setReuseAddress(true) + val maxIdleTime = livyConf.getTimeAsMs(LivyConf.THRIFT_HTTP_MAX_IDLE_TIME).asInstanceOf[Int] + connector.setIdleTimeout(maxIdleTime) + server.addConnector(connector) + // Thrift configs + val processor = new TCLIService.Processor[TCLIService.Iface](this) + val protocolFactory = new TBinaryProtocol.Factory + // Set during the init phase of LivyThriftserver if auth mode is kerberos + // UGI for the livy/_HOST (kerberos) principal + val serviceUGI = cliService.getServiceUGI + // UGI for the http/_HOST (SPNego) principal + val httpUGI = cliService.getHttpUGI + val authType = livyConf.get(LivyConf.THRIFT_AUTHENTICATION) + val thriftHttpServlet = new ThriftHttpServlet( + processor, + protocolFactory, + authType, + serviceUGI, + httpUGI, + hiveAuthFactory, + livyConf) + // Context handler + val context = new ServletContextHandler(ServletContextHandler.SESSIONS) + context.setContextPath("/") + if (livyConf.getBoolean(LivyConf.THRIFT_XSRF_FILTER_ENABLED)) { + // Filtering does not work here currently, doing filter in ThriftHttpServlet + debug("XSRF filter enabled") + } else { + warn("XSRF filter disabled") + } + val httpPath = getHttpPath(livyConf.get(LivyConf.THRIFT_HTTP_PATH)) + if (livyConf.getBoolean(LivyConf.THRIFT_XSRF_FILTER_ENABLED)) { + val gzipHandler = new GzipHandler + gzipHandler.setHandler(context) + gzipHandler.addIncludedMethods(HttpMethod.POST) + gzipHandler.addIncludedMimeTypes(ThriftHttpCLIService.APPLICATION_THRIFT) + server.setHandler(gzipHandler) + } else { + server.setHandler(context) + } + context.addServlet(new ServletHolder(thriftHttpServlet), httpPath) + // TODO: check defaults: maxTimeout, keepalive, maxBodySize, + // bodyRecieveDuration, etc. + // Finally, start the server + server.start() + info(s"Started ${classOf[ThriftHttpCLIService].getSimpleName} in $schemeName mode on port " + + s"$portNum path=$httpPath with $minWorkerThreads...$maxWorkerThreads worker threads") + } catch { + case e: Exception => throw new RuntimeException("Failed to init HttpServer", e) + } + } + + override def run(): Unit = { + try { + server.join() + } catch { + case t: InterruptedException => + // This is likely a shutdown + info(s"Caught ${t.getClass.getSimpleName}. Shutting down thrift server.") + case t: Throwable => + error(s"Exception caught by ${this.getClass.getSimpleName}. Exiting.", t) + System.exit(-1) + } + } + + /** + * The config parameter can be like "path", "/path", "/path/", "path/*", + * "/path1/path2/*" and so on. httpPath should end up as "/*", "/path/*" or + * "/path1/../pathN/*" + */ + private def getHttpPath(httpPath: String): String = { + Option(httpPath) match { + case None | Some("") => "/*" + case Some(path) => + val withStartingSlash = if (!path.startsWith("/")) { + s"/$path" + } else { + path + } + if (httpPath.endsWith("/")) { + s"$withStartingSlash*" + } else if (!httpPath.endsWith("/*")) { + s"$withStartingSlash/*" + } else { + withStartingSlash + } + } + } + + protected def stopServer(): Unit = { + if ((server != null) && server.isStarted) { + try { + server.stop() + server = null + info("Thrift HTTP server has been stopped") + } catch { + case e: Exception => + error("Error stopping HTTP server: ", e) + } + } + } +} + +object ThriftHttpCLIService { + private val APPLICATION_THRIFT = "application/x-thrift" +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftHttpServlet.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftHttpServlet.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftHttpServlet.scala new file mode 100644 index 0000000..ae83917 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThriftHttpServlet.scala @@ -0,0 +1,503 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.cli + +import java.io.IOException +import java.security.{PrivilegedExceptionAction, SecureRandom} +import javax.servlet.ServletException +import javax.servlet.http.{Cookie, HttpServletRequest, HttpServletResponse} +import javax.ws.rs.core.NewCookie + +import scala.collection.JavaConverters._ + +import org.apache.commons.codec.binary.{Base64, StringUtils} +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.authentication.util.KerberosName +import org.apache.hive.service.CookieSigner +import org.apache.hive.service.auth.{HiveAuthConstants, HttpAuthenticationException, HttpAuthUtils} +import org.apache.hive.service.auth.HiveAuthConstants.AuthTypes +import org.apache.hive.service.cli.HiveSQLException +import org.apache.thrift.TProcessor +import org.apache.thrift.protocol.TProtocolFactory +import org.apache.thrift.server.TServlet +import org.ietf.jgss.{GSSContext, GSSCredential, GSSException, GSSManager, Oid} + +import org.apache.livy.{LivyConf, Logging} +import org.apache.livy.thriftserver.SessionInfo +import org.apache.livy.thriftserver.auth.{AuthenticationProvider, AuthFactory} + +/** + * This class is a porting of the parts we use from `ThriftHttpServlet` by Hive. + */ +class ThriftHttpServlet( + processor: TProcessor, + protocolFactory: TProtocolFactory, + val authType: String, + val serviceUGI: UserGroupInformation, + val httpUGI: UserGroupInformation, + val authFactory: AuthFactory, + val livyConf: LivyConf) extends TServlet(processor, protocolFactory) with Logging { + + private val isCookieAuthEnabled = livyConf.getBoolean(LivyConf.THRIFT_HTTP_COOKIE_AUTH_ENABLED) + + // Class members for cookie based authentication. + private val signer: CookieSigner = if (isCookieAuthEnabled) { + // Generate the signer with secret. + val secret = ThriftHttpServlet.RAN.nextLong.toString + debug("Using the random number as the secret for cookie generation " + secret) + new CookieSigner(secret.getBytes()) + } else { + null + } + + private val cookieDomain = livyConf.get(LivyConf.THRIFT_HTTP_COOKIE_DOMAIN) + private val cookiePath = livyConf.get(LivyConf.THRIFT_HTTP_COOKIE_PATH) + private val cookieMaxAge = + (livyConf.getTimeAsMs(LivyConf.THRIFT_HTTP_COOKIE_MAX_AGE) / 1000).toInt + private val isCookieSecure = livyConf.getBoolean(LivyConf.THRIFT_USE_SSL) + private val isHttpOnlyCookie = livyConf.getBoolean(LivyConf.THRIFT_HTTP_COOKIE_IS_HTTPONLY) + private val xsrfFilterEnabled = livyConf.getBoolean(LivyConf.THRIFT_XSRF_FILTER_ENABLED) + + @throws[IOException] + @throws[ServletException] + override protected def doPost( + request: HttpServletRequest, response: HttpServletResponse): Unit = { + var clientUserName: String = null + var requireNewCookie: Boolean = false + + try { + if (xsrfFilterEnabled) { + val continueProcessing = ThriftHttpServlet.doXsrfFilter(request, response) + if (!continueProcessing) { + warn("Request did not have valid XSRF header, rejecting.") + return + } + } + // If the cookie based authentication is already enabled, parse the + // request and validate the request cookies. + if (isCookieAuthEnabled) { + clientUserName = validateCookie(request) + requireNewCookie = clientUserName == null + if (requireNewCookie) { + info("Could not validate cookie sent, will try to generate a new cookie") + } + } + // If the cookie based authentication is not enabled or the request does + // not have a valid cookie, use the kerberos or password based authentication + // depending on the server setup. + if (clientUserName == null) { + // For a kerberos setup + if (ThriftHttpServlet.isKerberosAuthMode(authType)) { + val delegationToken = request.getHeader(ThriftHttpServlet.HIVE_DELEGATION_TOKEN_HEADER) + // Each http request must have an Authorization header + if ((delegationToken != null) && (!delegationToken.isEmpty)) { + clientUserName = doTokenAuth(request, response) + } else { + clientUserName = doKerberosAuth(request) + } + } else { + // For password based authentication + clientUserName = doPasswdAuth(request, authType) + } + } + debug(s"Client username: $clientUserName") + + // Set the thread local username to be used for doAs if true + SessionInfo.setUserName(clientUserName) + + // find proxy user if any from query param + val doAsQueryParam = ThriftHttpServlet.getDoAsQueryParam(request.getQueryString) + if (doAsQueryParam != null) { + SessionInfo.setProxyUserName(doAsQueryParam) + } + + val clientIpAddress = request.getRemoteAddr + debug("Client IP Address: " + clientIpAddress) + // Set the thread local ip address + SessionInfo.setIpAddress(clientIpAddress) + + // get forwarded hosts address + val forwardedFor = request.getHeader(ThriftHttpServlet.X_FORWARDED_FOR) + if (forwardedFor != null) { + debug(s"${ThriftHttpServlet.X_FORWARDED_FOR}:$forwardedFor") + SessionInfo.setForwardedAddresses(forwardedFor.split(",").toList.asJava) + } else { + SessionInfo.setForwardedAddresses(List.empty.asJava) + } + + // Generate new cookie and add it to the response + if (requireNewCookie && !authType.equalsIgnoreCase(AuthTypes.NOSASL.toString)) { + val cookieToken = HttpAuthUtils.createCookieToken(clientUserName) + val hs2Cookie = createCookie(signer.signCookie(cookieToken)) + + if (isHttpOnlyCookie) { + response.setHeader("SET-COOKIE", ThriftHttpServlet.getHttpOnlyCookieHeader(hs2Cookie)) + } else { + response.addCookie(hs2Cookie) + } + info("Cookie added for clientUserName " + clientUserName) + } + super.doPost(request, response); + } catch { + case e: HttpAuthenticationException => + error("Error: ", e) + // Send a 401 to the client + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + if(ThriftHttpServlet.isKerberosAuthMode(authType)) { + response.addHeader(HttpAuthUtils.WWW_AUTHENTICATE, HttpAuthUtils.NEGOTIATE) + } + // scalastyle:off println + response.getWriter.println("Authentication Error: " + e.getMessage) + // scalastyle:on println + } finally { + // Clear the thread locals + SessionInfo.clearUserName() + SessionInfo.clearIpAddress() + SessionInfo.clearProxyUserName() + SessionInfo.clearForwardedAddresses() + } + } + + /** + * Retrieves the client name from cookieString. If the cookie does not correspond to a valid + * client, the function returns null. + * @param cookies HTTP Request cookies. + * @return Client Username if cookieString has a HS2 Generated cookie that is currently valid. + * Else, returns null. + */ + private def getClientNameFromCookie(cookies: Array[Cookie]): String = { + // Following is the main loop which iterates through all the cookies send by the client. + // The HS2 generated cookies are of the format hive.server2.auth=<value> + // A cookie which is identified as a hiveserver2 generated cookie is validated by calling + // signer.verifyAndExtract(). If the validation passes, send the username for which the cookie + // is validated to the caller. If no client side cookie passes the validation, return null to + // the caller. + cookies.filter(_.equals(ThriftHttpServlet.AUTH_COOKIE)).foreach { cookie => + val value = signer.verifyAndExtract(cookie.getValue) + if (value != null) { + val userName = HttpAuthUtils.getUserNameFromCookieToken(value) + if (userName == null) { + warn("Invalid cookie token " + value) + } else { + // We have found a valid cookie in the client request. + if (logger.isDebugEnabled()) { + debug("Validated the cookie for user " + userName) + } + return userName + } + } + } + // No valid generated cookies found, return null + null + } + + /** + * Convert cookie array to human readable cookie string + * @param cookies Cookie Array + * @return String containing all the cookies separated by a newline character. + * Each cookie is of the format [key]=[value] + */ + private def toCookieStr(cookies: Array[Cookie]): String = { + cookies.map(c => s"${c.getName} = ${c.getValue} ;\n").mkString + } + + /** + * Validate the request cookie. This function iterates over the request cookie headers + * and finds a cookie that represents a valid client/server session. If it finds one, it + * returns the client name associated with the session. Else, it returns null. + * @param request The HTTP Servlet Request send by the client + * @return Client Username if the request has valid HS2 cookie, else returns null + */ + private def validateCookie(request: HttpServletRequest): String = { + // Find all the valid cookies associated with the request. + val cookies = request.getCookies + + if (cookies == null) { + if (logger.isDebugEnabled()) { + debug("No valid cookies associated with the request " + request) + } + null + } else { + if (logger.isDebugEnabled()) { + debug("Received cookies: " + toCookieStr(cookies)) + } + getClientNameFromCookie(cookies) + } + } + + /** + * Generate a server side cookie given the cookie value as the input. + * @param str Input string token. + * @return The generated cookie. + */ + private def createCookie(str: String): Cookie = { + if (logger.isDebugEnabled()) { + debug(s"Cookie name = ${ThriftHttpServlet.AUTH_COOKIE} value = $str") + } + val cookie = new Cookie(ThriftHttpServlet.AUTH_COOKIE, str) + + cookie.setMaxAge(cookieMaxAge) + if (cookieDomain != null) { + cookie.setDomain(cookieDomain) + } + if (cookiePath != null) { + cookie.setPath(cookiePath) + } + cookie.setSecure(isCookieSecure) + cookie + } + + + /** + * Do the authentication (LDAP/PAM not yet supported) + */ + private def doPasswdAuth(request: HttpServletRequest, authType: String): String = { + val userName = getUsername(request, authType) + // No-op when authType is NOSASL + if (!authType.equalsIgnoreCase(HiveAuthConstants.AuthTypes.NOSASL.toString)) { + try { + val provider = AuthenticationProvider.getAuthenticationProvider(authType, livyConf) + provider.Authenticate(userName, getPassword(request, authType)) + } catch { + case e: Exception => throw new HttpAuthenticationException(e) + } + } + userName + } + + private def doTokenAuth(request: HttpServletRequest, response: HttpServletResponse): String = { + val tokenStr = request.getHeader(ThriftHttpServlet.HIVE_DELEGATION_TOKEN_HEADER) + try { + authFactory.verifyDelegationToken(tokenStr) + } catch { + case e: HiveSQLException => throw new HttpAuthenticationException(e); + } + } + + /** + * Do the GSS-API kerberos authentication. We already have a logged in subject in the form of + * serviceUGI, which GSS-API will extract information from. + * In case of a SPNego request we use the httpUGI, for the authenticating service tickets. + */ + private def doKerberosAuth(request: HttpServletRequest): String = { + // Try authenticating with the http/_HOST principal + if (httpUGI != null) { + try { + return httpUGI.doAs(new HttpKerberosServerAction(request, httpUGI, authType)) + } catch { + case _: Exception => + info("Failed to authenticate with http/_HOST kerberos principal, trying with " + + "livy/_HOST kerberos principal") + } + } + // Now try with livy/_HOST principal + try { + serviceUGI.doAs(new HttpKerberosServerAction(request, serviceUGI, authType)) + } catch { + case e: Exception => + error("Failed to authenticate with livy/_HOST kerberos principal") + throw new HttpAuthenticationException(e) + } + } + + private def getUsername(request: HttpServletRequest, authType: String): String = { + val creds = getAuthHeaderTokens(request, authType) + // Username must be present + if (creds(0) == null || creds(0).isEmpty) { + throw new HttpAuthenticationException("Authorization header received " + + "from the client does not contain username.") + } + creds(0) + } + + private def getPassword(request: HttpServletRequest, authType: String): String = { + val creds = getAuthHeaderTokens(request, authType) + // Password must be present + if (creds(1) == null || creds(1).isEmpty) { + throw new HttpAuthenticationException("Authorization header received " + + "from the client does not contain password.") + } + creds(1) + } + + private def getAuthHeaderTokens(request: HttpServletRequest, authType: String): Array[String] = { + val authHeaderBase64 = ThriftHttpServlet.getAuthHeader(request, authType) + val authHeaderString = StringUtils.newStringUtf8( + Base64.decodeBase64(authHeaderBase64.getBytes())) + authHeaderString.split(":") + } +} + + +object ThriftHttpServlet extends Logging { + private val XSRF_HEADER_DEFAULT = "X-XSRF-HEADER" + private val XSRF_METHODS_TO_IGNORE_DEFAULT = Set("GET", "OPTIONS", "HEAD", "TRACE") + val AUTH_COOKIE = "hive.server2.auth" + val RAN: SecureRandom = new SecureRandom() + val HIVE_DELEGATION_TOKEN_HEADER: String = "X-Hive-Delegation-Token" + val X_FORWARDED_FOR: String = "X-Forwarded-For" + + /** + * Generate httponly cookie from HS2 cookie + * @param cookie HS2 generated cookie + * @return The httponly cookie + */ + private def getHttpOnlyCookieHeader(cookie: Cookie): String = { + val newCookie = new NewCookie( + cookie.getName, + cookie.getValue, + cookie.getPath, + cookie.getDomain, + cookie.getVersion, + cookie.getComment, + cookie.getMaxAge, + cookie.getSecure) + newCookie + "; HttpOnly" + } + + private def getDoAsQueryParam(queryString: String): String = { + if (logger.isDebugEnabled()) { + debug("URL query string:" + queryString) + } + Option(queryString).flatMap { qs => + val params = javax.servlet.http.HttpUtils.parseQueryString(qs) + params.keySet().asScala.find(_.equalsIgnoreCase("doAs")).map(params.get(_).head) + }.orNull + } + + private def doXsrfFilter(request: HttpServletRequest, response: HttpServletResponse): Boolean = { + if (XSRF_METHODS_TO_IGNORE_DEFAULT.contains(request.getMethod) || + request.getHeader(XSRF_HEADER_DEFAULT) != null) { + true + } else { + response.sendError( + HttpServletResponse.SC_BAD_REQUEST, + "Missing Required Header for Vulnerability Protection") + // scalastyle:off println + response.getWriter.println( + "XSRF filter denial, requests must contain header : " + XSRF_HEADER_DEFAULT) + // scalastyle:on println + false + } + } + + /** + * Returns the base64 encoded auth header payload + */ + @throws[HttpAuthenticationException] + private[cli] def getAuthHeader(request: HttpServletRequest, authType: String): String = { + val authHeader = request.getHeader(HttpAuthUtils.AUTHORIZATION) + // Each http request must have an Authorization header + if (authHeader == null || authHeader.isEmpty) { + throw new HttpAuthenticationException("Authorization header received " + + "from the client is empty.") + } + + val beginIndex = if (isKerberosAuthMode(authType)) { + (HttpAuthUtils.NEGOTIATE + " ").length() + } else { + (HttpAuthUtils.BASIC + " ").length() + } + val authHeaderBase64String = authHeader.substring(beginIndex) + // Authorization header must have a payload + if (authHeaderBase64String == null || authHeaderBase64String.isEmpty) { + throw new HttpAuthenticationException("Authorization header received " + + "from the client does not contain any data.") + } + authHeaderBase64String + } + + private def isKerberosAuthMode(authType: String): Boolean = { + authType.equalsIgnoreCase(AuthTypes.KERBEROS.toString) + } +} + +class HttpKerberosServerAction( + val request: HttpServletRequest, + val serviceUGI: UserGroupInformation, + authType: String) extends PrivilegedExceptionAction[String] { + + @throws[HttpAuthenticationException] + override def run(): String = { + // Get own Kerberos credentials for accepting connection + val manager = GSSManager.getInstance() + var gssContext: Option[GSSContext] = None + val serverPrincipal = getPrincipalWithoutRealm(serviceUGI.getUserName) + try { + // This Oid for Kerberos GSS-API mechanism. + val kerberosMechOid = new Oid("1.2.840.113554.1.2.2") + // Oid for SPNego GSS-API mechanism. + val spnegoMechOid = new Oid("1.3.6.1.5.5.2") + // Oid for kerberos principal name + val krb5PrincipalOid = new Oid("1.2.840.113554.1.2.2.1") + + // GSS name for server + val serverName = manager.createName(serverPrincipal, krb5PrincipalOid) + + // GSS credentials for server + val serverCreds = manager.createCredential(serverName, + GSSCredential.DEFAULT_LIFETIME, + Array[Oid](kerberosMechOid, spnegoMechOid), + GSSCredential.ACCEPT_ONLY) + + // Create a GSS context + gssContext = Some(manager.createContext(serverCreds)) + // Get service ticket from the authorization header + val serviceTicketBase64 = ThriftHttpServlet.getAuthHeader(request, authType) + val inToken = Base64.decodeBase64(serviceTicketBase64.getBytes()) + gssContext.get.acceptSecContext(inToken, 0, inToken.length) + // Authenticate or deny based on its context completion + if (!gssContext.get.isEstablished) { + throw new HttpAuthenticationException("Kerberos authentication failed: " + + "unable to establish context with the service ticket provided by the client.") + } else { + getPrincipalWithoutRealmAndHost(gssContext.get.getSrcName.toString) + } + } catch { + case e: GSSException => + throw new HttpAuthenticationException("Kerberos authentication failed: ", e) + } finally { + gssContext.foreach { ctx => + try { + ctx.dispose() + } catch { + case _: GSSException => // No-op + } + } + } + } + + private def getPrincipalWithoutRealm(fullPrincipal: String): String = { + val fullKerberosName = new KerberosName(fullPrincipal) + val serviceName = fullKerberosName.getServiceName + val hostName = fullKerberosName.getHostName + if (hostName != null) { + serviceName + "/" + hostName + } else { + serviceName + } + } + + private def getPrincipalWithoutRealmAndHost(fullPrincipal: String): String = { + try { + new KerberosName(fullPrincipal).getShortName + } catch { + case e: IOException => throw new HttpAuthenticationException(e) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetCatalogsOperation.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetCatalogsOperation.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetCatalogsOperation.scala new file mode 100644 index 0000000..57687b0 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetCatalogsOperation.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.operation + +import org.apache.hive.service.cli._ + +import org.apache.livy.Logging +import org.apache.livy.thriftserver.serde.ThriftResultSet +import org.apache.livy.thriftserver.types.{BasicDataType, Field, Schema} + +class GetCatalogsOperation(sessionHandle: SessionHandle) + extends MetadataOperation(sessionHandle, OperationType.GET_TABLE_TYPES) with Logging { + + protected val rowSet = ThriftResultSet.apply(GetCatalogsOperation.SCHEMA, protocolVersion) + + info("Starting GetCatalogsOperation") + + @throws(classOf[HiveSQLException]) + override def runInternal(): Unit = { + setState(OperationState.RUNNING) + info("Fetching table type metadata") + try { + // catalogs are actually not supported in spark, so this is a no-op + setState(OperationState.FINISHED) + info("Fetching table type metadata has been successfully finished") + } catch { + case e: Throwable => + setState(OperationState.ERROR) + throw new HiveSQLException(e) + } + } + + @throws(classOf[HiveSQLException]) + override def getResultSetSchema: Schema = { + assertState(Seq(OperationState.FINISHED)) + GetCatalogsOperation.SCHEMA + } +} + +object GetCatalogsOperation { + val SCHEMA = Schema(Field("TABLE_CAT", BasicDataType("string"), + "Catalog name. NULL if not applicable.")) +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetTableTypesOperation.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetTableTypesOperation.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetTableTypesOperation.scala new file mode 100644 index 0000000..aff7ace --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetTableTypesOperation.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.operation + +import org.apache.hive.service.cli._ + +import org.apache.livy.Logging +import org.apache.livy.thriftserver.serde.ThriftResultSet +import org.apache.livy.thriftserver.types.{BasicDataType, Field, Schema} + +class GetTableTypesOperation(sessionHandle: SessionHandle) + extends MetadataOperation(sessionHandle, OperationType.GET_TABLE_TYPES) with Logging { + + protected val rowSet = ThriftResultSet.apply(GetTableTypesOperation.SCHEMA, protocolVersion) + + info("Starting GetTableTypesOperation") + + @throws(classOf[HiveSQLException]) + override def runInternal(): Unit = { + setState(OperationState.RUNNING) + info("Fetching table type metadata") + try { + Seq("TABLE", "VIEW").foreach { t => + rowSet.addRow(Array(t)) + } + setState(OperationState.FINISHED) + info("Fetching table type metadata has been successfully finished") + } catch { + case e: Throwable => + setState(OperationState.ERROR) + throw new HiveSQLException(e) + } + } + + @throws(classOf[HiveSQLException]) + override def getResultSetSchema: Schema = { + assertState(Seq(OperationState.FINISHED)) + GetTableTypesOperation.SCHEMA + } +} + +object GetTableTypesOperation { + val SCHEMA = Schema(Field("TABLE_TYPE", BasicDataType("string"), "Table type name.")) +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetTypeInfoOperation.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetTypeInfoOperation.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetTypeInfoOperation.scala new file mode 100644 index 0000000..a587445 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/GetTypeInfoOperation.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.operation + +import java.sql.{DatabaseMetaData, Types} + +import org.apache.hive.service.cli.{HiveSQLException, OperationState, OperationType, SessionHandle} + +import org.apache.livy.Logging +import org.apache.livy.thriftserver.serde.ThriftResultSet +import org.apache.livy.thriftserver.types.{BasicDataType, Field, Schema} + +sealed case class TypeInfo(name: String, sqlType: Int, precision: Option[Int], + caseSensitive: Boolean, searchable: Short, unsignedAttribute: Boolean, numPrecRadix: Option[Int]) + +class GetTypeInfoOperation(sessionHandle: SessionHandle) + extends MetadataOperation(sessionHandle, OperationType.GET_TYPE_INFO) with Logging { + + info("Starting GetTypeInfoOperation") + + protected val rowSet = ThriftResultSet.apply(GetTypeInfoOperation.SCHEMA, protocolVersion) + + @throws(classOf[HiveSQLException]) + override def runInternal(): Unit = { + setState(OperationState.RUNNING) + info("Fetching type info metadata") + try { + GetTypeInfoOperation.TYPES.foreach { t => + val data = Array[Any]( + t.name, + t.sqlType, + t.precision, + null, // LITERAL_PREFIX + null, // LITERAL_SUFFIX + null, // CREATE_PARAMS + DatabaseMetaData.typeNullable, // All types are nullable + t.caseSensitive, + t.unsignedAttribute, + false, // FIXED_PREC_SCALE + false, // AUTO_INCREMENT + null, // LOCAL_TYPE_NAME + 0, // MINIMUM_SCALE + 0, // MAXIMUM_SCALE + null, // SQL_DATA_TYPE + null, // SQL_DATETIME_SUB + t.numPrecRadix) + rowSet.addRow(data) + } + setState(OperationState.FINISHED) + info("Fetching type info metadata has been successfully finished") + } catch { + case e: Throwable => + setState(OperationState.ERROR) + throw new HiveSQLException(e) + } + } + + @throws(classOf[HiveSQLException]) + override def getResultSetSchema: Schema = { + assertState(Seq(OperationState.FINISHED)) + GetTypeInfoOperation.SCHEMA + } +} + +object GetTypeInfoOperation { + val SCHEMA = Schema( + Field("TYPE_NAME", BasicDataType("string"), "Type name"), + Field("DATA_TYPE", BasicDataType("integer"), "SQL data type from java.sql.Types"), + Field("PRECISION", BasicDataType("integer"), "Maximum precision"), + Field("LITERAL_PREFIX", BasicDataType("string"), + "Prefix used to quote a literal (may be null)"), + Field("LITERAL_SUFFIX", BasicDataType("string"), + "Suffix used to quote a literal (may be null)"), + Field("CREATE_PARAMS", BasicDataType("string"), + "Parameters used in creating the type (may be null)"), + Field("NULLABLE", BasicDataType("short"), "Can you use NULL for this type"), + Field("CASE_SENSITIVE", BasicDataType("boolean"), "Is it case sensitive"), + Field("SEARCHABLE", BasicDataType("short"), "Can you use \"WHERE\" based on this type"), + Field("UNSIGNED_ATTRIBUTE", BasicDataType("boolean"), "Is it unsigned"), + Field("FIXED_PREC_SCALE", BasicDataType("boolean"), "Can it be a money value"), + Field("AUTO_INCREMENT", BasicDataType("boolean"), + "Can it be used for an auto-increment value"), + Field("LOCAL_TYPE_NAME", BasicDataType("string"), + "Localized version of type name (may be null)"), + Field("MINIMUM_SCALE", BasicDataType("short"), "Minimum scale supported"), + Field("MAXIMUM_SCALE", BasicDataType("short"), "Maximum scale supported"), + Field("SQL_DATA_TYPE", BasicDataType("integer"), "Unused"), + Field("SQL_DATETIME_SUB", BasicDataType("integer"), "Unused"), + Field("NUM_PREC_RADIX", BasicDataType("integer"), "Usually 2 or 10")) + + import DatabaseMetaData._ + + val TYPES = Seq( + TypeInfo("void", Types.NULL, None, false, typePredNone.toShort, true, None), + TypeInfo("boolean", Types.BOOLEAN, None, false, typePredBasic.toShort, true, None), + TypeInfo("byte", Types.TINYINT, Some(3), false, typePredBasic.toShort, false, Some(10)), + TypeInfo("short", Types.SMALLINT, Some(5), false, typePredBasic.toShort, false, Some(10)), + TypeInfo("integer", Types.INTEGER, Some(10), false, typePredBasic.toShort, false, Some(10)), + TypeInfo("long", Types.BIGINT, Some(19), false, typePredBasic.toShort, false, Some(10)), + TypeInfo("float", Types.FLOAT, Some(7), false, typePredBasic.toShort, false, Some(10)), + TypeInfo("double", Types.DOUBLE, Some(15), false, typePredBasic.toShort, false, Some(10)), + TypeInfo("date", Types.DATE, None, false, typePredBasic.toShort, true, None), + TypeInfo("timestamp", Types.TIMESTAMP, None, false, typePredBasic.toShort, true, None), + TypeInfo("string", Types.VARCHAR, None, true, typeSearchable.toShort, true, None), + TypeInfo("binary", Types.BINARY, None, false, typePredBasic.toShort, true, None), + TypeInfo("decimal", Types.DECIMAL, Some(38), false, typePredBasic.toShort, false, Some(10)), + TypeInfo("array", Types.ARRAY, None, false, typePredBasic.toShort, true, None), + TypeInfo("map", Types.OTHER, None, false, typePredNone.toShort, true, None), + TypeInfo("struct", Types.STRUCT, None, false, typePredBasic.toShort, true, None), + TypeInfo("udt", Types.OTHER, None, false, typePredNone.toShort, true, None)) +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/MetadataOperation.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/MetadataOperation.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/MetadataOperation.scala new file mode 100644 index 0000000..4db3929 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/MetadataOperation.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.operation + +import org.apache.hive.service.cli.{FetchOrientation, HiveSQLException, OperationState, OperationType, SessionHandle} + +import org.apache.livy.thriftserver.serde.ThriftResultSet + +abstract class MetadataOperation(sessionHandle: SessionHandle, opType: OperationType) + extends Operation(sessionHandle, opType) { + setHasResultSet(true) + + protected def rowSet: ThriftResultSet + + @throws[HiveSQLException] + override def close(): Unit = { + setState(OperationState.CLOSED) + } + + @throws[HiveSQLException] + override def cancel(stateAfterCancel: OperationState): Unit = { + throw new UnsupportedOperationException("MetadataOperation.cancel()") + } + + @throws(classOf[HiveSQLException]) + override def getNextRowSet(orientation: FetchOrientation, maxRows: Long): ThriftResultSet = { + assertState(Seq(OperationState.FINISHED)) + validateFetchOrientation(orientation) + if (orientation.equals(FetchOrientation.FETCH_FIRST)) { + rowSet.setRowOffset(0) + } + rowSet + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/Operation.scala ---------------------------------------------------------------------- diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/Operation.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/Operation.scala new file mode 100644 index 0000000..2a139e3 --- /dev/null +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/operation/Operation.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.operation + +import java.util +import java.util.concurrent.Future + +import org.apache.hive.service.cli.FetchOrientation +import org.apache.hive.service.cli.HiveSQLException +import org.apache.hive.service.cli.OperationHandle +import org.apache.hive.service.cli.OperationState +import org.apache.hive.service.cli.OperationType +import org.apache.hive.service.cli.SessionHandle +import org.apache.hive.service.rpc.thrift.TProtocolVersion + +import org.apache.livy.Logging +import org.apache.livy.thriftserver.serde.ThriftResultSet +import org.apache.livy.thriftserver.types.Schema + +abstract class Operation( + val sessionHandle: SessionHandle, + val opType: OperationType) extends Logging { + + @volatile private var state = OperationState.INITIALIZED + val opHandle = new OperationHandle(opType, sessionHandle.getProtocolVersion) + + protected var resultSetPresent: Boolean = false + @volatile private var operationException: HiveSQLException = _ + @volatile protected var backgroundHandle: Future[_] = _ + + private val beginTime = System.currentTimeMillis() + @volatile private var lastAccessTime = beginTime + + protected var operationStart: Long = _ + protected var operationComplete: Long = _ + + def getBackgroundHandle: Future[_] = backgroundHandle + + protected def setBackgroundHandle(backgroundHandle: Future[_]): Unit = { + this.backgroundHandle = backgroundHandle + } + + def shouldRunAsync: Boolean = false + + def protocolVersion: TProtocolVersion = opHandle.getProtocolVersion + + def getStatus: OperationStatus = { + // TODO: get and return also the task status + OperationStatus(state, operationStart, operationComplete, resultSetPresent, operationException) + } + + def hasResultSet: Boolean = resultSetPresent + + protected def setHasResultSet(hasResultSet: Boolean): Unit = { + this.resultSetPresent = hasResultSet + opHandle.setHasResultSet(hasResultSet) + } + + @throws[HiveSQLException] + protected def setState(newState: OperationState): OperationState = { + state.validateTransition(newState) + val prevState = state + state = newState + onNewState(state, prevState) + this.lastAccessTime = System.currentTimeMillis() + state + } + + def isTimedOut(current: Long, operationTimeout: Long): Boolean = { + if (operationTimeout == 0) { + false + } else if (operationTimeout > 0) { + // check only when it's in terminal state + state.isTerminal && lastAccessTime + operationTimeout <= current + } else { + lastAccessTime + (- operationTimeout) <= current + } + } + + protected def setOperationException(operationException: HiveSQLException): Unit = { + this.operationException = operationException + } + + protected def assertState(states: Seq[OperationState]): Unit = { + if (!states.contains(state)) { + throw new HiveSQLException(s"Expected states: $states, but found $state") + } + this.lastAccessTime = System.currentTimeMillis() + } + + def isDone: Boolean = state.isTerminal + + /** + * Implemented by subclass of Operation class to execute specific behaviors. + */ + @throws[HiveSQLException] + protected def runInternal(): Unit + + // As of now, run does nothing else than calling runInternal. This may change in the future as + // additional operation before and after running may be added (as it happens in Hive). + @throws[HiveSQLException] + def run(): Unit = runInternal() + + @throws[HiveSQLException] + def cancel(stateAfterCancel: OperationState): Unit + + @throws[HiveSQLException] + def close(): Unit + + @throws[HiveSQLException] + def getResultSetSchema: Schema + + @throws[HiveSQLException] + def getNextRowSet(orientation: FetchOrientation, maxRows: Long): ThriftResultSet + + /** + * Verify if the given fetch orientation is part of the supported orientation types. + */ + @throws[HiveSQLException] + protected def validateFetchOrientation(orientation: FetchOrientation): Unit = { + if (!Operation.DEFAULT_FETCH_ORIENTATION_SET.contains(orientation)) { + throw new HiveSQLException( + s"The fetch type $orientation is not supported for this resultset", "HY106") + } + } + + def getBeginTime: Long = beginTime + + protected def getState: OperationState = state + + protected def onNewState(newState: OperationState, prevState: OperationState): Unit = { + newState match { + case OperationState.RUNNING => + markOperationStartTime() + case OperationState.ERROR | OperationState.FINISHED | OperationState.CANCELED => + markOperationCompletedTime() + case _ => // Do nothing. + } + } + + def getOperationComplete: Long = operationComplete + + def getOperationStart: Long = operationStart + + protected def markOperationStartTime(): Unit = { + operationStart = System.currentTimeMillis() + } + + protected def markOperationCompletedTime(): Unit = { + operationComplete = System.currentTimeMillis() + } +} + +object Operation { + val DEFAULT_FETCH_ORIENTATION_SET: util.EnumSet[FetchOrientation] = + util.EnumSet.of(FetchOrientation.FETCH_NEXT, FetchOrientation.FETCH_FIRST) + val DEFAULT_FETCH_ORIENTATION = FetchOrientation.FETCH_NEXT +} + +case class OperationStatus( + state: OperationState, + operationStarted: Long, + operationCompleted: Long, + hasResultSet: Boolean, + operationException: HiveSQLException)