Github user mgaido91 commented on a diff in the pull request:

    https://github.com/apache/incubator-livy/pull/117#discussion_r221890669
  
    --- Diff: 
thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthBridgeServer.scala
 ---
    @@ -0,0 +1,296 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.livy.thriftserver.auth
    +
    +import java.io.IOException
    +import java.net.InetAddress
    +import java.security.{PrivilegedAction, PrivilegedExceptionAction}
    +import java.util
    +import javax.security.auth.callback.{Callback, CallbackHandler, 
NameCallback, PasswordCallback, UnsupportedCallbackException}
    +import javax.security.sasl.{AuthorizeCallback, RealmCallback, SaslServer}
    +
    +import org.apache.commons.codec.binary.Base64
    +import org.apache.hadoop.fs.FileSystem
    +import org.apache.hadoop.security.{SaslRpcServer, UserGroupInformation}
    +import org.apache.hadoop.security.SaslRpcServer.AuthMethod
    +import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod
    +import org.apache.hadoop.security.token.SecretManager.InvalidToken
    +import org.apache.thrift.{TException, TProcessor}
    +import org.apache.thrift.protocol.TProtocol
    +import org.apache.thrift.transport.{TSaslServerTransport, TSocket, 
TTransport, TTransportException, TTransportFactory}
    +
    +import org.apache.livy.Logging
    +
    +/**
    + * The class is taken from Hive's `HadoopThriftAuthBridge.Server`. It 
bridges Thrift's SASL
    + * transports to Hadoop's SASL callback handlers and authentication 
classes.
    + *
    + * This class is based on Hive's one.
    + */
    +class AuthBridgeServer(private val secretManager: 
LivyDelegationTokenSecretManager) {
    +  private val ugi = try {
    +      UserGroupInformation.getCurrentUser
    +    } catch {
    +      case ioe: IOException => throw new TTransportException(ioe)
    +    }
    +
    +  /**
    +   * Create a TTransportFactory that, upon connection of a client socket,
    +   * negotiates a Kerberized SASL transport. The resulting 
TTransportFactory
    +   * can be passed as both the input and output transport factory when
    +   * instantiating a TThreadPoolServer, for example.
    +   *
    +   * @param saslProps Map of SASL properties
    +   */
    +  @throws[TTransportException]
    +  def createTransportFactory(saslProps: util.Map[String, String]): 
TTransportFactory = {
    +    val transFactory: TSaslServerTransport.Factory = 
createSaslServerTransportFactory(saslProps)
    +    new TUGIAssumingTransportFactory(transFactory, ugi)
    +  }
    +
    +  /**
    +   * Create a TSaslServerTransport.Factory that, upon connection of a 
client
    +   * socket, negotiates a Kerberized SASL transport.
    +   *
    +   * @param saslProps Map of SASL properties
    +   */
    +  @throws[TTransportException]
    +  def createSaslServerTransportFactory(
    +      saslProps: util.Map[String, String]): TSaslServerTransport.Factory = 
{
    +    // Parse out the kerberos principal, host, realm.
    +    val kerberosName: String = ugi.getUserName
    +    val names: Array[String] = 
SaslRpcServer.splitKerberosName(kerberosName)
    +    if (names.length != 3) {
    +      throw new TTransportException(s"Kerberos principal should have 3 
parts: $kerberosName")
    +    }
    +    val transFactory: TSaslServerTransport.Factory = new 
TSaslServerTransport.Factory
    +    transFactory.addServerDefinition(AuthMethod.KERBEROS.getMechanismName,
    +      names(0), names(1), // two parts of kerberos principal
    +      saslProps,
    +      new SaslRpcServer.SaslGssCallbackHandler)
    +    transFactory.addServerDefinition(AuthMethod.TOKEN.getMechanismName,
    +      null,
    +      SaslRpcServer.SASL_DEFAULT_REALM,
    +      saslProps,
    +      new SaslDigestCallbackHandler(secretManager))
    +    transFactory
    +  }
    +
    +  /**
    +   * Wrap a TTransportFactory in such a way that, before processing any 
RPC, it
    +   * assumes the UserGroupInformation of the user authenticated by
    +   * the SASL transport.
    +   */
    +  def wrapTransportFactory(transFactory: TTransportFactory): 
TTransportFactory = {
    +    new TUGIAssumingTransportFactory(transFactory, ugi)
    +  }
    +
    +  /**
    +   * Wrap a TProcessor in such a way that, before processing any RPC, it
    +   * assumes the UserGroupInformation of the user authenticated by
    +   * the SASL transport.
    +   */
    +  def wrapProcessor(processor: TProcessor): TProcessor = {
    +    new TUGIAssumingProcessor(processor, secretManager, true)
    +  }
    +
    +  /**
    +   * Wrap a TProcessor to capture the client information like connecting 
userid, ip etc
    +   */
    +  def wrapNonAssumingProcessor(processor: TProcessor): TProcessor = {
    +    new TUGIAssumingProcessor(processor, secretManager, false)
    +  }
    +
    +  def getRemoteAddress: InetAddress = AuthBridgeServer.remoteAddress.get
    +
    +  def getRemoteUser: String = AuthBridgeServer.remoteUser.get
    +
    +  def getUserAuthMechanism: String = AuthBridgeServer.userAuthMechanism.get
    +
    +}
    +
    +/**
    + * A TransportFactory that wraps another one, but assumes a specified UGI
    + * before calling through.
    + *
    + * This is used on the server side to assume the server's Principal when 
accepting
    + * clients.
    + *
    + * This class is derived from Hive's one.
    + */
    +private[auth] class TUGIAssumingTransportFactory(
    +    val wrapped: TTransportFactory, val ugi: UserGroupInformation) extends 
TTransportFactory {
    +  assert(wrapped != null)
    +  assert(ugi != null)
    +
    +  override def getTransport(trans: TTransport): TTransport = {
    +    ugi.doAs(new PrivilegedAction[TTransport]() {
    +      override def run: TTransport = wrapped.getTransport(trans)
    +    })
    +  }
    +}
    +
    +/**
    + * CallbackHandler for SASL DIGEST-MD5 mechanism.
    + */
    +// This code is pretty much completely based on Hadoop's 
SaslRpcServer.SaslDigestCallbackHandler -
    +// the only reason we could not use that Hadoop class as-is was because it 
needs a
    +// Server.Connection.
    +sealed class SaslDigestCallbackHandler(
    +    val secretManager: LivyDelegationTokenSecretManager) extends 
CallbackHandler with Logging {
    +  @throws[InvalidToken]
    +  private def getPassword(tokenId: LivyDelegationTokenIdentifier): 
Array[Char] = {
    +    encodePassword(secretManager.retrievePassword(tokenId))
    +  }
    +
    +  private def encodePassword(password: Array[Byte]): Array[Char] = {
    +    new String(Base64.encodeBase64(password)).toCharArray
    +  }
    +
    +  @throws[InvalidToken]
    +  @throws[UnsupportedCallbackException]
    +  override def handle(callbacks: Array[Callback]): Unit = {
    +    var nc: NameCallback = null
    +    var pc: PasswordCallback = null
    +    callbacks.foreach {
    +      case ac: AuthorizeCallback =>
    +        val authid: String = ac.getAuthenticationID
    +        val authzid: String = ac.getAuthorizationID
    +        if (authid == authzid) ac.setAuthorized(true)
    +        else ac.setAuthorized(false)
    +        if (ac.isAuthorized) {
    +          if (logger.isDebugEnabled) {
    +            val username = SaslRpcServer.getIdentifier(authzid, 
secretManager).getUser.getUserName
    +            debug(s"SASL server DIGEST-MD5 callback: setting canonicalized 
client ID: $username")
    +          }
    +          ac.setAuthorizedID(authzid)
    +        }
    +      case c: NameCallback => nc = c
    +      case c: PasswordCallback => pc = c
    +      case _: RealmCallback => // Do nothing.
    +      case other =>
    +        throw new UnsupportedCallbackException(other, "Unrecognized SASL 
DIGEST-MD5 Callback")
    +    }
    +    if (pc != null) {
    +      val tokenIdentifier = SaslRpcServer.getIdentifier(nc.getDefaultName, 
secretManager)
    +      val password: Array[Char] = getPassword(tokenIdentifier)
    +      if (logger.isDebugEnabled) {
    +        debug("SASL server DIGEST-MD5 callback: setting password for 
client: " +
    +          tokenIdentifier.getUser)
    +      }
    +      pc.setPassword(password)
    +    }
    +  }
    +}
    +
    +/**
    + * Processor that pulls the SaslServer object out of the transport, and 
assumes the remote user's
    + * UGI before calling through to the original processor.
    + *
    + * This is used on the server side to set the UGI for each specific call.
    --- End diff --
    
    Most of the operations here are needed as this is setting the `remoteUser` 
which is used in many places later. But the UGI impersonation may not be needed 
indeed.


---

Reply via email to