This is an automated email from the ASF dual-hosted git repository.

Wei-hao-Li pushed a commit to branch IoTDBLocal
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 0d161fe98abcb87d9d526fc772eb9d326bec32af
Author: Weihao Li <[email protected]>
AuthorDate: Fri Jun 26 00:30:21 2026 +0800

    fix beforeStart of UDSF+UDAF
    
    Signed-off-by: Weihao Li <[email protected]>
---
 .../relational/it/db/it/udf/IoTDBLocalIT.java      |  6 --
 .../relational/aggregation/AccumulatorFactory.java | 17 ------
 .../UserDefinedAggregateFunctionAccumulator.java   | 33 ++++++++++-
 .../GroupedUserDefinedAggregateAccumulator.java    | 20 +++++++
 .../udf/UserDefineScalarFunctionTransformer.java   | 69 +++++++++++-----------
 5 files changed, 87 insertions(+), 58 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBLocalIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBLocalIT.java
index d6fb2339ae2..c50d360ba66 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBLocalIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBLocalIT.java
@@ -78,12 +78,6 @@ public class IoTDBLocalIT {
         "CLEAR ATTRIBUTE CACHE",
       };
 
-  public static void main(String[] args) {
-    for (String sql : SETUP_SQLS) {
-      System.out.println(sql + ";");
-    }
-  }
-
   @BeforeClass
   public static void setUp() throws Exception {
     
EnvFactory.getEnv().getConfig().getCommonConfig().setEnforceStrongPassword(false);
diff --git 
a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java
 
b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java
index d2b40c742d5..f9d35cf67f5 100644
--- 
a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java
+++ 
b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java
@@ -67,7 +67,6 @@ import 
org.apache.iotdb.commons.queryengine.plan.udf.TableUDFUtils;
 import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer;
 import org.apache.iotdb.udf.api.IoTDBLocal;
 import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments;
-import org.apache.iotdb.udf.api.exception.UDFException;
 import org.apache.iotdb.udf.api.relational.AggregateFunction;
 
 import com.google.common.collect.ImmutableList;
@@ -287,14 +286,6 @@ public class AccumulatorFactory {
     FunctionArguments functionArguments =
         new FunctionArguments(
             UDFDataTypeTransformer.transformToUDFDataTypeList(inputDataTypes), 
inputAttributes);
-    try {
-      aggregateFunction.beforeStart(functionArguments, ioTDBLocal);
-    } catch (UDFException e) {
-      throw new RuntimeException(
-          "Error occurs when starting user-defined aggregate function "
-              + aggregateFunction.getClass().getName(),
-          e);
-    }
     return new UserDefinedAggregateFunctionAccumulator(
         aggregateFunction.analyze(functionArguments),
         aggregateFunction,
@@ -313,14 +304,6 @@ public class AccumulatorFactory {
     FunctionArguments functionArguments =
         new FunctionArguments(
             UDFDataTypeTransformer.transformToUDFDataTypeList(inputDataTypes), 
inputAttributes);
-    try {
-      aggregateFunction.beforeStart(functionArguments, ioTDBLocal);
-    } catch (UDFException e) {
-      throw new RuntimeException(
-          "Error occurs when starting user-defined aggregate function "
-              + aggregateFunction.getClass().getName(),
-          e);
-    }
     return new GroupedUserDefinedAggregateAccumulator(
         aggregateFunction,
         functionArguments,
diff --git 
a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java
 
b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java
index 0e1d38f54e4..7bedcbbb030 100644
--- 
a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java
+++ 
b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java
@@ -24,6 +24,7 @@ import org.apache.iotdb.udf.api.IoTDBLocal;
 import org.apache.iotdb.udf.api.State;
 import org.apache.iotdb.udf.api.customizer.analysis.AggregateFunctionAnalysis;
 import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments;
+import org.apache.iotdb.udf.api.exception.UDFException;
 import org.apache.iotdb.udf.api.relational.AggregateFunction;
 import org.apache.iotdb.udf.api.utils.ResultValue;
 
@@ -52,6 +53,7 @@ public class UserDefinedAggregateFunctionAccumulator 
implements TableAccumulator
   private final List<Type> inputDataTypes;
   private final State state;
   private final IoTDBLocal ioTDBLocal;
+  private boolean init;
 
   public UserDefinedAggregateFunctionAccumulator(
       AggregateFunctionAnalysis analysis,
@@ -59,6 +61,16 @@ public class UserDefinedAggregateFunctionAccumulator 
implements TableAccumulator
       FunctionArguments functionArguments,
       List<Type> inputDataTypes,
       IoTDBLocal ioTDBLocal) {
+    this(analysis, aggregateFunction, functionArguments, inputDataTypes, 
ioTDBLocal, false);
+  }
+
+  private UserDefinedAggregateFunctionAccumulator(
+      AggregateFunctionAnalysis analysis,
+      AggregateFunction aggregateFunction,
+      FunctionArguments functionArguments,
+      List<Type> inputDataTypes,
+      IoTDBLocal ioTDBLocal,
+      boolean init) {
     checkArgument(ioTDBLocal != null, "IoTDBLocal must not be null for UDAF");
     this.analysis = analysis;
     this.aggregateFunction = aggregateFunction;
@@ -66,6 +78,22 @@ public class UserDefinedAggregateFunctionAccumulator 
implements TableAccumulator
     this.inputDataTypes = inputDataTypes;
     this.state = aggregateFunction.createState();
     this.ioTDBLocal = ioTDBLocal;
+    this.init = init;
+  }
+
+  private void initIfNeeded() {
+    if (init) {
+      return;
+    }
+    init = true;
+    try {
+      aggregateFunction.beforeStart(functionArguments, ioTDBLocal);
+    } catch (UDFException e) {
+      throw new RuntimeException(
+          "Error occurs when starting user-defined aggregate function "
+              + aggregateFunction.getClass().getName(),
+          e);
+    }
   }
 
   @Override
@@ -76,11 +104,12 @@ public class UserDefinedAggregateFunctionAccumulator 
implements TableAccumulator
   @Override
   public TableAccumulator copy() {
     return new UserDefinedAggregateFunctionAccumulator(
-        analysis, aggregateFunction, functionArguments, inputDataTypes, 
ioTDBLocal);
+        analysis, aggregateFunction, functionArguments, inputDataTypes, 
ioTDBLocal, true);
   }
 
   @Override
   public void addInput(Column[] arguments, AggregationMask mask) {
+    initIfNeeded();
     RecordIterator iterator =
         mask.isSelectAll()
             ? new RecordIterator(
@@ -93,6 +122,7 @@ public class UserDefinedAggregateFunctionAccumulator 
implements TableAccumulator
 
   @Override
   public void addIntermediate(Column argument) {
+    initIfNeeded();
     checkArgument(
         argument instanceof BinaryColumn
             || (argument instanceof RunLengthEncodedColumn
@@ -118,6 +148,7 @@ public class UserDefinedAggregateFunctionAccumulator 
implements TableAccumulator
 
   @Override
   public void evaluateFinal(ColumnBuilder columnBuilder) {
+    initIfNeeded();
     ResultValue resultValue = new ResultValue(columnBuilder);
     aggregateFunction.outputFinal(state, resultValue, ioTDBLocal);
   }
diff --git 
a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
 
b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
index 5cffef01105..a530ce4f7de 100644
--- 
a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
+++ 
b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
@@ -27,6 +27,7 @@ import org.apache.iotdb.calc.i18n.CalcMessages;
 import org.apache.iotdb.udf.api.IoTDBLocal;
 import org.apache.iotdb.udf.api.State;
 import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments;
+import org.apache.iotdb.udf.api.exception.UDFException;
 import org.apache.iotdb.udf.api.relational.AggregateFunction;
 import org.apache.iotdb.udf.api.utils.ResultValue;
 
@@ -53,6 +54,7 @@ public class GroupedUserDefinedAggregateAccumulator 
implements GroupedAccumulato
   private final ObjectBigArray<State> stateArray;
   private final List<Type> inputDataTypes;
   private final IoTDBLocal ioTDBLocal;
+  private boolean init = false;
 
   public GroupedUserDefinedAggregateAccumulator(
       AggregateFunction aggregateFunction,
@@ -67,6 +69,21 @@ public class GroupedUserDefinedAggregateAccumulator 
implements GroupedAccumulato
     this.ioTDBLocal = ioTDBLocal;
   }
 
+  private void initIfNeeded() {
+    if (init) {
+      return;
+    }
+    init = true;
+    try {
+      aggregateFunction.beforeStart(functionArguments, ioTDBLocal);
+    } catch (UDFException e) {
+      throw new RuntimeException(
+          "Error occurs when starting user-defined aggregate function "
+              + aggregateFunction.getClass().getName(),
+          e);
+    }
+  }
+
   @Override
   public long getEstimatedSize() {
     return INSTANCE_SIZE;
@@ -88,6 +105,7 @@ public class GroupedUserDefinedAggregateAccumulator 
implements GroupedAccumulato
 
   @Override
   public void addInput(int[] groupIds, Column[] arguments, AggregationMask 
mask) {
+    initIfNeeded();
     RecordIterator iterator =
         mask.isSelectAll()
             ? new RecordIterator(
@@ -115,6 +133,7 @@ public class GroupedUserDefinedAggregateAccumulator 
implements GroupedAccumulato
 
   @Override
   public void addIntermediate(int[] groupIds, Column argument) {
+    initIfNeeded();
     checkArgument(
         argument instanceof BinaryColumn
             || (argument instanceof RunLengthEncodedColumn
@@ -146,6 +165,7 @@ public class GroupedUserDefinedAggregateAccumulator 
implements GroupedAccumulato
 
   @Override
   public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) {
+    initIfNeeded();
     ResultValue resultValue = new ResultValue(columnBuilder);
     aggregateFunction.outputFinal(getOrCreateState(groupId), resultValue, 
ioTDBLocal);
   }
diff --git 
a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java
 
b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java
index 07d92d967f8..6955dd56663 100644
--- 
a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java
+++ 
b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java
@@ -37,11 +37,15 @@ import org.apache.tsfile.read.common.type.Type;
 import java.util.List;
 import java.util.stream.Collectors;
 
+import static com.google.common.base.Preconditions.checkArgument;
+
 public class UserDefineScalarFunctionTransformer extends 
MultiColumnTransformer {
 
   private final ScalarFunction scalarFunction;
+  private final FunctionArguments parameters;
   private final List<Type> inputTypes;
   private final IoTDBLocal ioTDBLocal;
+  private boolean init = false;
 
   public UserDefineScalarFunctionTransformer(
       Type returnType,
@@ -51,15 +55,35 @@ public class UserDefineScalarFunctionTransformer extends 
MultiColumnTransformer
       ColumnTransformerBuilder.Context context) {
     super(returnType, childrenTransformers);
     this.scalarFunction = scalarFunction;
+    this.parameters = parameters;
     this.ioTDBLocal = createIoTDBLocal(context);
     this.inputTypes =
         
childrenTransformers.stream().map(ColumnTransformer::getType).collect(Collectors.toList());
+  }
+
+  private static IoTDBLocal createIoTDBLocal(ColumnTransformerBuilder.Context 
context) {
+    IoTDBLocalFactory factory = context.getIoTDBLocalFactory();
+    String fragmentInstanceId = context.getFragmentInstanceId();
+    String outerGlobalQueryId = context.getOuterGlobalQueryId();
+    long outerLocalQueryId = context.getOuterLocalQueryId();
+    checkArgument(factory != null, "IoTDBLocalFactory must not be null for UDF 
execution");
+    checkArgument(
+        fragmentInstanceId != null, "fragmentInstanceId must not be null for 
UDF execution");
+    checkArgument(
+        outerGlobalQueryId != null, "outerGlobalQueryId must not be null for 
UDF execution");
+    checkArgument(
+        outerLocalQueryId >= 0, "outerLocalQueryId must not be negative for 
UDF execution");
+    return factory.create(
+        context.getSessionInfo(), fragmentInstanceId, outerLocalQueryId, 
outerGlobalQueryId);
+  }
+
+  private void initIfNeeded() {
+    if (init) {
+      return;
+    }
+    init = true;
     try {
-      if (ioTDBLocal != null) {
-        scalarFunction.beforeStart(parameters, ioTDBLocal);
-      } else {
-        scalarFunction.beforeStart(parameters);
-      }
+      scalarFunction.beforeStart(parameters, ioTDBLocal);
     } catch (UDFException e) {
       throw new RuntimeException(
           "Error occurs when starting user-defined scalar function "
@@ -68,31 +92,14 @@ public class UserDefineScalarFunctionTransformer extends 
MultiColumnTransformer
     }
   }
 
-  private static IoTDBLocal createIoTDBLocal(ColumnTransformerBuilder.Context 
context) {
-    IoTDBLocalFactory factory = context.getIoTDBLocalFactory();
-    if (factory == null
-        || context.getFragmentInstanceId() == null
-        || context.getOuterGlobalQueryId() == null
-        || context.getOuterLocalQueryId() < 0) {
-      return null;
-    }
-    return factory.create(
-        context.getSessionInfo(),
-        context.getFragmentInstanceId(),
-        context.getOuterLocalQueryId(),
-        context.getOuterGlobalQueryId());
-  }
-
   @Override
   protected void doTransform(
       List<Column> childrenColumns, ColumnBuilder builder, int positionCount) {
+    initIfNeeded();
     RecordIterator iterator = new RecordIterator(childrenColumns, inputTypes, 
positionCount);
     while (iterator.hasNext()) {
       try {
-        Object result =
-            ioTDBLocal != null
-                ? scalarFunction.evaluate(iterator.next(), ioTDBLocal)
-                : scalarFunction.evaluate(iterator.next());
+        Object result = scalarFunction.evaluate(iterator.next(), ioTDBLocal);
         if (result == null) {
           builder.appendNull();
         } else {
@@ -110,6 +117,7 @@ public class UserDefineScalarFunctionTransformer extends 
MultiColumnTransformer
   @Override
   protected void doTransform(
       List<Column> childrenColumns, ColumnBuilder builder, int positionCount, 
boolean[] selection) {
+    initIfNeeded();
     RecordIterator iterator = new RecordIterator(childrenColumns, inputTypes, 
positionCount);
     int i = 0;
     while (iterator.hasNext()) {
@@ -119,10 +127,7 @@ public class UserDefineScalarFunctionTransformer extends 
MultiColumnTransformer
           builder.appendNull();
           continue;
         }
-        Object result =
-            ioTDBLocal != null
-                ? scalarFunction.evaluate(input, ioTDBLocal)
-                : scalarFunction.evaluate(input);
+        Object result = scalarFunction.evaluate(input, ioTDBLocal);
         if (result == null) {
           builder.appendNull();
         } else {
@@ -140,12 +145,8 @@ public class UserDefineScalarFunctionTransformer extends 
MultiColumnTransformer
   @Override
   public void close() {
     super.close();
-    if (ioTDBLocal != null) {
-      ioTDBLocal.close();
-      scalarFunction.beforeDestroy(ioTDBLocal);
-    } else {
-      scalarFunction.beforeDestroy();
-    }
+    scalarFunction.beforeDestroy(ioTDBLocal);
+    ioTDBLocal.close();
   }
 
   @Override

Reply via email to