Repository: incubator-zeppelin Updated Branches: refs/heads/master d169284fd -> ca67374c2
ZEPPELIN-297 Dependency should be loaded in pypsark This PR fixes https://issues.apache.org/jira/browse/ZEPPELIN-297 * [x] fix, by set context classloader with classes from dependency loader when initializing py4j gatewayserver * [x] test Author: Lee moon soo <[email protected]> Closes #298 from Leemoonsoo/ZEPPELIN-297 and squashes the following commits: 0de89fe [Lee moon soo] Add logging 1e8f52a [Lee moon soo] Add test 163acfa [Lee moon soo] Set classloader for gatewayserver with classes from dependency loader Project: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/commit/ca67374c Tree: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/tree/ca67374c Diff: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/diff/ca67374c Branch: refs/heads/master Commit: ca67374c2b0fe26f6bc6992910bb11d5658258e0 Parents: d169284 Author: Lee moon soo <[email protected]> Authored: Mon Sep 14 14:04:40 2015 +0900 Committer: Lee moon soo <[email protected]> Committed: Mon Sep 14 15:13:53 2015 +0900 ---------------------------------------------------------------------- .../zeppelin/spark/PySparkInterpreter.java | 59 ++++++++++++++++++++ .../zeppelin/rest/ZeppelinSparkClusterTest.java | 45 ++++++++++++++- 2 files changed, 103 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/ca67374c/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java index 0e58729..d0e5fec 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java @@ -25,7 +25,10 @@ import java.io.IOException; import java.io.OutputStreamWriter; import java.io.PipedInputStream; import java.io.PipedOutputStream; +import java.net.MalformedURLException; import java.net.ServerSocket; +import java.net.URL; +import java.net.URLClassLoader; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -51,6 +54,7 @@ import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterResult.Code; import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.WrappedInterpreter; +import org.apache.zeppelin.spark.dep.DependencyContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -125,6 +129,44 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand @Override public void open() { + DepInterpreter depInterpreter = getDepInterpreter(); + + // load libraries from Dependency Interpreter + URL [] urls = new URL[0]; + + if (depInterpreter != null) { + DependencyContext depc = depInterpreter.getDependencyContext(); + if (depc != null) { + List<File> files = depc.getFiles(); + List<URL> urlList = new LinkedList<URL>(); + if (files != null) { + for (File f : files) { + try { + urlList.add(f.toURI().toURL()); + } catch (MalformedURLException e) { + logger.error("Error", e); + } + } + + urls = urlList.toArray(urls); + } + } + } + + ClassLoader oldCl = Thread.currentThread().getContextClassLoader(); + try { + URLClassLoader newCl = new URLClassLoader(urls, oldCl); + Thread.currentThread().setContextClassLoader(newCl); + createGatewayServerAndStartScript(); + } catch (Exception e) { + logger.error("Error", e); + throw new InterpreterException(e); + } finally { + Thread.currentThread().setContextClassLoader(oldCl); + } + } + + private void createGatewayServerAndStartScript() { // create python script createPythonScript(); @@ -400,6 +442,23 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand } } + private DepInterpreter getDepInterpreter() { + InterpreterGroup intpGroup = getInterpreterGroup(); + if (intpGroup == null) return null; + synchronized (intpGroup) { + for (Interpreter intp : intpGroup) { + if (intp.getClassName().equals(DepInterpreter.class.getName())) { + Interpreter p = intp; + while (p instanceof WrappedInterpreter) { + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + return (DepInterpreter) p; + } + } + } + return null; + } + @Override public void onProcessComplete(int exitValue) { http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/ca67374c/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java ---------------------------------------------------------------------- diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java index aa2476a..d5006ee 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java @@ -18,8 +18,12 @@ package org.apache.zeppelin.rest; import static org.junit.Assert.assertEquals; +import java.io.File; import java.io.IOException; +import java.util.List; +import org.apache.commons.io.FileUtils; +import org.apache.zeppelin.interpreter.InterpreterSetting; import org.apache.zeppelin.notebook.Note; import org.apache.zeppelin.notebook.Paragraph; import org.apache.zeppelin.scheduler.Job.Status; @@ -75,7 +79,6 @@ public class ZeppelinSparkClusterTest extends AbstractTestRestApi { public void pySparkTest() throws IOException { // create new note Note note = ZeppelinServer.notebook.createNote(); - int sparkVersion = getSparkVersionNumber(note); if (isPyspark() && sparkVersion >= 12) { // pyspark supported from 1.2.1 @@ -129,6 +132,46 @@ public class ZeppelinSparkClusterTest extends AbstractTestRestApi { ZeppelinServer.notebook.removeNote(note.id()); } + @Test + public void pySparkDepLoaderTest() throws IOException { + // create new note + Note note = ZeppelinServer.notebook.createNote(); + + if (isPyspark() && getSparkVersionNumber(note) >= 14) { + // restart spark interpreter + List<InterpreterSetting> settings = + ZeppelinServer.notebook.getBindedInterpreterSettings(note.id()); + + for (InterpreterSetting setting : settings) { + if (setting.getGroup().equals("spark")) { + ZeppelinServer.notebook.getInterpreterFactory().restart(setting.id()); + break; + } + } + + // load dep + Paragraph p0 = note.addParagraph(); + p0.setText("%dep z.load(\"com.databricks:spark-csv_2.11:1.2.0\")"); + note.run(p0.getId()); + waitForFinish(p0); + + // write test csv file + File tmpFile = File.createTempFile("test", "csv"); + FileUtils.write(tmpFile, "a,b\n1,2"); + + // load data using libraries from dep loader + Paragraph p1 = note.addParagraph(); + p1.setText("%pyspark\n" + + "from pyspark.sql import SQLContext\n" + + "print(sqlContext.read.format('com.databricks.spark.csv')" + + ".load('"+ tmpFile.getAbsolutePath() +"').count())"); + note.run(p1.getId()); + + waitForFinish(p1); + assertEquals("2\n", p1.getResult().message()); + } + } + /** * Get spark version number as a numerical value. * eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ...
