Repository: flink
Updated Branches:
  refs/heads/master d0debf4a8 -> 4afca4b3a


[FLINK-7656] [runtime] Switch to user classloader before calling 
initializeOnMaster and finalizeOnMaster.

This closes #4690.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/4afca4b3
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/4afca4b3
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/4afca4b3

Branch: refs/heads/master
Commit: 4afca4b3a13b61c2754bc839c77ba4d4eb1d2da2
Parents: d0debf4
Author: Fabian Hueske <fhue...@apache.org>
Authored: Wed Sep 20 16:26:47 2017 +0200
Committer: Fabian Hueske <fhue...@apache.org>
Committed: Fri Sep 22 18:32:29 2017 +0200

----------------------------------------------------------------------
 .../runtime/jobgraph/OutputFormatVertex.java    | 50 ++++++++-----
 .../runtime/jobgraph/JobTaskVertexTest.java     | 76 ++++++++++++++++----
 2 files changed, 98 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/4afca4b3/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OutputFormatVertex.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OutputFormatVertex.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OutputFormatVertex.java
index c9ac564..77f207c 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OutputFormatVertex.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OutputFormatVertex.java
@@ -73,15 +73,24 @@ public class OutputFormatVertex extends JobVertex {
                catch (Throwable t) {
                        throw new Exception("Instantiating the OutputFormat (" 
+ formatDescription + ") failed: " + t.getMessage(), t);
                }
+
+               // set user classloader before calling user code
+               final ClassLoader prevContextCl = 
Thread.currentThread().getContextClassLoader();
+               Thread.currentThread().setContextClassLoader(loader);
+
                try {
-                       outputFormat.configure(cfg.getStubParameters());
-               }
-               catch (Throwable t) {
-                       throw new Exception("Configuring the OutputFormat (" + 
formatDescription + ") failed: " + t.getMessage(), t);
-               }
-               
-               if (outputFormat instanceof InitializeOnMaster) {
-                       ((InitializeOnMaster) 
outputFormat).initializeGlobal(getParallelism());
+                       // configure output format
+                       try {
+                               outputFormat.configure(cfg.getStubParameters());
+                       } catch (Throwable t) {
+                               throw new Exception("Configuring the 
OutputFormat (" + formatDescription + ") failed: " + t.getMessage(), t);
+                       }
+                       if (outputFormat instanceof InitializeOnMaster) {
+                               ((InitializeOnMaster) 
outputFormat).initializeGlobal(getParallelism());
+                       }
+               } finally {
+                       // restore previous classloader
+                       
Thread.currentThread().setContextClassLoader(prevContextCl);
                }
        }
        
@@ -107,15 +116,24 @@ public class OutputFormatVertex extends JobVertex {
                catch (Throwable t) {
                        throw new Exception("Instantiating the OutputFormat (" 
+ formatDescription + ") failed: " + t.getMessage(), t);
                }
+
+               // set user classloader before calling user code
+               final ClassLoader prevContextCl = 
Thread.currentThread().getContextClassLoader();
+               Thread.currentThread().setContextClassLoader(loader);
+
                try {
-                       outputFormat.configure(cfg.getStubParameters());
-               }
-               catch (Throwable t) {
-                       throw new Exception("Configuring the OutputFormat (" + 
formatDescription + ") failed: " + t.getMessage(), t);
-               }
-               
-               if (outputFormat instanceof FinalizeOnMaster) {
-                       ((FinalizeOnMaster) 
outputFormat).finalizeGlobal(getParallelism());
+                       // configure output format
+                       try {
+                               outputFormat.configure(cfg.getStubParameters());
+                       } catch (Throwable t) {
+                               throw new Exception("Configuring the 
OutputFormat (" + formatDescription + ") failed: " + t.getMessage(), t);
+                       }
+                       if (outputFormat instanceof FinalizeOnMaster) {
+                               ((FinalizeOnMaster) 
outputFormat).finalizeGlobal(getParallelism());
+                       }
+               } finally {
+                       // restore previous classloader
+                       
Thread.currentThread().setContextClassLoader(prevContextCl);
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/4afca4b3/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
index 970437a..794c5c6 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
@@ -18,12 +18,14 @@
 
 package org.apache.flink.runtime.jobgraph;
 
+import org.apache.flink.api.common.io.FinalizeOnMaster;
 import org.apache.flink.api.common.io.GenericInputFormat;
 import org.apache.flink.api.common.io.InitializeOnMaster;
 import org.apache.flink.api.common.io.InputFormat;
 import org.apache.flink.api.common.io.OutputFormat;
 import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
 import org.apache.flink.api.java.io.DiscardingOutputFormat;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.io.GenericInputSplit;
 import org.apache.flink.core.io.InputSplit;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
@@ -33,6 +35,8 @@ import org.apache.flink.util.InstantiationUtil;
 import org.junit.Test;
 
 import java.io.IOException;
+import java.net.URL;
+import java.net.URLClassLoader;
 
 import static org.junit.Assert.*;
 
@@ -84,10 +88,10 @@ public class JobTaskVertexTest {
        @Test
        public void testOutputFormatVertex() {
                try {
-                       final TestingOutputFormat outputFormat = new 
TestingOutputFormat();
+                       final OutputFormat outputFormat = new 
TestingOutputFormat();
                        final OutputFormatVertex of = new 
OutputFormatVertex("Name");
                        new 
TaskConfig(of.getConfiguration()).setStubWrapper(new 
UserCodeObjectWrapper<OutputFormat<?>>(outputFormat));
-                       final ClassLoader cl = getClass().getClassLoader();
+                       final ClassLoader cl = new TestClassLoader();
                        
                        try {
                                of.initializeOnMaster(cl);
@@ -97,19 +101,29 @@ public class JobTaskVertexTest {
                        }
                        
                        OutputFormatVertex copy = InstantiationUtil.clone(of);
+                       ClassLoader ctxCl = 
Thread.currentThread().getContextClassLoader();
                        try {
                                copy.initializeOnMaster(cl);
                                fail("Did not throw expected exception.");
                        } catch (TestException e) {
                                // all good
                        }
+                       assertEquals("Previous classloader was not restored.", 
ctxCl, Thread.currentThread().getContextClassLoader());
+
+                       try {
+                               copy.finalizeOnMaster(cl);
+                               fail("Did not throw expected exception.");
+                       } catch (TestException e) {
+                               // all good
+                       }
+                       assertEquals("Previous classloader was not restored.", 
ctxCl, Thread.currentThread().getContextClassLoader());
                }
                catch (Exception e) {
                        e.printStackTrace();
                        fail(e.getMessage());
                }
        }
-       
+
        @Test
        public void testInputFormatVertex() {
                try {
@@ -134,17 +148,8 @@ public class JobTaskVertexTest {
        
        // 
--------------------------------------------------------------------------------------------
        
-       private static final class TestingOutputFormat extends 
DiscardingOutputFormat<Object> implements InitializeOnMaster {
-               @Override
-               public void initializeGlobal(int parallelism) throws 
IOException {
-                       throw new TestException();
-               }
-       }
-       
        private static final class TestException extends IOException {}
        
-       // 
--------------------------------------------------------------------------------------------
-       
        private static final class TestSplit extends GenericInputSplit {
                
                public TestSplit(int partitionNumber, int 
totalNumberOfPartitions) {
@@ -169,4 +174,51 @@ public class JobTaskVertexTest {
                        return new GenericInputSplit[] { new TestSplit(0, 1) };
                }
        }
+
+       private static final class TestingOutputFormat extends 
DiscardingOutputFormat<Object> implements InitializeOnMaster, FinalizeOnMaster {
+
+               private boolean isConfigured = false;
+
+               @Override
+               public void initializeGlobal(int parallelism) throws 
IOException {
+                       if (!isConfigured) {
+                               throw new IllegalStateException("OutputFormat 
was not configured before initializeGlobal was called.");
+                       }
+                       if (!(Thread.currentThread().getContextClassLoader() 
instanceof TestClassLoader)) {
+                               throw new IllegalStateException("Context 
ClassLoader was not correctly switched.");
+                       }
+                       // notify we have been here.
+                       throw new TestException();
+               }
+
+               @Override
+               public void finalizeGlobal(int parallelism) throws IOException {
+                       if (!isConfigured) {
+                               throw new IllegalStateException("OutputFormat 
was not configured before finalizeGlobal was called.");
+                       }
+                       if (!(Thread.currentThread().getContextClassLoader() 
instanceof TestClassLoader)) {
+                               throw new IllegalStateException("Context 
ClassLoader was not correctly switched.");
+                       }
+                       // notify we have been here.
+                       throw new TestException();
+               }
+
+               @Override
+               public void configure(Configuration parameters) {
+                       if (isConfigured) {
+                               throw new IllegalStateException("OutputFormat 
is already configured.");
+                       }
+                       if (!(Thread.currentThread().getContextClassLoader() 
instanceof TestClassLoader)) {
+                               throw new IllegalStateException("Context 
ClassLoader was not correctly switched.");
+                       }
+                       isConfigured = true;
+               }
+
+       }
+
+       private static class TestClassLoader extends URLClassLoader {
+               public TestClassLoader() {
+                       super(new URL[0], 
Thread.currentThread().getContextClassLoader());
+               }
+       }
 }

Reply via email to