This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 740f2ee251 [CALCITE-6408] Not-null ThreadLocal
740f2ee251 is described below

commit 740f2ee2511cf27421ff10bcf3f63e138c42d059
Author: Julian Hyde <[email protected]>
AuthorDate: Thu May 9 11:40:20 2024 -0700

    [CALCITE-6408] Not-null ThreadLocal
    
    Make various ThreadLocal instances non-nullable. They must
    have an initializer, but the caller can use the value without
    checking whether it is null.
---
 .../org/apache/calcite/jdbc/CalcitePrepare.java    |  13 +-
 .../java/org/apache/calcite/prepare/Prepare.java   |   2 +-
 .../apache/calcite/rel/rules/DateRangeRules.java   |   9 +-
 .../main/java/org/apache/calcite/runtime/Hook.java |  15 +-
 .../org/apache/calcite/runtime/SqlFunctions.java   |  15 +-
 .../org/apache/calcite/runtime/XmlFunctions.java   |  30 ++--
 .../apache/calcite/sql/parser/SqlParserUtil.java   |   7 +-
 .../calcite/sql/type/SqlTypeCoercionRule.java      |   8 +-
 .../calcite/sql/validate/SqlValidatorImpl.java     |   7 +-
 .../org/apache/calcite/util/TryThreadLocal.java    | 152 +++++++++++++++++----
 .../calcite/util/format/FormatElementEnum.java     |  11 +-
 .../apache/calcite/sql/type/SqlTypeUtilTest.java   |  35 +++--
 .../java/org/apache/calcite/test/JdbcTest.java     |   7 +-
 .../java/org/apache/calcite/util/UtilTest.java     |  57 ++++++++
 .../org/apache/calcite/test/DiffRepository.java    |   4 +-
 .../calcite/test/catalog/CountingFactory.java      |   5 +-
 .../main/java/org/apache/calcite/util/Smalls.java  |   6 +-
 17 files changed, 269 insertions(+), 114 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/jdbc/CalcitePrepare.java 
b/core/src/main/java/org/apache/calcite/jdbc/CalcitePrepare.java
index bda806be90..ec37fd4c7a 100644
--- a/core/src/main/java/org/apache/calcite/jdbc/CalcitePrepare.java
+++ b/core/src/main/java/org/apache/calcite/jdbc/CalcitePrepare.java
@@ -45,6 +45,7 @@ import 
org.apache.calcite.sql.validate.CyclicDefinitionException;
 import org.apache.calcite.sql.validate.SqlValidator;
 import org.apache.calcite.tools.RelRunner;
 import org.apache.calcite.util.ImmutableIntList;
+import org.apache.calcite.util.TryThreadLocal;
 
 import com.fasterxml.jackson.annotation.JsonIgnore;
 import com.google.common.collect.ImmutableList;
