Github user vanzin commented on a diff in the pull request: https://github.com/apache/incubator-livy/pull/117#discussion_r221764487 --- Diff: thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthBridgeServer.scala --- @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.thriftserver.auth + +import java.io.IOException +import java.net.InetAddress +import java.security.{PrivilegedAction, PrivilegedExceptionAction} +import java.util +import javax.security.auth.callback.{Callback, CallbackHandler, NameCallback, PasswordCallback, UnsupportedCallbackException} +import javax.security.sasl.{AuthorizeCallback, RealmCallback, SaslServer} + +import org.apache.commons.codec.binary.Base64 +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.{SaslRpcServer, UserGroupInformation} +import org.apache.hadoop.security.SaslRpcServer.AuthMethod +import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod +import org.apache.hadoop.security.token.SecretManager.InvalidToken +import org.apache.thrift.{TException, TProcessor} +import org.apache.thrift.protocol.TProtocol +import org.apache.thrift.transport.{TSaslServerTransport, TSocket, TTransport, TTransportException, TTransportFactory} + +import org.apache.livy.Logging + +/** + * The class is taken from Hive's `HadoopThriftAuthBridge.Server`. It bridges Thrift's SASL + * transports to Hadoop's SASL callback handlers and authentication classes. + * + * This class is based on Hive's one. + */ +class AuthBridgeServer(private val secretManager: LivyDelegationTokenSecretManager) { + private val ugi = try { + UserGroupInformation.getCurrentUser + } catch { + case ioe: IOException => throw new TTransportException(ioe) + } + + /** + * Create a TTransportFactory that, upon connection of a client socket, + * negotiates a Kerberized SASL transport. The resulting TTransportFactory + * can be passed as both the input and output transport factory when + * instantiating a TThreadPoolServer, for example. + * + * @param saslProps Map of SASL properties + */ + @throws[TTransportException] + def createTransportFactory(saslProps: util.Map[String, String]): TTransportFactory = { + val transFactory: TSaslServerTransport.Factory = createSaslServerTransportFactory(saslProps) + new TUGIAssumingTransportFactory(transFactory, ugi) + } + + /** + * Create a TSaslServerTransport.Factory that, upon connection of a client + * socket, negotiates a Kerberized SASL transport. + * + * @param saslProps Map of SASL properties + */ + @throws[TTransportException] + def createSaslServerTransportFactory( + saslProps: util.Map[String, String]): TSaslServerTransport.Factory = { + // Parse out the kerberos principal, host, realm. + val kerberosName: String = ugi.getUserName + val names: Array[String] = SaslRpcServer.splitKerberosName(kerberosName) + if (names.length != 3) { + throw new TTransportException(s"Kerberos principal should have 3 parts: $kerberosName") + } + val transFactory: TSaslServerTransport.Factory = new TSaslServerTransport.Factory + transFactory.addServerDefinition(AuthMethod.KERBEROS.getMechanismName, + names(0), names(1), // two parts of kerberos principal + saslProps, + new SaslRpcServer.SaslGssCallbackHandler) + transFactory.addServerDefinition(AuthMethod.TOKEN.getMechanismName, + null, + SaslRpcServer.SASL_DEFAULT_REALM, + saslProps, + new SaslDigestCallbackHandler(secretManager)) + transFactory + } + + /** + * Wrap a TTransportFactory in such a way that, before processing any RPC, it + * assumes the UserGroupInformation of the user authenticated by + * the SASL transport. + */ + def wrapTransportFactory(transFactory: TTransportFactory): TTransportFactory = { + new TUGIAssumingTransportFactory(transFactory, ugi) + } + + /** + * Wrap a TProcessor in such a way that, before processing any RPC, it + * assumes the UserGroupInformation of the user authenticated by + * the SASL transport. + */ + def wrapProcessor(processor: TProcessor): TProcessor = { + new TUGIAssumingProcessor(processor, secretManager, true) + } + + /** + * Wrap a TProcessor to capture the client information like connecting userid, ip etc + */ + def wrapNonAssumingProcessor(processor: TProcessor): TProcessor = { + new TUGIAssumingProcessor(processor, secretManager, false) + } + + def getRemoteAddress: InetAddress = AuthBridgeServer.remoteAddress.get + + def getRemoteUser: String = AuthBridgeServer.remoteUser.get + + def getUserAuthMechanism: String = AuthBridgeServer.userAuthMechanism.get + +} + +/** + * A TransportFactory that wraps another one, but assumes a specified UGI + * before calling through. + * + * This is used on the server side to assume the server's Principal when accepting + * clients. + * + * This class is derived from Hive's one. + */ +private[auth] class TUGIAssumingTransportFactory( + val wrapped: TTransportFactory, val ugi: UserGroupInformation) extends TTransportFactory { + assert(wrapped != null) + assert(ugi != null) + + override def getTransport(trans: TTransport): TTransport = { + ugi.doAs(new PrivilegedAction[TTransport]() { + override def run: TTransport = wrapped.getTransport(trans) + }) + } +} + +/** + * CallbackHandler for SASL DIGEST-MD5 mechanism. + */ +// This code is pretty much completely based on Hadoop's SaslRpcServer.SaslDigestCallbackHandler - --- End diff -- Put this inside the javadoc comment?
---