This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 0fc3bb44f2 [GLUTEN-10046][Flink] Improve function validation (#10047)
0fc3bb44f2 is described below
commit 0fc3bb44f24ef409e6ff2be60256df99f68d454a
Author: lgbo <[email protected]>
AuthorDate: Fri Jul 18 14:24:06 2025 +0800
[GLUTEN-10046][Flink] Improve function validation (#10047)
* improve function validation
* update
---
...RexCallConverter.java => ValidationResult.java} | 29 +++++---
.../rexnode/functions/BaseRexCallConverters.java | 57 +++++++--------
.../functions/BasicCompareOperatorConverters.java | 34 ++++++---
.../rexnode/functions/ModRexCallConverter.java | 16 ++++-
.../gluten/rexnode/functions/RexCallConverter.java | 3 +-
.../rexnode/functions/RexCallConverterFactory.java | 30 +++++++-
...verters.java => SubstractRexCallConverter.java} | 81 +++++++---------------
7 files changed, 136 insertions(+), 114 deletions(-)
diff --git
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverter.java
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/ValidationResult.java
similarity index 59%
copy from
gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverter.java
copy to
gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/ValidationResult.java
index 794f4c4eec..416e147c67 100644
---
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverter.java
+++
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/ValidationResult.java
@@ -14,17 +14,30 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.rexnode.functions;
+package org.apache.gluten.rexnode;
-import org.apache.gluten.rexnode.RexConversionContext;
+public class ValidationResult {
+ private final boolean ok;
+ private final String message;
-import io.github.zhztheplayer.velox4j.expression.TypedExpr;
+ public boolean isOk() {
+ return ok;
+ }
-import org.apache.calcite.rex.RexCall;
+ public String getMessage() {
+ return message;
+ }
-public interface RexCallConverter {
- // Let the Converter decide how to build the arguments.
- TypedExpr toTypedExpr(RexCall callNode, RexConversionContext context);
+ public ValidationResult(boolean ok, String message) {
+ this.ok = ok;
+ this.message = message;
+ }
- boolean isSupported(RexCall callNode, RexConversionContext context);
+ public static ValidationResult success() {
+ return new ValidationResult(true, "Validation successful");
+ }
+
+ public static ValidationResult failure(String message) {
+ return new ValidationResult(false, message);
+ }
}
diff --git
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BaseRexCallConverters.java
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BaseRexCallConverters.java
index 8a7c41ab1d..8ab9177386 100644
---
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BaseRexCallConverters.java
+++
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BaseRexCallConverters.java
@@ -19,16 +19,16 @@ package org.apache.gluten.rexnode.functions;
import org.apache.gluten.rexnode.RexConversionContext;
import org.apache.gluten.rexnode.RexNodeConverter;
import org.apache.gluten.rexnode.TypeUtils;
+import org.apache.gluten.rexnode.ValidationResult;
import io.github.zhztheplayer.velox4j.expression.CallTypedExpr;
import io.github.zhztheplayer.velox4j.expression.TypedExpr;
-import io.github.zhztheplayer.velox4j.type.BigIntType;
-import io.github.zhztheplayer.velox4j.type.TimestampType;
import io.github.zhztheplayer.velox4j.type.Type;
import org.apache.calcite.rex.RexCall;
import java.util.List;
+import java.util.stream.Collectors;
abstract class BaseRexCallConverter implements RexCallConverter {
protected final String functionName;
@@ -46,10 +46,18 @@ abstract class BaseRexCallConverter implements
RexCallConverter {
}
@Override
- public boolean isSupported(RexCall callNode, RexConversionContext context) {
+ public ValidationResult isSuitable(RexCall callNode, RexConversionContext
context) {
// Default implementation assumes all RexCall nodes are supported.
// Subclasses can override this method to provide specific support checks.
- return true;
+ return ValidationResult.success();
+ }
+
+ protected String getFunctionProtoTypeName(RexCall callNode) {
+ String operandTypeNames =
+ callNode.getOperands().stream()
+ .map(arg -> arg.getType().toString())
+ .collect(Collectors.joining(", "));
+ return String.format("(%s) -> %s", operandTypeNames,
callNode.getType().toString());
}
}
@@ -72,9 +80,18 @@ class BasicArithmeticOperatorRexCallConverter extends
BaseRexCallConverter {
}
@Override
- public boolean isSupported(RexCall callNode, RexConversionContext context) {
- return callNode.getOperands().stream()
- .allMatch(param ->
TypeUtils.isNumericType(RexNodeConverter.toType(param.getType())));
+ public ValidationResult isSuitable(RexCall callNode, RexConversionContext
context) {
+ boolean typesValidate =
+ callNode.getOperands().stream()
+ .allMatch(param ->
TypeUtils.isNumericType(RexNodeConverter.toType(param.getType())));
+ if (!typesValidate) {
+ String message =
+ String.format(
+ "Arithmetic operation '%s' requires numeric operands, but found:
%s",
+ functionName, getFunctionProtoTypeName(callNode));
+ return ValidationResult.failure(message);
+ }
+ return ValidationResult.success();
}
@Override
@@ -86,29 +103,3 @@ class BasicArithmeticOperatorRexCallConverter extends
BaseRexCallConverter {
return new CallTypedExpr(resultType, alignedParams, functionName);
}
}
-
-class SubtractRexCallConverter extends BaseRexCallConverter {
-
- public SubtractRexCallConverter() {
- super("subtract");
- }
-
- @Override
- public TypedExpr toTypedExpr(RexCall callNode, RexConversionContext context)
{
- List<TypedExpr> params = getParams(callNode, context);
-
- if (params.get(0).getReturnType() instanceof TimestampType
- && params.get(1).getReturnType() instanceof BigIntType) {
-
- Type bigIntType = new BigIntType();
- TypedExpr castExpr = new CallTypedExpr(bigIntType,
List.of(params.get(0)), "cast");
-
- List<TypedExpr> newParams = List.of(castExpr, params.get(1));
- return new CallTypedExpr(bigIntType, newParams, functionName);
- }
-
- List<TypedExpr> alignedParams =
TypeUtils.promoteTypeForArithmeticExpressions(params);
- Type resultType = getResultType(callNode);
- return new CallTypedExpr(resultType, alignedParams, functionName);
- }
-}
diff --git
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BasicCompareOperatorConverters.java
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BasicCompareOperatorConverters.java
index 0c0d6b1642..584ccd9fe3 100644
---
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BasicCompareOperatorConverters.java
+++
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BasicCompareOperatorConverters.java
@@ -19,6 +19,7 @@ package org.apache.gluten.rexnode.functions;
import org.apache.gluten.rexnode.RexConversionContext;
import org.apache.gluten.rexnode.RexNodeConverter;
import org.apache.gluten.rexnode.TypeUtils;
+import org.apache.gluten.rexnode.ValidationResult;
import io.github.zhztheplayer.velox4j.expression.CallTypedExpr;
import io.github.zhztheplayer.velox4j.expression.CastTypedExpr;
@@ -38,10 +39,19 @@ class StringCompareRexCallConverter extends
BaseRexCallConverter {
}
@Override
- public boolean isSupported(RexCall callNode, RexConversionContext context) {
+ public ValidationResult isSuitable(RexCall callNode, RexConversionContext
context) {
// This converter supports string comparison functions.
- return callNode.getOperands().stream()
- .allMatch(param ->
TypeUtils.isStringType(RexNodeConverter.toType(param.getType())));
+ boolean typesValidate =
+ callNode.getOperands().stream()
+ .allMatch(param ->
TypeUtils.isStringType(RexNodeConverter.toType(param.getType())));
+ if (!typesValidate) {
+ String message =
+ String.format(
+ "String comparison operation requires all operands to be string
types, but found: %s",
+ getFunctionProtoTypeName(callNode));
+ return ValidationResult.failure(message);
+ }
+ return ValidationResult.success();
}
@Override
@@ -59,18 +69,24 @@ class StringNumberCompareRexCallConverter extends
BaseRexCallConverter {
}
@Override
- public boolean isSupported(RexCall callNode, RexConversionContext context) {
+ public ValidationResult isSuitable(RexCall callNode, RexConversionContext
context) {
// This converter supports string and numeric comparison functions.
List<Type> paramTypes =
callNode.getOperands().stream()
.map(param -> RexNodeConverter.toType(param.getType()))
.collect(Collectors.toList());
- if ((TypeUtils.isNumericType(paramTypes.get(0)) &&
TypeUtils.isStringType(paramTypes.get(1)))
- || (TypeUtils.isStringType(paramTypes.get(0))
- && TypeUtils.isNumericType(paramTypes.get(1)))) {
- return true;
+ boolean typesValidate =
+ (TypeUtils.isNumericType(paramTypes.get(0)) &&
TypeUtils.isStringType(paramTypes.get(1)))
+ || (TypeUtils.isStringType(paramTypes.get(0))
+ && TypeUtils.isNumericType(paramTypes.get(1)));
+ if (!typesValidate) {
+ String message =
+ String.format(
+ "String and numeric comparison operation requires one string and
one numeric operand, but found: %s",
+ getFunctionProtoTypeName(callNode));
+ return ValidationResult.failure(message);
}
- return false;
+ return ValidationResult.success();
}
@Override
diff --git
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/ModRexCallConverter.java
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/ModRexCallConverter.java
index d2dc8a1788..9d96a1dd44 100644
---
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/ModRexCallConverter.java
+++
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/ModRexCallConverter.java
@@ -19,6 +19,7 @@ package org.apache.gluten.rexnode.functions;
import org.apache.gluten.rexnode.RexConversionContext;
import org.apache.gluten.rexnode.RexNodeConverter;
import org.apache.gluten.rexnode.TypeUtils;
+import org.apache.gluten.rexnode.ValidationResult;
import io.github.zhztheplayer.velox4j.expression.CallTypedExpr;
import io.github.zhztheplayer.velox4j.expression.TypedExpr;
@@ -36,10 +37,19 @@ public class ModRexCallConverter extends
BaseRexCallConverter {
}
@Override
- public boolean isSupported(RexCall callNode, RexConversionContext context) {
+ public ValidationResult isSuitable(RexCall callNode, RexConversionContext
context) {
// Modulus operation is supported for numeric types.
- return callNode.getOperands().size() == 2
- &&
TypeUtils.isNumericType(RexNodeConverter.toType(callNode.getType()));
+ boolean typesValidate =
+ callNode.getOperands().size() == 2
+ &&
TypeUtils.isNumericType(RexNodeConverter.toType(callNode.getType()));
+ if (!typesValidate) {
+ String message =
+ String.format(
+ "Modulus operation requires exactly two numeric operands, but
found: %s",
+ getFunctionProtoTypeName(callNode));
+ return ValidationResult.failure(message);
+ }
+ return ValidationResult.success();
}
@Override
diff --git
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverter.java
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverter.java
index 794f4c4eec..ffa54e524b 100644
---
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverter.java
+++
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverter.java
@@ -17,6 +17,7 @@
package org.apache.gluten.rexnode.functions;
import org.apache.gluten.rexnode.RexConversionContext;
+import org.apache.gluten.rexnode.ValidationResult;
import io.github.zhztheplayer.velox4j.expression.TypedExpr;
@@ -26,5 +27,5 @@ public interface RexCallConverter {
// Let the Converter decide how to build the arguments.
TypedExpr toTypedExpr(RexCall callNode, RexConversionContext context);
- boolean isSupported(RexCall callNode, RexConversionContext context);
+ ValidationResult isSuitable(RexCall callNode, RexConversionContext context);
}
diff --git
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
index 67a226546d..eac71dca7e 100644
---
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
+++
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
@@ -17,9 +17,11 @@
package org.apache.gluten.rexnode.functions;
import org.apache.gluten.rexnode.RexConversionContext;
+import org.apache.gluten.rexnode.ValidationResult;
import org.apache.calcite.rex.RexCall;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@@ -71,16 +73,38 @@ public class RexCallConverterFactory {
throw new RuntimeException("Function not supported: " + operatorName);
}
+ List<String> failureMessages = new ArrayList<>();
List<RexCallConverter> converterList =
builders.stream()
.map(RexCallConverterBuilder::build)
- .filter(c -> c.isSupported(callNode, context))
+ .filter(
+ c -> {
+ ValidationResult validationResult = c.isSuitable(callNode,
context);
+ if (!validationResult.isOk()) {
+ failureMessages.add(
+ c.getClass().getName() + ": " +
validationResult.getMessage());
+ return false;
+ } else {
+ return true;
+ }
+ })
.collect(Collectors.toList());
if (converterList.size() > 1) {
- throw new RuntimeException("Multiple converters found for: " +
operatorName);
+ String converterClasses =
+ converterList.stream()
+ .map(converter -> converter.getClass().getName())
+ .collect(Collectors.joining(", "));
+ String message =
+ String.format(
+ "Multiple converters found for: %s. Converters: %s.",
operatorName, converterClasses);
+ throw new RuntimeException(message);
} else if (converterList.isEmpty()) {
- throw new RuntimeException("No suitable converter found for: " +
operatorName);
+ String message =
+ String.format(
+ "No suitable converter found for: %s. Reason:\n%s",
+ operatorName, String.join("\n", failureMessages));
+ throw new RuntimeException(message);
}
return converterList.get(0);
diff --git
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BaseRexCallConverters.java
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/SubstractRexCallConverter.java
similarity index 53%
copy from
gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BaseRexCallConverters.java
copy to
gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/SubstractRexCallConverter.java
index 8a7c41ab1d..2d5a1e5c7f 100644
---
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/BaseRexCallConverters.java
+++
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/SubstractRexCallConverter.java
@@ -19,6 +19,7 @@ package org.apache.gluten.rexnode.functions;
import org.apache.gluten.rexnode.RexConversionContext;
import org.apache.gluten.rexnode.RexNodeConverter;
import org.apache.gluten.rexnode.TypeUtils;
+import org.apache.gluten.rexnode.ValidationResult;
import io.github.zhztheplayer.velox4j.expression.CallTypedExpr;
import io.github.zhztheplayer.velox4j.expression.TypedExpr;
@@ -29,63 +30,7 @@ import io.github.zhztheplayer.velox4j.type.Type;
import org.apache.calcite.rex.RexCall;
import java.util.List;
-
-abstract class BaseRexCallConverter implements RexCallConverter {
- protected final String functionName;
-
- public BaseRexCallConverter(String functionName) {
- this.functionName = functionName;
- }
-
- protected List<TypedExpr> getParams(RexCall callNode, RexConversionContext
context) {
- return RexNodeConverter.toTypedExpr(callNode.getOperands(), context);
- }
-
- protected Type getResultType(RexCall callNode) {
- return RexNodeConverter.toType(callNode.getType());
- }
-
- @Override
- public boolean isSupported(RexCall callNode, RexConversionContext context) {
- // Default implementation assumes all RexCall nodes are supported.
- // Subclasses can override this method to provide specific support checks.
- return true;
- }
-}
-
-class DefaultRexCallConverter extends BaseRexCallConverter {
- public DefaultRexCallConverter(String functionName) {
- super(functionName);
- }
-
- @Override
- public TypedExpr toTypedExpr(RexCall callNode, RexConversionContext context)
{
- List<TypedExpr> params = getParams(callNode, context);
- Type resultType = getResultType(callNode);
- return new CallTypedExpr(resultType, params, functionName);
- }
-}
-
-class BasicArithmeticOperatorRexCallConverter extends BaseRexCallConverter {
- public BasicArithmeticOperatorRexCallConverter(String functionName) {
- super(functionName);
- }
-
- @Override
- public boolean isSupported(RexCall callNode, RexConversionContext context) {
- return callNode.getOperands().stream()
- .allMatch(param ->
TypeUtils.isNumericType(RexNodeConverter.toType(param.getType())));
- }
-
- @Override
- public TypedExpr toTypedExpr(RexCall callNode, RexConversionContext context)
{
- List<TypedExpr> params = getParams(callNode, context);
- // If types are different, align them
- List<TypedExpr> alignedParams =
TypeUtils.promoteTypeForArithmeticExpressions(params);
- Type resultType = getResultType(callNode);
- return new CallTypedExpr(resultType, alignedParams, functionName);
- }
-}
+import java.util.stream.Collectors;
class SubtractRexCallConverter extends BaseRexCallConverter {
@@ -93,6 +38,28 @@ class SubtractRexCallConverter extends BaseRexCallConverter {
super("subtract");
}
+ @Override
+ public ValidationResult isSuitable(RexCall callNode, RexConversionContext
context) {
+ // Subtraction operation is supported for numeric types.
+ List<Type> paramTypes =
+ callNode.getOperands().stream()
+ .map(param -> RexNodeConverter.toType(param.getType()))
+ .collect(Collectors.toList());
+ boolean validate =
+ callNode.getOperands().size() == 2
+ && (paramTypes.stream().allMatch(TypeUtils::isNumericType)
+ || (paramTypes.get(0) instanceof TimestampType
+ && paramTypes.get(1) instanceof BigIntType));
+ if (!validate) {
+ String message =
+ String.format(
+ "Subtraction operation requires exactly two numeric operands or
timestamp - numeric, but found: %s",
+ getFunctionProtoTypeName(callNode));
+ return ValidationResult.failure(message);
+ }
+ return ValidationResult.success();
+ }
+
@Override
public TypedExpr toTypedExpr(RexCall callNode, RexConversionContext context)
{
List<TypedExpr> params = getParams(callNode, context);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]