Repository: zeppelin
Updated Branches:
  refs/heads/master a691e94c5 -> 80997e8e6


[ZEPPELIN-1212] User impersonation support in JDBC

### What is this PR for?
Add impersonation support to JDBC interpreters, in addition to Kerberos 
Authentication to improve auditability in all JDBC interpreters.

### What type of PR is it?
[Bug Fix | Improvement]

### What is the Jira issue?
* [ZEPPELIN-1212](https://issues.apache.org/jira/browse/ZEPPELIN-1212)

### How should this be tested?
In JDBC interpreter setting add following properties

 - zeppelin.jdbc.auth.type = KERBEROS
 - zeppelin.jdbc.principal = principal value
 - zeppelin.jdbc.keytab.location = keytab location
 - enable shiro authentication via shiro.ini

Now try and run any of hive's query (say show tables) it should return with 
valid results/errors depending on user permission.

### Questions:
* Does the licenses files need update? n/a
* Is there breaking changes for older versions? n/a
* Does this needs documentation? n/a

Author: Prabhjyot Singh <[email protected]>

Closes #1205 from prabhjyotsingh/ZEPPELIN-1212 and squashes the following 
commits:

e22b681 [Prabhjyot Singh] Fix CI
66824a0 [Prabhjyot Singh] ZEPPELIN-1212 User impersonation support in JDBC 
interpreter for Hive and Phoenix(Others)


Project: http://git-wip-us.apache.org/repos/asf/zeppelin/repo
Commit: http://git-wip-us.apache.org/repos/asf/zeppelin/commit/80997e8e
Tree: http://git-wip-us.apache.org/repos/asf/zeppelin/tree/80997e8e
Diff: http://git-wip-us.apache.org/repos/asf/zeppelin/diff/80997e8e

Branch: refs/heads/master
Commit: 80997e8e6d77ca9f4811d8a801f5805c505fc7f4
Parents: a691e94
Author: Prabhjyot Singh <[email protected]>
Authored: Wed Jul 20 14:04:26 2016 +0530
Committer: Prabhjyot Singh <[email protected]>
Committed: Fri Jul 22 13:02:16 2016 +0530

----------------------------------------------------------------------
 .../apache/zeppelin/jdbc/JDBCInterpreter.java   | 100 +++++++++++++------
 .../jdbc/security/JDBCSecurityImpl.java         |  26 +++--
 .../zeppelin/jdbc/JDBCInterpreterTest.java      |  20 ++--
 3 files changed, 98 insertions(+), 48 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/zeppelin/blob/80997e8e/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java
----------------------------------------------------------------------
diff --git a/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java 
b/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java
index 818ae69..d5f6236 100644
--- a/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java
+++ b/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java
@@ -17,23 +17,20 @@ package org.apache.zeppelin.jdbc;
 import static org.apache.commons.lang.StringUtils.containsIgnoreCase;
 
 import java.io.IOException;
+import java.security.PrivilegedExceptionAction;
 import java.sql.Connection;
 import java.sql.DriverManager;
 import java.sql.ResultSet;
 import java.sql.ResultSetMetaData;
 import java.sql.SQLException;
 import java.sql.Statement;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Properties;
-import java.util.Set;
+import java.util.*;
 
 import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.zeppelin.interpreter.Interpreter;
 import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterException;
 import org.apache.zeppelin.interpreter.InterpreterResult;
 import org.apache.zeppelin.interpreter.InterpreterResult.Code;
 import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
@@ -170,19 +167,11 @@ public class JDBCInterpreter extends Interpreter {
 
     logger.debug("propertiesMap: {}", propertiesMap);
 
-    Connection connection = null;
-    SqlCompleter sqlCompleter = null;
     if 
(!StringUtils.isAnyEmpty(property.getProperty("zeppelin.jdbc.auth.type"))) {
       JDBCSecurityImpl.createSecureConfiguration(property);
     }
     for (String propertyKey : propertiesMap.keySet()) {
-      try {
-        connection = getConnection(propertyKey);
-        sqlCompleter = createSqlCompleter(connection);
-      } catch (Exception e) {
-        sqlCompleter = createSqlCompleter(null);
-      }
-      propertyKeySqlCompleterMap.put(propertyKey, sqlCompleter);
+      propertyKeySqlCompleterMap.put(propertyKey, createSqlCompleter(null));
     }
   }
 
@@ -203,7 +192,8 @@ public class JDBCInterpreter extends Interpreter {
     return completer;
   }
 
-  public Connection getConnection(String propertyKey) throws 
ClassNotFoundException, SQLException {
+  public Connection getConnection(String propertyKey, String user)
+      throws ClassNotFoundException, SQLException, InterpreterException {
     Connection connection = null;
     if (propertyKey == null || propertiesMap.get(propertyKey) == null) {
       return null;
@@ -219,22 +209,70 @@ public class JDBCInterpreter extends Interpreter {
       }
     }
     if (null == connection) {
-      Properties properties = propertiesMap.get(propertyKey);
+      final Properties properties = propertiesMap.get(propertyKey);
       logger.info(properties.getProperty(DRIVER_KEY));
       Class.forName(properties.getProperty(DRIVER_KEY));
-      String url = properties.getProperty(URL_KEY);
-      connection = DriverManager.getConnection(url, properties);
+      final String url = properties.getProperty(URL_KEY);
+
+      UserGroupInformation.AuthenticationMethod authType = 
JDBCSecurityImpl.getAuthtype(property);
+      switch (authType) {
+          case KERBEROS:
+            if (user == null) {
+              connection = DriverManager.getConnection(url, properties);
+            } else {
+              if ("hive".equalsIgnoreCase(propertyKey)) {
+                connection = DriverManager.getConnection(url + 
";hive.server2.proxy.user=" + user,
+                    properties);
+              } else {
+                UserGroupInformation ugi = null;
+                try {
+                  ugi = UserGroupInformation.createProxyUser(user,
+                      UserGroupInformation.getCurrentUser());
+                } catch (Exception e) {
+                  logger.error("Error in createProxyUser", e);
+                  StringBuilder stringBuilder = new StringBuilder();
+                  stringBuilder.append(e.getMessage()).append("\n");
+                  stringBuilder.append(e.getCause());
+                  throw new InterpreterException(stringBuilder.toString());
+                }
+                try {
+                  connection = ugi.doAs(new 
PrivilegedExceptionAction<Connection>() {
+                    @Override
+                    public Connection run() throws Exception {
+                      return DriverManager.getConnection(url, properties);
+                    }
+                  });
+                } catch (Exception e) {
+                  logger.error("Error in doAs", e);
+                  StringBuilder stringBuilder = new StringBuilder();
+                  stringBuilder.append(e.getMessage()).append("\n");
+                  stringBuilder.append(e.getCause());
+                  throw new InterpreterException(stringBuilder.toString());
+                }
+              }
+            }
+            break;
+
+          default:
+            connection = DriverManager.getConnection(url, properties);
+      }
+
     }
+    propertyKeySqlCompleterMap.put(propertyKey, 
createSqlCompleter(connection));
     return connection;
   }
 
-  public Statement getStatement(String propertyKey, String paragraphId)
-      throws SQLException, ClassNotFoundException {
+  public Statement getStatement(String propertyKey, String paragraphId,
+                                InterpreterContext interpreterContext)
+      throws SQLException, ClassNotFoundException, InterpreterException {
     Connection connection;
-    if (paragraphIdConnectionMap.containsKey(paragraphId)) {
-      connection = paragraphIdConnectionMap.get(paragraphId);
+
+    if (paragraphIdConnectionMap.containsKey(paragraphId +
+        interpreterContext.getAuthenticationInfo().getUser())) {
+      connection = paragraphIdConnectionMap.get(paragraphId +
+          interpreterContext.getAuthenticationInfo().getUser());
     } else {
-      connection = getConnection(propertyKey);
+      connection = getConnection(propertyKey, 
interpreterContext.getAuthenticationInfo().getUser());
     }
 
     if (connection == null) {
@@ -243,11 +281,13 @@ public class JDBCInterpreter extends Interpreter {
 
     Statement statement = connection.createStatement();
     if (isStatementClosed(statement)) {
-      connection = getConnection(propertyKey);
+      connection = getConnection(propertyKey, 
interpreterContext.getAuthenticationInfo().getUser());
       statement = connection.createStatement();
     }
-    paragraphIdConnectionMap.put(paragraphId, connection);
-    paragraphIdStatementMap.put(paragraphId, statement);
+    paragraphIdConnectionMap.put(paragraphId + 
interpreterContext.getAuthenticationInfo().getUser(),
+        connection);
+    paragraphIdStatementMap.put(paragraphId + 
interpreterContext.getAuthenticationInfo().getUser(),
+        statement);
 
     return statement;
   }
@@ -303,7 +343,7 @@ public class JDBCInterpreter extends Interpreter {
 
     try {
 
-      Statement statement = getStatement(propertyKey, paragraphId);
+      Statement statement = getStatement(propertyKey, paragraphId, 
interpreterContext);
 
       if (statement == null) {
         return new InterpreterResult(Code.ERROR, "Prefix not found.");
@@ -419,7 +459,7 @@ public class JDBCInterpreter extends Interpreter {
 
     String paragraphId = context.getParagraphId();
     try {
-      paragraphIdStatementMap.get(paragraphId).cancel();
+      paragraphIdStatementMap.get(paragraphId + 
context.getAuthenticationInfo().getUser()).cancel();
     } catch (SQLException e) {
       logger.error("Error while cancelling...", e);
     }

http://git-wip-us.apache.org/repos/asf/zeppelin/blob/80997e8e/jdbc/src/main/java/org/apache/zeppelin/jdbc/security/JDBCSecurityImpl.java
----------------------------------------------------------------------
diff --git 
a/jdbc/src/main/java/org/apache/zeppelin/jdbc/security/JDBCSecurityImpl.java 
b/jdbc/src/main/java/org/apache/zeppelin/jdbc/security/JDBCSecurityImpl.java
index 03d957d..8cc2735 100644
--- a/jdbc/src/main/java/org/apache/zeppelin/jdbc/security/JDBCSecurityImpl.java
+++ b/jdbc/src/main/java/org/apache/zeppelin/jdbc/security/JDBCSecurityImpl.java
@@ -25,6 +25,7 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.util.Properties;
 
+import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
 import static 
org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.KERBEROS;
 import static 
org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.SIMPLE;
 
@@ -39,17 +40,7 @@ public class JDBCSecurityImpl {
    * @param properties
    */
   public static void createSecureConfiguration(Properties properties) {
-    UserGroupInformation.AuthenticationMethod authType;
-    try {
-      authType = UserGroupInformation
-          
.AuthenticationMethod.valueOf(properties.getProperty("zeppelin.jdbc.auth.type")
-              .trim().toUpperCase());
-    } catch (Exception e) {
-      LOGGER.error(String.format("Invalid auth.type detected with value %s, 
defaulting " +
-          "auth.type to SIMPLE", 
properties.getProperty("zeppelin.jdbc.auth.type").trim()));
-      authType = SIMPLE;
-    }
-
+    AuthenticationMethod authType = getAuthtype(properties);
 
     switch (authType) {
         case KERBEROS:
@@ -69,4 +60,17 @@ public class JDBCSecurityImpl {
     }
   }
 
+  public static AuthenticationMethod getAuthtype(Properties properties) {
+    AuthenticationMethod authType;
+    try {
+      authType = 
AuthenticationMethod.valueOf(properties.getProperty("zeppelin.jdbc.auth.type")
+          .trim().toUpperCase());
+    } catch (Exception e) {
+      LOGGER.error(String.format("Invalid auth.type detected with value %s, 
defaulting " +
+          "auth.type to SIMPLE", 
properties.getProperty("zeppelin.jdbc.auth.type")));
+      authType = SIMPLE;
+    }
+    return authType;
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/zeppelin/blob/80997e8e/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java
----------------------------------------------------------------------
diff --git 
a/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java 
b/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java
index b8e0220..bd5bae6 100644
--- a/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java
+++ b/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java
@@ -40,6 +40,7 @@ import org.apache.zeppelin.jdbc.JDBCInterpreter;
 import org.apache.zeppelin.scheduler.FIFOScheduler;
 import org.apache.zeppelin.scheduler.ParallelScheduler;
 import org.apache.zeppelin.scheduler.Scheduler;
+import org.apache.zeppelin.user.AuthenticationInfo;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -50,6 +51,7 @@ import com.mockrunner.jdbc.BasicJDBCTestCaseAdapter;
 public class JDBCInterpreterTest extends BasicJDBCTestCaseAdapter {
 
   static String jdbcConnection;
+  InterpreterContext interpreterContext;
 
   private static String getJdbcConnection() throws IOException {
     if(null == jdbcConnection) {
@@ -84,6 +86,8 @@ public class JDBCInterpreterTest extends 
BasicJDBCTestCaseAdapter {
     PreparedStatement insertStatement = connection.prepareStatement("insert 
into test_table(id, name) values ('a', 'a_name'),('b', 'b_name'),('c', ?);");
     insertStatement.setString(1, null);
     insertStatement.execute();
+    interpreterContext = new InterpreterContext("", "1", "", "", new 
AuthenticationInfo(), null, null, null, null,
+        null, null);
   }
 
 
@@ -126,24 +130,24 @@ public class JDBCInterpreterTest extends 
BasicJDBCTestCaseAdapter {
 
     String sqlQuery = "(fake) select * from test_table";
 
-    InterpreterResult interpreterResult = t.interpret(sqlQuery, new 
InterpreterContext("", "1", "", "", null, null, null, null, null, null, null));
+    InterpreterResult interpreterResult = t.interpret(sqlQuery, 
interpreterContext);
 
     // if prefix not found return ERROR and Prefix not found.
     assertEquals(InterpreterResult.Code.ERROR, interpreterResult.code());
     assertEquals("Prefix not found.", interpreterResult.message());
   }
-  
+
   @Test
   public void testDefaultProperties() throws SQLException {
     JDBCInterpreter jdbcInterpreter = new 
JDBCInterpreter(getJDBCTestProperties());
-    
+
     assertEquals("org.postgresql.Driver", 
jdbcInterpreter.getProperty(DEFAULT_DRIVER));
     assertEquals("jdbc:postgresql://localhost:5432/", 
jdbcInterpreter.getProperty(DEFAULT_URL));
     assertEquals("gpadmin", jdbcInterpreter.getProperty(DEFAULT_USER));
     assertEquals("", jdbcInterpreter.getProperty(DEFAULT_PASSWORD));
     assertEquals("1000", jdbcInterpreter.getProperty(COMMON_MAX_LINE));
   }
-  
+
   @Test
   public void testSelectQuery() throws SQLException, IOException {
     Properties properties = new Properties();
@@ -158,7 +162,7 @@ public class JDBCInterpreterTest extends 
BasicJDBCTestCaseAdapter {
 
     String sqlQuery = "select * from test_table WHERE ID in ('a', 'b')";
 
-    InterpreterResult interpreterResult = t.interpret(sqlQuery, new 
InterpreterContext("", "1", "", "", null, null, null, null, null, null, null));
+    InterpreterResult interpreterResult = t.interpret(sqlQuery, 
interpreterContext);
 
     assertEquals(InterpreterResult.Code.SUCCESS, interpreterResult.code());
     assertEquals(InterpreterResult.Type.TABLE, interpreterResult.type());
@@ -179,7 +183,7 @@ public class JDBCInterpreterTest extends 
BasicJDBCTestCaseAdapter {
 
     String sqlQuery = "select * from test_table WHERE ID = 'c'";
 
-    InterpreterResult interpreterResult = t.interpret(sqlQuery, new 
InterpreterContext("", "1", "", "", null, null, null, null, null, null, null));
+    InterpreterResult interpreterResult = t.interpret(sqlQuery, 
interpreterContext);
 
     assertEquals(InterpreterResult.Code.SUCCESS, interpreterResult.code());
     assertEquals(InterpreterResult.Type.TABLE, interpreterResult.type());
@@ -202,7 +206,7 @@ public class JDBCInterpreterTest extends 
BasicJDBCTestCaseAdapter {
 
     String sqlQuery = "select * from test_table";
 
-    InterpreterResult interpreterResult = t.interpret(sqlQuery, new 
InterpreterContext("", "1", "", "", null, null, null, null, null, null, null));
+    InterpreterResult interpreterResult = t.interpret(sqlQuery, 
interpreterContext);
 
     assertEquals(InterpreterResult.Code.SUCCESS, interpreterResult.code());
     assertEquals(InterpreterResult.Type.TABLE, interpreterResult.type());
@@ -244,6 +248,8 @@ public class JDBCInterpreterTest extends 
BasicJDBCTestCaseAdapter {
     JDBCInterpreter jdbcInterpreter = new JDBCInterpreter(properties);
     jdbcInterpreter.open();
 
+    jdbcInterpreter.interpret("", interpreterContext);
+
     List<InterpreterCompletion> completionList = 
jdbcInterpreter.completion("SEL", 0);
     
     InterpreterCompletion correctCompletionKeyword = new 
InterpreterCompletion("SELECT", "SELECT");

Reply via email to