Github user vanzin commented on a diff in the pull request:

    https://github.com/apache/incubator-livy/pull/117#discussion_r221764487
  
    --- Diff: 
thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthBridgeServer.scala
 ---
    @@ -0,0 +1,296 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.livy.thriftserver.auth
    +
    +import java.io.IOException
    +import java.net.InetAddress
    +import java.security.{PrivilegedAction, PrivilegedExceptionAction}
    +import java.util
    +import javax.security.auth.callback.{Callback, CallbackHandler, 
NameCallback, PasswordCallback, UnsupportedCallbackException}
    +import javax.security.sasl.{AuthorizeCallback, RealmCallback, SaslServer}
    +
    +import org.apache.commons.codec.binary.Base64
    +import org.apache.hadoop.fs.FileSystem
    +import org.apache.hadoop.security.{SaslRpcServer, UserGroupInformation}
    +import org.apache.hadoop.security.SaslRpcServer.AuthMethod
    +import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod
    +import org.apache.hadoop.security.token.SecretManager.InvalidToken
    +import org.apache.thrift.{TException, TProcessor}
    +import org.apache.thrift.protocol.TProtocol
    +import org.apache.thrift.transport.{TSaslServerTransport, TSocket, 
TTransport, TTransportException, TTransportFactory}
    +
    +import org.apache.livy.Logging
    +
    +/**
    + * The class is taken from Hive's `HadoopThriftAuthBridge.Server`. It 
bridges Thrift's SASL
    + * transports to Hadoop's SASL callback handlers and authentication 
classes.
    + *
    + * This class is based on Hive's one.
    + */
    +class AuthBridgeServer(private val secretManager: 
LivyDelegationTokenSecretManager) {
    +  private val ugi = try {
    +      UserGroupInformation.getCurrentUser
    +    } catch {
    +      case ioe: IOException => throw new TTransportException(ioe)
    +    }
    +
    +  /**
    +   * Create a TTransportFactory that, upon connection of a client socket,
    +   * negotiates a Kerberized SASL transport. The resulting 
TTransportFactory
    +   * can be passed as both the input and output transport factory when
    +   * instantiating a TThreadPoolServer, for example.
    +   *
    +   * @param saslProps Map of SASL properties
    +   */
    +  @throws[TTransportException]
    +  def createTransportFactory(saslProps: util.Map[String, String]): 
TTransportFactory = {
    +    val transFactory: TSaslServerTransport.Factory = 
createSaslServerTransportFactory(saslProps)
    +    new TUGIAssumingTransportFactory(transFactory, ugi)
    +  }
    +
    +  /**
    +   * Create a TSaslServerTransport.Factory that, upon connection of a 
client
    +   * socket, negotiates a Kerberized SASL transport.
    +   *
    +   * @param saslProps Map of SASL properties
    +   */
    +  @throws[TTransportException]
    +  def createSaslServerTransportFactory(
    +      saslProps: util.Map[String, String]): TSaslServerTransport.Factory = 
{
    +    // Parse out the kerberos principal, host, realm.
    +    val kerberosName: String = ugi.getUserName
    +    val names: Array[String] = 
SaslRpcServer.splitKerberosName(kerberosName)
    +    if (names.length != 3) {
    +      throw new TTransportException(s"Kerberos principal should have 3 
parts: $kerberosName")
    +    }
    +    val transFactory: TSaslServerTransport.Factory = new 
TSaslServerTransport.Factory
    +    transFactory.addServerDefinition(AuthMethod.KERBEROS.getMechanismName,
    +      names(0), names(1), // two parts of kerberos principal
    +      saslProps,
    +      new SaslRpcServer.SaslGssCallbackHandler)
    +    transFactory.addServerDefinition(AuthMethod.TOKEN.getMechanismName,
    +      null,
    +      SaslRpcServer.SASL_DEFAULT_REALM,
    +      saslProps,
    +      new SaslDigestCallbackHandler(secretManager))
    +    transFactory
    +  }
    +
    +  /**
    +   * Wrap a TTransportFactory in such a way that, before processing any 
RPC, it
    +   * assumes the UserGroupInformation of the user authenticated by
    +   * the SASL transport.
    +   */
    +  def wrapTransportFactory(transFactory: TTransportFactory): 
TTransportFactory = {
    +    new TUGIAssumingTransportFactory(transFactory, ugi)
    +  }
    +
    +  /**
    +   * Wrap a TProcessor in such a way that, before processing any RPC, it
    +   * assumes the UserGroupInformation of the user authenticated by
    +   * the SASL transport.
    +   */
    +  def wrapProcessor(processor: TProcessor): TProcessor = {
    +    new TUGIAssumingProcessor(processor, secretManager, true)
    +  }
    +
    +  /**
    +   * Wrap a TProcessor to capture the client information like connecting 
userid, ip etc
    +   */
    +  def wrapNonAssumingProcessor(processor: TProcessor): TProcessor = {
    +    new TUGIAssumingProcessor(processor, secretManager, false)
    +  }
    +
    +  def getRemoteAddress: InetAddress = AuthBridgeServer.remoteAddress.get
    +
    +  def getRemoteUser: String = AuthBridgeServer.remoteUser.get
    +
    +  def getUserAuthMechanism: String = AuthBridgeServer.userAuthMechanism.get
    +
    +}
    +
    +/**
    + * A TransportFactory that wraps another one, but assumes a specified UGI
    + * before calling through.
    + *
    + * This is used on the server side to assume the server's Principal when 
accepting
    + * clients.
    + *
    + * This class is derived from Hive's one.
    + */
    +private[auth] class TUGIAssumingTransportFactory(
    +    val wrapped: TTransportFactory, val ugi: UserGroupInformation) extends 
TTransportFactory {
    +  assert(wrapped != null)
    +  assert(ugi != null)
    +
    +  override def getTransport(trans: TTransport): TTransport = {
    +    ugi.doAs(new PrivilegedAction[TTransport]() {
    +      override def run: TTransport = wrapped.getTransport(trans)
    +    })
    +  }
    +}
    +
    +/**
    + * CallbackHandler for SASL DIGEST-MD5 mechanism.
    + */
    +// This code is pretty much completely based on Hadoop's 
SaslRpcServer.SaslDigestCallbackHandler -
    --- End diff --
    
    Put this inside the javadoc comment?


---

Reply via email to