@@ -70,8 +71,8 @@ import static java.util.Objects.requireNonNull;
  */
 public interface CalcitePrepare {
   Function0<CalcitePrepare> DEFAULT_FACTORY = CalcitePrepareImpl::new;
-  ThreadLocal<@Nullable Deque<Context>> THREAD_CONTEXT_STACK =
-      ThreadLocal.withInitial(ArrayDeque::new);
+  TryThreadLocal<Deque<Context>> THREAD_CONTEXT_STACK =
+      TryThreadLocal.withInitial(ArrayDeque::new);
 
   ParseResult parse(Context context, String sql);
 
@@ -193,7 +194,7 @@ public interface CalcitePrepare {
     }
 
     public static void push(Context context) {
-      final Deque<Context> stack = castNonNull(THREAD_CONTEXT_STACK.get());
+      final Deque<Context> stack = THREAD_CONTEXT_STACK.get();
       final List<String> path = context.getObjectPath();
       if (path != null) {
         for (Context context1 : stack) {
@@ -207,11 +208,13 @@ public interface CalcitePrepare {
     }
 
     public static Context peek() {
-      return castNonNull(castNonNull(THREAD_CONTEXT_STACK.get()).peek());
+      final Deque<Context> stack = THREAD_CONTEXT_STACK.get();
+      return castNonNull(stack.peek());
     }
 
     public static void pop(Context context) {
-      Context x = castNonNull(THREAD_CONTEXT_STACK.get()).pop();
+      final Deque<Context> stack = THREAD_CONTEXT_STACK.get();
+      Context x = castNonNull(stack).pop();
       assert x == context;
     }
 
diff --git a/core/src/main/java/org/apache/calcite/prepare/Prepare.java 
b/core/src/main/java/org/apache/calcite/prepare/Prepare.java
index 59006c1337..f14c864c1c 100644
--- a/core/src/main/java/org/apache/calcite/prepare/Prepare.java
+++ b/core/src/main/java/org/apache/calcite/prepare/Prepare.java
@@ -390,7 +390,7 @@ public abstract class Prepare {
     // For now, don't trim if there are more than 3 joins. The projects
     // near the leaves created by trim migrate past joins and seem to
     // prevent join-reordering.
-    return castNonNull(THREAD_TRIM.get()) || RelOptUtil.countJoins(rootRel) < 
2;
+    return THREAD_TRIM.get() || RelOptUtil.countJoins(rootRel) < 2;
   }
 
   protected abstract void init(Class runtimeContextClass);
diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/DateRangeRules.java 
b/core/src/main/java/org/apache/calcite/rel/rules/DateRangeRules.java
index d720624dd4..ce0df2b013 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/DateRangeRules.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/DateRangeRules.java
@@ -39,6 +39,7 @@ import org.apache.calcite.tools.RelBuilderFactory;
 import org.apache.calcite.util.DateString;
 import org.apache.calcite.util.TimestampString;
 import org.apache.calcite.util.TimestampWithTimeZoneString;
+import org.apache.calcite.util.TryThreadLocal;
 import org.apache.calcite.util.Util;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -137,7 +138,7 @@ public abstract class DateRangeRules {
    * generate hundreds of ranges we'll later throw away. */
   static ImmutableSortedSet<TimeUnitRange> extractTimeUnits(RexNode e) {
     try (ExtractFinder finder = ExtractFinder.THREAD_INSTANCES.get()) {
-      assert requireNonNull(finder, "finder").timeUnits.isEmpty() && 
finder.opKinds.isEmpty()
+      assert finder.timeUnits.isEmpty() && finder.opKinds.isEmpty()
           : "previous user did not clean up";
       e.accept(finder);
       return ImmutableSortedSet.copyOf(finder.timeUnits);
@@ -190,7 +191,7 @@ public abstract class DateRangeRules {
      * If none of these, we cannot apply the rule. */
     private static boolean containsRoundingExpression(Filter filter) {
       try (ExtractFinder finder = ExtractFinder.THREAD_INSTANCES.get()) {
-        assert requireNonNull(finder, "finder").timeUnits.isEmpty() && 
finder.opKinds.isEmpty()
+        assert finder.timeUnits.isEmpty() && finder.opKinds.isEmpty()
             : "previous user did not clean up";
         filter.getCondition().accept(finder);
         return finder.timeUnits.contains(TimeUnitRange.YEAR)
@@ -239,8 +240,8 @@ public abstract class DateRangeRules {
         EnumSet.noneOf(TimeUnitRange.class);
     private final Set<SqlKind> opKinds = EnumSet.noneOf(SqlKind.class);
 
-    private static final ThreadLocal<@Nullable ExtractFinder> THREAD_INSTANCES 
=
-        ThreadLocal.withInitial(ExtractFinder::new);
+    private static final TryThreadLocal<ExtractFinder> THREAD_INSTANCES =
+        TryThreadLocal.withInitial(ExtractFinder::new);
 
     private ExtractFinder() {
       super(true);
diff --git a/core/src/main/java/org/apache/calcite/runtime/Hook.java 
b/core/src/main/java/org/apache/calcite/runtime/Hook.java
index 98f3cc3fd1..eced70ba2c 100644
--- a/core/src/main/java/org/apache/calcite/runtime/Hook.java
+++ b/core/src/main/java/org/apache/calcite/runtime/Hook.java
@@ -18,10 +18,10 @@ package org.apache.calcite.runtime;
 
 import org.apache.calcite.rel.RelRoot;
 import org.apache.calcite.util.Holder;
+import org.apache.calcite.util.TryThreadLocal;
 import org.apache.calcite.util.Util;
 
 import org.apiguardian.api.API;
-import org.checkerframework.checker.nullness.qual.Nullable;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -29,8 +29,6 @@ import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.function.Consumer;
 import java.util.function.Function;
 
-import static org.apache.calcite.linq4j.Nullness.castNonNull;
-
 /**
  * Collection of hooks that can be set by observers and are executed at various
  * parts of the query preparation process.
@@ -111,8 +109,8 @@ public enum Hook {
       new CopyOnWriteArrayList<>();
 
   @SuppressWarnings("ImmutableEnumChecker")
-  private final ThreadLocal<@Nullable List<Consumer<Object>>> threadHandlers =
-      ThreadLocal.withInitial(ArrayList::new);
+  private final TryThreadLocal<List<Consumer<Object>>> threadHandlers =
+      TryThreadLocal.withInitial(ArrayList::new);
 
   /** Adds a handler for this Hook.
    *
@@ -156,7 +154,7 @@ public enum Hook {
   /** Adds a handler for this thread. */
   public <T> Closeable addThread(final Consumer<T> handler) {
     //noinspection unchecked
-    castNonNull(threadHandlers.get()).add((Consumer<Object>) handler);
+    threadHandlers.get().add((Consumer<Object>) handler);
     return () -> removeThread(handler);
   }
 
@@ -182,8 +180,9 @@ public enum Hook {
   }
 
   /** Removes a thread handler from this Hook. */
+  @SuppressWarnings({"rawtypes", "UnusedReturnValue"})
   private boolean removeThread(Consumer handler) {
-    return castNonNull(threadHandlers.get()).remove(handler);
+    return threadHandlers.get().remove(handler);
   }
 
   // CHECKSTYLE: IGNORE 1
@@ -211,7 +210,7 @@ public enum Hook {
     for (Consumer<Object> handler : handlers) {
       handler.accept(arg);
     }
-    for (Consumer<Object> handler : castNonNull(threadHandlers.get())) {
+    for (Consumer<Object> handler : threadHandlers.get()) {
       handler.accept(arg);
     }
   }
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java 
b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index c6cdc64e65..e426b6d5fd 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -44,6 +44,7 @@ import org.apache.calcite.sql.fun.SqlLibraryOperators;
 import org.apache.calcite.util.NumberUtil;
 import org.apache.calcite.util.TimeWithTimeZoneString;
 import org.apache.calcite.util.TimestampWithTimeZoneString;
+import org.apache.calcite.util.TryThreadLocal;
 import org.apache.calcite.util.Unsafe;
 import org.apache.calcite.util.Util;
 import org.apache.calcite.util.format.FormatElement;
@@ -213,8 +214,8 @@ public class SqlFunctions {
    * <p>This is a straw man of an implementation whose main goal is to prove
    * that sequences can be parsed, validated and planned. A real application
    * will want persistent values for sequences, shared among threads. */
-  private static final ThreadLocal<@Nullable Map<String, AtomicLong>> 
THREAD_SEQUENCES =
-      ThreadLocal.withInitial(HashMap::new);
+  private static final TryThreadLocal<Map<String, AtomicLong>> 
THREAD_SEQUENCES =
+      TryThreadLocal.withInitial(HashMap::new);
 
   /** A byte string consisting of a single byte that is the ASCII space
    * character (0x20). */
@@ -5597,14 +5598,8 @@ public class SqlFunctions {
   }
 
   private static AtomicLong getAtomicLong(String key) {
-    final Map<String, AtomicLong> map =
-        requireNonNull(THREAD_SEQUENCES.get(), "THREAD_SEQUENCES.get()");
-    AtomicLong atomic = map.get(key);
-    if (atomic == null) {
-      atomic = new AtomicLong();
-      map.put(key, atomic);
-    }
-    return atomic;
+    final Map<String, AtomicLong> map = THREAD_SEQUENCES.get();
+    return map.computeIfAbsent(key, key_ -> new AtomicLong());
   }
 
   /** Support the ARRAYS_OVERLAP function. */
diff --git a/core/src/main/java/org/apache/calcite/runtime/XmlFunctions.java 
b/core/src/main/java/org/apache/calcite/runtime/XmlFunctions.java
index 4773b7ecc9..2a8af55e23 100644
--- a/core/src/main/java/org/apache/calcite/runtime/XmlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/XmlFunctions.java
@@ -17,6 +17,7 @@
 package org.apache.calcite.runtime;
 
 import org.apache.calcite.util.SimpleNamespaceContext;
+import org.apache.calcite.util.TryThreadLocal;
 
 import org.apache.commons.lang3.StringUtils;
 
@@ -66,8 +67,8 @@ import static java.util.Objects.requireNonNull;
  */
 public class XmlFunctions {
 
-  private static final ThreadLocal<@Nullable XPathFactory> XPATH_FACTORY =
-      ThreadLocal.withInitial(() -> {
+  private static final TryThreadLocal<XPathFactory> XPATH_FACTORY =
+      TryThreadLocal.withInitial(() -> {
         final XPathFactory xPathFactory = XPathFactory.newInstance();
         try {
           xPathFactory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, 
true);
@@ -76,8 +77,9 @@ public class XmlFunctions {
         }
         return xPathFactory;
       });
-  private static final ThreadLocal<@Nullable TransformerFactory> 
TRANSFORMER_FACTORY =
-      ThreadLocal.withInitial(() -> {
+
+  private static final TryThreadLocal<TransformerFactory> TRANSFORMER_FACTORY =
+      TryThreadLocal.withInitial(() -> {
         final TransformerFactory transformerFactory = 
TransformerFactory.newInstance();
         transformerFactory.setErrorListener(new InternalErrorListener());
         try {
@@ -87,8 +89,9 @@ public class XmlFunctions {
         }
         return transformerFactory;
       });
-  private static final ThreadLocal<@Nullable DocumentBuilderFactory> 
DOCUMENT_BUILDER_FACTORY =
-      ThreadLocal.withInitial(() -> {
+
+  private static final TryThreadLocal<DocumentBuilderFactory> 
DOCUMENT_BUILDER_FACTORY =
+      TryThreadLocal.withInitial(() -> {
         final DocumentBuilderFactory documentBuilderFactory = 
DocumentBuilderFactory.newInstance();
         documentBuilderFactory.setXIncludeAware(false);
         documentBuilderFactory.setExpandEntityReferences(false);
@@ -117,7 +120,8 @@ public class XmlFunctions {
     }
     try {
       final Node documentNode = getDocumentNode(input);
-      XPathExpression xpathExpression = 
castNonNull(XPATH_FACTORY.get()).newXPath().compile(xpath);
+      XPathExpression xpathExpression =
+          XPATH_FACTORY.get().newXPath().compile(xpath);
       try {
         NodeList nodes = (NodeList) xpathExpression
             .evaluate(documentNode, XPathConstants.NODESET);
@@ -145,8 +149,8 @@ public class XmlFunctions {
     try {
       final Source xsltSource = new StreamSource(new StringReader(xslt));
       final Source xmlSource = new StreamSource(new StringReader(xml));
-      final Transformer transformer = castNonNull(TRANSFORMER_FACTORY.get())
-          .newTransformer(xsltSource);
+      final Transformer transformer =
+          TRANSFORMER_FACTORY.get().newTransformer(xsltSource);
       final StringWriter writer = new StringWriter();
       final StreamResult result = new StreamResult(writer);
       transformer.setErrorListener(new InternalErrorListener());
@@ -169,7 +173,7 @@ public class XmlFunctions {
       return null;
     }
     try {
-      XPath xPath = castNonNull(XPATH_FACTORY.get()).newXPath();
+      XPath xPath = XPATH_FACTORY.get().newXPath();
 
       if (namespace != null) {
         xPath.setNamespaceContext(extractNamespaceContext(namespace));
@@ -206,7 +210,7 @@ public class XmlFunctions {
       return null;
     }
     try {
-      XPath xPath = castNonNull(XPATH_FACTORY.get()).newXPath();
+      XPath xPath = XPATH_FACTORY.get().newXPath();
       if (namespace != null) {
         xPath.setNamespaceContext(extractNamespaceContext(namespace));
       }
@@ -247,7 +251,7 @@ public class XmlFunctions {
 
   private static String convertNodeToString(Node node) throws 
TransformerException {
     StringWriter writer = new StringWriter();
-    Transformer transformer = 
castNonNull(TRANSFORMER_FACTORY.get()).newTransformer();
+    Transformer transformer = TRANSFORMER_FACTORY.get().newTransformer();
     transformer.setErrorListener(new InternalErrorListener());
     transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes");
     transformer.transform(new DOMSource(node), new StreamResult(writer));
@@ -257,7 +261,7 @@ public class XmlFunctions {
   private static Node getDocumentNode(final String xml) {
     try {
       final DocumentBuilder documentBuilder =
-          castNonNull(DOCUMENT_BUILDER_FACTORY.get()).newDocumentBuilder();
+          DOCUMENT_BUILDER_FACTORY.get().newDocumentBuilder();
       final InputSource inputSource = new InputSource(new StringReader(xml));
       return documentBuilder.parse(inputSource);
     } catch (final ParserConfigurationException | SAXException | IOException 
e) {
diff --git 
a/core/src/main/java/org/apache/calcite/sql/parser/SqlParserUtil.java 
b/core/src/main/java/org/apache/calcite/sql/parser/SqlParserUtil.java
index 44fecaa053..c5611ab408 100644
--- a/core/src/main/java/org/apache/calcite/sql/parser/SqlParserUtil.java
+++ b/core/src/main/java/org/apache/calcite/sql/parser/SqlParserUtil.java
@@ -49,6 +49,7 @@ import org.apache.calcite.util.TimeString;
 import org.apache.calcite.util.TimeWithTimeZoneString;
 import org.apache.calcite.util.TimestampString;
 import org.apache.calcite.util.TimestampWithTimeZoneString;
+import org.apache.calcite.util.TryThreadLocal;
 import org.apache.calcite.util.Util;
 import org.apache.calcite.util.trace.CalciteTrace;
 
@@ -1214,11 +1215,11 @@ public final class SqlParserUtil {
   /** Pre-initialized {@link DateFormat} objects, to be used within the current
    * thread, because {@code DateFormat} is not thread-safe. */
   private static class Format {
-    private static final ThreadLocal<@Nullable Format> PER_THREAD =
-        ThreadLocal.withInitial(Format::new);
+    private static final TryThreadLocal<Format> PER_THREAD =
+        TryThreadLocal.withInitial(Format::new);
 
     private static Format get() {
-      return requireNonNull(PER_THREAD.get(), "PER_THREAD.get()");
+      return PER_THREAD.get();
     }
 
     final DateFormat timestamp =
diff --git 
a/core/src/main/java/org/apache/calcite/sql/type/SqlTypeCoercionRule.java 
b/core/src/main/java/org/apache/calcite/sql/type/SqlTypeCoercionRule.java
index a727a74c84..ab03b24ce7 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/SqlTypeCoercionRule.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/SqlTypeCoercionRule.java
@@ -16,11 +16,11 @@
  */
 package org.apache.calcite.sql.type;
 
+import org.apache.calcite.util.TryThreadLocal;
+
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 
-import org.checkerframework.checker.nullness.qual.Nullable;
-
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
@@ -78,8 +78,7 @@ public class SqlTypeCoercionRule implements 
SqlTypeMappingRule {
 
   private static final SqlTypeCoercionRule LENIENT_INSTANCE;
 
-  public static final ThreadLocal<@Nullable SqlTypeCoercionRule> 
THREAD_PROVIDERS =
-      ThreadLocal.withInitial(() -> SqlTypeCoercionRule.INSTANCE);
+  public static final TryThreadLocal<SqlTypeCoercionRule> THREAD_PROVIDERS;
 
   //~ Instance fields --------------------------------------------------------
 
@@ -352,6 +351,7 @@ public class SqlTypeCoercionRule implements 
SqlTypeMappingRule {
             .build());
 
     LENIENT_INSTANCE = new SqlTypeCoercionRule(coerceRules.map);
+    THREAD_PROVIDERS = TryThreadLocal.of(SqlTypeCoercionRule.INSTANCE);
   }
 
   //~ Methods ----------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java 
b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
index 706c8e6be0..7b3c050554 100644
--- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
+++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
@@ -342,11 +342,8 @@ public class SqlValidatorImpl implements 
SqlValidatorWithHints {
 
     if (config.conformance().allowLenientCoercion()) {
       final SqlTypeCoercionRule rules =
-          requireNonNull(
-              config.typeCoercionRules() != null
-                  ? config.typeCoercionRules()
-                  : SqlTypeCoercionRule.THREAD_PROVIDERS.get(),
-              "rules");
+          first(config.typeCoercionRules(),
+              SqlTypeCoercionRule.instance());
 
       final ImmutableSet<SqlTypeName> arrayMapping =
           ImmutableSet.<SqlTypeName>builder()
diff --git a/core/src/main/java/org/apache/calcite/util/TryThreadLocal.java 
b/core/src/main/java/org/apache/calcite/util/TryThreadLocal.java
index 9b5ee2024e..12e702753c 100644
--- a/core/src/main/java/org/apache/calcite/util/TryThreadLocal.java
+++ b/core/src/main/java/org/apache/calcite/util/TryThreadLocal.java
@@ -16,38 +16,54 @@
  */
 package org.apache.calcite.util;
 
+import org.checkerframework.checker.nullness.qual.NonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 import java.util.function.Supplier;
 
 import static org.apache.calcite.linq4j.Nullness.castNonNull;
 
+import static java.util.Objects.requireNonNull;
+
 /**
  * Thread-local variable that returns a handle that can be closed.
  *
  * @param <T> Value type
  */
-public class TryThreadLocal<T> extends ThreadLocal<@Nullable T> {
-  private final T initialValue;
-
-  /** Creates a TryThreadLocal.
+public abstract class TryThreadLocal<T> extends ThreadLocal<@Nullable T> {
+  /** Creates a TryThreadLocal with a fixed initial value.
    *
    * @param initialValue Initial value
    */
-  public static <T> TryThreadLocal<T> of(T initialValue) {
-    return new TryThreadLocal<>(initialValue);
+  public static <S> TryThreadLocal<S> of(S initialValue) {
+    return new FixedTryThreadLocal<>(initialValue);
   }
 
-  private TryThreadLocal(T initialValue) {
-    this.initialValue = initialValue;
+  /** Creates a TryThreadLocal with a fixed initial value
+   * whose values are never null.
+   *
+   * <p>The value returned from {@link #get} is never null;
+   * the initial value must not be null;
+   * you must not call {@link #set(Object)} with a null value.
+   *
+   * @param initialValue Initial value
+   */
+  public static <S> TryThreadLocal<@NonNull S> ofNonNull(S initialValue) {
+    return new NonNullFixedTryThreadLocal<>(
+        requireNonNull(initialValue, "initialValue"));
   }
 
-  // It is important that this method is final.
-  // This ensures that the sub-class does not choose a different initial
-  // value. Then the close logic can detect whether the previous value was
-  // equal to the initial value.
-  @Override protected final T initialValue() {
-    return initialValue;
+  /** Creates a TryThreadLocal with a supplier for the initial value.
+   *
+   * <p>The value returned from {@link #get} is never null;
+   * the supplier must never return null;
+   * you must not call {@link #set(Object)} with a null value.
+   *
+   * @param supplier Supplier
+   */
+  public static <S> TryThreadLocal<@NonNull S> withInitial(
+      Supplier<? extends @NonNull S> supplier) {
+    return new SuppliedTryThreadLocal<>(supplier);
   }
 
   @Override public T get() {
@@ -63,21 +79,11 @@ public class TryThreadLocal<T> extends 
ThreadLocal<@Nullable T> {
     return () -> restoreTo(previous);
   }
 
-  /** Sets the value back to a previous value.
-   *
-   * <p>If the previous value was {@link #initialValue}, calls
-   * {@link #remove()}. There's no way to tell whether {@link #set} has
-   * been called previously, but the effect is the same. */
-  protected void restoreTo(T previous) {
-    if (previous == initialValue) {
-      remove();
-    } else {
-      set(previous);
-    }
-  }
+  /** Sets the value back to a previous value. */
+  protected abstract void restoreTo(T previous);
 
   /** Performs an action with this ThreadLocal set to a particular value
-   * in this thread, and restores the previous value afterwards.
+   * in this thread, and restores the previous value afterward.
    *
    * <p>This method is named after the Standard ML {@code let} construct,
    * for example {@code let val x = 1 in x + 2 end}. */
@@ -96,7 +102,7 @@ public class TryThreadLocal<T> extends ThreadLocal<@Nullable 
T> {
   }
 
   /** Calls a Supplier with this ThreadLocal set to a particular value,
-   * in this thread, and restores the previous value afterwards.
+   * in this thread, and restores the previous value afterward.
    *
    * <p>This method is named after the Standard ML {@code let} construct,
    * for example {@code let val x = 1 in x + 2 end}. */
@@ -119,4 +125,94 @@ public class TryThreadLocal<T> extends 
ThreadLocal<@Nullable T> {
     /** Sets the value back; never throws. */
     @Override void close();
   }
+
+  /** Implementation of {@link org.apache.calcite.util.TryThreadLocal}
+   * with a fixed initial value.
+   *
+   * @param <T> Value type */
+  private static class FixedTryThreadLocal<T> extends TryThreadLocal<T> {
+    private final T initialValue;
+
+    protected FixedTryThreadLocal(T initialValue) {
+      this.initialValue = initialValue;
+    }
+
+    /**
+     * {@inheritDoc}
+     *
+     * <p>It is important that this method is final.
+     * This ensures that a subclass does not choose a different initial
+     * value. Then the close logic can detect whether the previous value was
+     * equal to the initial value.
+     */
+    @Override protected final T initialValue() {
+      return initialValue;
+    }
+
+    /** Sets the value back to a previous value.
+     *
+     * <p>If the previous value was {@link #initialValue}, calls
+     * {@link #remove()}. There's no way to tell whether {@link #set} has
+     * been called previously, but the effect is the same. */
+    @Override protected void restoreTo(T previous) {
+      if (previous == initialValue) {
+        remove();
+      } else {
+        set(previous);
+      }
+    }
+  }
+
+  /** Implementation of {@link org.apache.calcite.util.TryThreadLocal}
+   * with a fixed initial value.
+   *
+   * @param <T> Value type */
+  private static class NonNullFixedTryThreadLocal<T>
+      extends FixedTryThreadLocal<T> {
+    private NonNullFixedTryThreadLocal(T initialValue) {
+      super(requireNonNull(initialValue, "initialValue"));
+    }
+
+    @Override public void set(@Nullable T value) {
+      super.set(requireNonNull(value, "value"));
+    }
+
+    @Override public T get() {
+      return requireNonNull(super.get());
+    }
+  }
+
+  /** Implementation of {@link org.apache.calcite.util.TryThreadLocal}
+   * whose initial value comes from a supplier.
+   *
+   * @param <T> Value type */
+  private static class SuppliedTryThreadLocal<T> extends TryThreadLocal<T> {
+    private final Supplier<? extends @NonNull T> supplier;
+
+    SuppliedTryThreadLocal(Supplier<? extends @NonNull T> supplier) {
+      this.supplier = requireNonNull(supplier, "supplier");
+    }
+
+    @Override protected @NonNull T initialValue() {
+      return requireNonNull(supplier.get(), "supplier returned null");
+    }
+
+    @Override public void set(@Nullable T value) {
+      super.set(requireNonNull(value, "value"));
+    }
+
+    @Override protected void restoreTo(T previous) {
+      // If the thread had no value before they called 'push', should we call
+      // 'remove()' here? No, for two reasons.
+      //
+      // First, it's not possible to know whether there was a value.
+      // (ThreadLocal.isPresent() is package-protected.)
+      //
+      // Second, it may be what the user wants. If each 'restoreTo' call 
invokes
+      // 'remove', then the next call to 'push' will invoke the supplier again.
+      // Sometimes the user doesn't want to pay the initialization cost 
multiple
+      // times, or to lose the state in the initialized object.
+      set(previous);
+    }
+  }
 }
diff --git 
a/core/src/main/java/org/apache/calcite/util/format/FormatElementEnum.java 
b/core/src/main/java/org/apache/calcite/util/format/FormatElementEnum.java
index f4f290903f..b78d88c91e 100644
--- a/core/src/main/java/org/apache/calcite/util/format/FormatElementEnum.java
+++ b/core/src/main/java/org/apache/calcite/util/format/FormatElementEnum.java
@@ -17,11 +17,10 @@
 package org.apache.calcite.util.format;
 
 import org.apache.calcite.avatica.util.DateTimeUtils;
+import org.apache.calcite.util.TryThreadLocal;
 
 import org.apache.commons.lang3.StringUtils;
 
-import org.checkerframework.checker.nullness.qual.Nullable;
-
 import java.text.DateFormat;
 import java.text.SimpleDateFormat;
 import java.time.LocalDate;
@@ -30,8 +29,6 @@ import java.util.Calendar;
 import java.util.Date;
 import java.util.Locale;
 
-import static org.apache.calcite.linq4j.Nullness.castNonNull;
-
 import static java.util.Objects.requireNonNull;
 
 /**
@@ -428,12 +425,12 @@ public enum FormatElementEnum implements FormatElement {
   /** Work space. Provides a value for each mutable data structure that might
    * be needed by a format element. Ensures thread-safety. */
   static class Work {
-    private static final ThreadLocal<@Nullable Work> THREAD_WORK =
-        ThreadLocal.withInitial(Work::new);
+    private static final TryThreadLocal<Work> THREAD_WORK =
+        TryThreadLocal.withInitial(Work::new);
 
     /** Returns an instance of Work for this thread. */
     static Work get() {
-      return castNonNull(THREAD_WORK.get());
+      return THREAD_WORK.get();
     }
 
     final Calendar calendar =
diff --git 
a/core/src/test/java/org/apache/calcite/sql/type/SqlTypeUtilTest.java 
b/core/src/test/java/org/apache/calcite/sql/type/SqlTypeUtilTest.java
index e918da54a2..8fc574795f 100644
--- a/core/src/test/java/org/apache/calcite/sql/type/SqlTypeUtilTest.java
+++ b/core/src/test/java/org/apache/calcite/sql/type/SqlTypeUtilTest.java
@@ -22,9 +22,11 @@ import org.apache.calcite.sql.SqlBasicTypeNameSpec;
 import org.apache.calcite.sql.SqlCollectionTypeNameSpec;
 import org.apache.calcite.sql.SqlIdentifier;
 import org.apache.calcite.sql.SqlRowTypeNameSpec;
+import org.apache.calcite.util.TryThreadLocal;
 
 import com.google.common.collect.ImmutableList;
 
+import org.hamcrest.Matcher;
 import org.junit.jupiter.api.Test;
 
 import java.util.Arrays;
@@ -113,24 +115,29 @@ class SqlTypeUtilTest {
     final SqlTypeCoercionRule defaultRules = SqlTypeCoercionRule.instance();
     builder.addAll(defaultRules.getTypeMapping());
     // Do the tweak, for example, if we want to add a rule to allow
-    // coerce BOOLEAN to TIMESTAMP.
+    // coercion of BOOLEAN to TIMESTAMP.
     builder.add(SqlTypeName.TIMESTAMP,
         builder.copyValues(SqlTypeName.TIMESTAMP)
             .add(SqlTypeName.BOOLEAN).build());
 
-    // Initialize a SqlTypeCoercionRules with the new builder mappings.
-    SqlTypeCoercionRule typeCoercionRules = 
SqlTypeCoercionRule.instance(builder.map);
-    assertThat(SqlTypeUtil.canCastFrom(f.sqlTimestampPrec3, f.sqlBoolean, 
true),
-        is(false));
-    assertThat(SqlTypeUtil.canCastFrom(f.sqlTimestampPrec3, f.sqlBoolean, 
defaultRules),
-        is(false));
-    SqlTypeCoercionRule.THREAD_PROVIDERS.set(typeCoercionRules);
-    assertThat(SqlTypeUtil.canCastFrom(f.sqlTimestampPrec3, f.sqlBoolean, 
true),
-        is(true));
-    assertThat(SqlTypeUtil.canCastFrom(f.sqlTimestampPrec3, f.sqlBoolean, 
typeCoercionRules),
-        is(true));
-    // Recover the mappings to default.
-    SqlTypeCoercionRule.THREAD_PROVIDERS.set(defaultRules);
+    // Try converting with both default rules and the new rule set.
+    checkConvert(defaultRules, is(false));
+    final SqlTypeCoercionRule typeCoercionRules =
+        SqlTypeCoercionRule.instance(builder.map);
+    try (TryThreadLocal.Memo ignored =
+             SqlTypeCoercionRule.THREAD_PROVIDERS.push(typeCoercionRules)) {
+      checkConvert(typeCoercionRules, is(true));
+    }
+  }
+
+  private void checkConvert(SqlTypeCoercionRule rules,
+      Matcher<Boolean> matcher) {
+    assertThat(
+        SqlTypeUtil.canCastFrom(f.sqlTimestampPrec3, f.sqlBoolean, true),
+        matcher);
+    assertThat(
+        SqlTypeUtil.canCastFrom(f.sqlTimestampPrec3, f.sqlBoolean, rules),
+        matcher);
   }
 
   @Test void testEqualAsCollectionSansNullability() {
diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java 
b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
index 54c64f8ded..4eeb666d4f 100644
--- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java
+++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
@@ -8686,8 +8686,8 @@ public class JdbcTest {
 
   /** Factory for EMP and DEPT tables. */
   public static class EmpDeptTableFactory implements TableFactory<Table> {
-    public static final TryThreadLocal<@Nullable List<Employee>> 
THREAD_COLLECTION =
-        TryThreadLocal.of(null);
+    public static final TryThreadLocal<List<Employee>> THREAD_COLLECTION =
+        TryThreadLocal.of(Collections.emptyList());
 
     public Table create(
         SchemaPlus schema,
@@ -8703,9 +8703,6 @@ public class JdbcTest {
         break;
       case "MUTABLE_EMPLOYEES":
         List<Employee> employees = THREAD_COLLECTION.get();
-        if (employees == null) {
-          employees = Collections.emptyList();
-        }
         return JdbcFrontLinqBackTest.mutable(name, employees, false);
       case "DEPARTMENTS":
         clazz = Department.class;
diff --git a/core/src/test/java/org/apache/calcite/util/UtilTest.java 
b/core/src/test/java/org/apache/calcite/util/UtilTest.java
index 324465acfd..882b276e9b 100644
--- a/core/src/test/java/org/apache/calcite/util/UtilTest.java
+++ b/core/src/test/java/org/apache/calcite/util/UtilTest.java
@@ -48,6 +48,7 @@ import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
 import com.google.common.primitives.Ints;
 
+import org.checkerframework.checker.nullness.qual.NonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.hamcrest.Description;
 import org.hamcrest.FeatureMatcher;
@@ -102,6 +103,7 @@ import java.util.function.Function;
 import java.util.function.IntFunction;
 import java.util.function.ObjIntConsumer;
 import java.util.function.Predicate;
+import java.util.function.Supplier;
 import java.util.function.UnaryOperator;
 
 import static org.apache.calcite.test.Matchers.isLinux;
@@ -113,6 +115,7 @@ import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.CoreMatchers.isA;
 import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.CoreMatchers.notNullValue;
 import static org.hamcrest.CoreMatchers.nullValue;
 import static org.hamcrest.CoreMatchers.sameInstance;
 import static org.hamcrest.CoreMatchers.startsWith;
@@ -2463,6 +2466,21 @@ class UtilTest {
    * <a href="https://issues.apache.org/jira/browse/CALCITE-915";>[CALCITE-915]
    * Tests do not unset ThreadLocal values on exit</a>. */
   @Test void testTryThreadLocal() {
+    final TryThreadLocal<String> local0 = TryThreadLocal.ofNonNull("foo");
+    assertThat(local0.get(), is("foo"));
+    TryThreadLocal.Memo memo0 = local0.push("bar");
+    assertThat(local0.get(), is("bar"));
+    local0.set("baz");
+    assertThat(local0.get(), is("baz"));
+    memo0.close();
+    assertThat(local0.get(), is("foo"));
+    try {
+      local0.set(null);
+      fail("expected err");
+    } catch (NullPointerException e) {
+      // ok
+    }
+
     final TryThreadLocal<String> local1 = TryThreadLocal.of("foo");
     assertThat(local1.get(), is("foo"));
     TryThreadLocal.Memo memo1 = local1.push("bar");
@@ -2471,6 +2489,7 @@ class UtilTest {
     assertThat(local1.get(), is("baz"));
     memo1.close();
     assertThat(local1.get(), is("foo"));
+    local1.set(null); // null values are allowed
 
     final TryThreadLocal<@Nullable String> local2 =
         TryThreadLocal.of(null);
@@ -2492,6 +2511,44 @@ class UtilTest {
       local2.set("z");
     }
     assertThat(local2.get(), is("x"));
+
+    final Supplier<@NonNull String> stringSupplier =
+        new Supplier<String>() {
+      final Random random = new Random();
+
+      @Override public String get() {
+        return "s" + random.nextInt(15);
+      }
+    };
+    final TryThreadLocal<String> local3 =
+        TryThreadLocal.withInitial(stringSupplier);
+    assertThat(local3.get(), startsWith("s"));
+    TryThreadLocal.Memo memo3 = local3.push("bar");
+    assertThat(local3.get(), is("bar"));
+    local3.set("baz");
+    assertThat(local3.get(), is("baz"));
+    memo3.close();
+    assertThat(local3.get(), startsWith("s"));
+    try {
+      local3.set(null);
+      fail("expected err");
+    } catch (NullPointerException e) {
+      assertThat(e, notNullValue());
+    }
+
+    @SuppressWarnings("DataFlowIssue")
+    final Supplier<@NonNull String> nullSupplier = () -> null;
+    final TryThreadLocal<String> local4 =
+        TryThreadLocal.withInitial(nullSupplier);
+    local4.set("abc");
+    assertThat(local4.get(), is("abc"));
+    local4.remove();
+    try {
+      final String s = local4.get();
+      fail("expected error, got " + s);
+    } catch (NullPointerException e) {
+      assertThat(e, notNullValue());
+    }
   }
 
   /** Tests
diff --git a/testkit/src/main/java/org/apache/calcite/test/DiffRepository.java 
b/testkit/src/main/java/org/apache/calcite/test/DiffRepository.java
index 35e86fad24..ef2fd76c39 100644
--- a/testkit/src/main/java/org/apache/calcite/test/DiffRepository.java
+++ b/testkit/src/main/java/org/apache/calcite/test/DiffRepository.java
@@ -173,7 +173,7 @@ public class DiffRepository {
   private static final LoadingCache<Key, DiffRepository> REPOSITORY_CACHE =
       CacheBuilder.newBuilder().build(CacheLoader.from(Key::toRepo));
 
-  private static final ThreadLocal<@Nullable DocumentBuilderFactory> 
DOCUMENT_BUILDER_FACTORY =
+  private static final ThreadLocal<DocumentBuilderFactory> 
DOCUMENT_BUILDER_FACTORY =
       ThreadLocal.withInitial(() -> {
         final DocumentBuilderFactory documentBuilderFactory = 
DocumentBuilderFactory.newInstance();
         documentBuilderFactory.setXIncludeAware(false);
@@ -225,7 +225,7 @@ public class DiffRepository {
     // Load the document.
     try {
       DocumentBuilder docBuilder =
-          
Nullness.castNonNull(DOCUMENT_BUILDER_FACTORY.get()).newDocumentBuilder();
+          DOCUMENT_BUILDER_FACTORY.get().newDocumentBuilder();
       try (InputStream inputStream = refFile.openStream()) {
         // Parse the reference file.
         this.doc = docBuilder.parse(inputStream);
diff --git 
a/testkit/src/main/java/org/apache/calcite/test/catalog/CountingFactory.java 
b/testkit/src/main/java/org/apache/calcite/test/catalog/CountingFactory.java
index af3257f974..c27cfcfd05 100644
--- a/testkit/src/main/java/org/apache/calcite/test/catalog/CountingFactory.java
+++ b/testkit/src/main/java/org/apache/calcite/test/catalog/CountingFactory.java
@@ -26,6 +26,7 @@ import org.apache.calcite.sql.SqlFunction;
 import org.apache.calcite.sql2rel.InitializerContext;
 import org.apache.calcite.sql2rel.InitializerExpressionFactory;
 import org.apache.calcite.sql2rel.NullInitializerExpressionFactory;
+import org.apache.calcite.util.TryThreadLocal;
 
 import com.google.common.collect.ImmutableList;
 
@@ -39,8 +40,8 @@ import java.util.concurrent.atomic.AtomicInteger;
  * <p>If a column is in {@code defaultColumns}, returns 1 as the default
  * value. */
 public class CountingFactory extends NullInitializerExpressionFactory {
-  public static final ThreadLocal<AtomicInteger> THREAD_CALL_COUNT =
-      ThreadLocal.withInitial(AtomicInteger::new);
+  public static final TryThreadLocal<AtomicInteger> THREAD_CALL_COUNT =
+      TryThreadLocal.withInitial(AtomicInteger::new);
 
   private final List<String> defaultColumns;
 
diff --git a/testkit/src/main/java/org/apache/calcite/util/Smalls.java 
b/testkit/src/main/java/org/apache/calcite/util/Smalls.java
index 1b1ffeb386..53cdbb8a64 100644
--- a/testkit/src/main/java/org/apache/calcite/util/Smalls.java
+++ b/testkit/src/main/java/org/apache/calcite/util/Smalls.java
@@ -515,7 +515,7 @@ public class Smalls {
    * and named parameters. */
   public static class MyPlusFunction {
     public static final ThreadLocal<AtomicInteger> INSTANCE_COUNT =
-        ThreadLocal.withInitial(() -> new AtomicInteger(0));
+        ThreadLocal.withInitial(AtomicInteger::new);
 
     // Note: Not marked @Deterministic
     public MyPlusFunction() {
@@ -532,7 +532,7 @@ public class Smalls {
    * {@link org.apache.calcite.schema.FunctionContext} parameter. */
   public static class MyPlusInitFunction {
     public static final ThreadLocal<AtomicInteger> INSTANCE_COUNT =
-        ThreadLocal.withInitial(() -> new AtomicInteger(0));
+        ThreadLocal.withInitial(AtomicInteger::new);
     public static final ThreadLocal<String> THREAD_DIGEST =
         new ThreadLocal<>();
 
@@ -567,7 +567,7 @@ public class Smalls {
   /** As {@link MyPlusFunction} but declared to be deterministic. */
   public static class MyDeterministicPlusFunction {
     public static final ThreadLocal<AtomicInteger> INSTANCE_COUNT =
-        ThreadLocal.withInitial(() -> new AtomicInteger(0));
+        ThreadLocal.withInitial(AtomicInteger::new);
 
     @Deterministic public MyDeterministicPlusFunction() {
       INSTANCE_COUNT.get().incrementAndGet();


Reply via email to