ppkarwasz commented on code in PR #2072:
URL: https://github.com/apache/logging-log4j2/pull/2072#discussion_r1418831243


##########
log4j-jndi-test/src/main/java/org/apache/logging/log4j/jndi/test/junit/JndiRule.java:
##########
@@ -16,42 +16,113 @@
  */
 package org.apache.logging.log4j.jndi.test.junit;
 
+import static java.util.Objects.requireNonNull;
+import static org.junit.Assert.assertNotNull;
+
 import java.util.Collections;
+import java.util.Hashtable;
 import java.util.Map;
+import java.util.Set;
+import java.util.Spliterators;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+import javax.annotation.Nullable;
 import javax.naming.Context;
+import javax.naming.NameClassPair;
+import javax.naming.NamingException;
+import javax.naming.spi.InitialContextFactoryBuilder;
+import javax.naming.spi.NamingManager;
+import org.apache.logging.log4j.jndi.JndiManager;
 import org.junit.rules.TestRule;
 import org.junit.runner.Description;
 import org.junit.runners.model.Statement;
-import org.springframework.mock.jndi.SimpleNamingContextBuilder;
+import org.osjava.sj.jndi.MemoryContext;
 
 /**
  * JUnit rule to create a mock {@link Context} and bind an object to a name.
  *
  * @since 2.8
  */
+@SuppressWarnings("BanJNDI")
 public class JndiRule implements TestRule {
 
-    private final Map<String, Object> initialBindings;
+    static {
+        final InitialContextFactoryBuilder factoryBuilder =
+                factoryBuilderEnv -> factoryEnv -> new MemoryContext(new 
Hashtable<>()) {};
+        try {
+            NamingManager.setInitialContextFactoryBuilder(factoryBuilder);
+        } catch (final NamingException error) {
+            throw new RuntimeException(error);
+        }
+    }
+
+    @Nullable
+    private final String managerName;
+
+    private final Map<String, Object> bindings;
 
     public JndiRule(final String name, final Object value) {
-        this.initialBindings = Collections.singletonMap(name, value);
+        this(null, Collections.singletonMap(name, value));
     }
 
-    public JndiRule(final Map<String, Object> initialBindings) {
-        this.initialBindings = initialBindings;
+    public JndiRule(@Nullable final String managerName, final String name, 
final Object value) {
+        this(managerName, Collections.singletonMap(name, value));
+    }
+
+    public JndiRule(final Map<String, Object> bindings) {
+        this(null, bindings);
+    }
+
+    public JndiRule(@Nullable final String managerName, final Map<String, 
Object> bindings) {
+        this.managerName = managerName;
+        this.bindings = requireNonNull(bindings, "bindings");
     }
 
     @Override
     public Statement apply(final Statement base, final Description 
description) {
         return new Statement() {
             @Override
             public void evaluate() throws Throwable {
-                final SimpleNamingContextBuilder builder = 
SimpleNamingContextBuilder.emptyActivatedContextBuilder();
-                for (final Map.Entry<String, Object> entry : 
initialBindings.entrySet()) {
-                    builder.bind(entry.getKey(), entry.getValue());
-                }
+                resetJndiManager();
                 base.evaluate();
             }
         };
     }
+
+    private void resetJndiManager() throws NamingException {
+        if (JndiManager.isJndiEnabled()) {
+            final Context context = getContext();
+            clearBindings(context);
+            addBindings(context);
+        }
+    }
+
+    private Context getContext() {
+        final JndiManager manager =
+                managerName == null ? JndiManager.getDefaultManager() : 
JndiManager.getDefaultManager(managerName);
+        @Nullable final Context context = manager.getContext();
+        assertNotNull(context);
+        return context;
+    }
+
+    private static void clearBindings(final Context context) throws 
NamingException {
+        final Set<NameClassPair> existingBindings = StreamSupport.stream(
+                        
Spliterators.spliteratorUnknownSize(context.list("").asIterator(), 0), false)
+                .collect(Collectors.toSet());
+        existingBindings.forEach(binding -> {
+            try {
+                context.unbind(binding.getName());
+            } catch (NamingException error) {
+                throw new RuntimeException(error);
+            }
+        });
+    }
+
+    private void addBindings(final Context context) throws NamingException {
+        for (final Map.Entry<String, Object> entry : bindings.entrySet()) {
+            final String name = entry.getKey();
+            final Object object = entry.getValue();
+            context.bind(name, object);
+        }

Review Comment:
   Comparing this method to #2071 I see a potential problem: a bind for 
`foo/bar` will fail if there is no context bound to `foo`. Maybe this method 
should also call `context.createSubcontext` for all subcontexts that do not 
exist?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to