lihaosky commented on code in PR #26583: URL: https://github.com/apache/flink/pull/26583#discussion_r2108131196
########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java: ########## @@ -49,6 +59,75 @@ public static void adjustTypeForMapConstructor( } } + public static boolean throwValidationSignatureErrorOrReturnFalse( + SqlCallBinding callBinding, boolean throwOnFailure) { + if (throwOnFailure) { + throw callBinding.newValidationSignatureError(); + } else { + return false; + } + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + public static boolean throwExceptionOrReturnFalse( + Optional<RuntimeException> e, boolean throwOnFailure) { + if (e.isPresent()) { + if (throwOnFailure) { + throw e.get(); + } else { + return false; + } + } else { + return true; + } + } + + /** + * Checks whether the heading operands are in the form {@code (ROW, DESCRIPTOR, DESCRIPTOR ..., + * other params)}, returning whether successful, and throwing if any columns are not found. + * + * @param callBinding The call binding + * @param descriptorStartPos The position of the first descriptor operand + * @param descriptorCount The number of descriptors following the first operand (e.g. the table) + * @return true if validation passes; throws if any columns are not found + */ + public static boolean checkTableAndDescriptorOperands( + SqlCallBinding callBinding, int descriptorStartPos, int descriptorCount) { Review Comment: There will be two descriptors in `ml_evaluate`: ``` SELECT * FROM ML_EVALUATE(TABLE `eval_data`, MODEL `classifier_model`, DESCRIPTOR(label), DESCRIPTOR(f1, f2)) ``` ########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java: ########## @@ -61,9 +78,16 @@ public boolean argumentMustBeScalar(int ordinal) { @Override protected RelDataType inferRowType(SqlOperatorBinding opBinding) { - // TODO: FLINK-37780 output type based on table schema and model output schema - // model output schema to be available after integrated with SqlExplicitModelCall - return opBinding.getOperandType(1); + final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + final RelDataType inputRowType = opBinding.getOperandType(0); + final RelDataType modelOutputRowType = opBinding.getOperandType(1); + + return typeFactory + .builder() + .kind(inputRowType.getStructKind()) + .addAll(inputRowType.getFieldList()) Review Comment: Do you mean we need to make field names unique? I'm following `SqlWindowTableFunction` which doesn't check if input table column has `window_start` etc. I'm on the fence here ########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java: ########## @@ -112,5 +140,92 @@ public String getAllowedSignatures(SqlOperator op, String opName) { return opName + "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]]"; } + + private static Optional<RuntimeException> checkModelSignature(SqlCallBinding callBinding) { + SqlValidator validator = callBinding.getValidator(); + + // Check second operand is SqlModelCall + if (!(callBinding.operand(1) instanceof SqlModelCall)) { + return Optional.of( + new ValidationException("Second operand must be a model identifier.")); + } + + // Get descriptor columns + SqlCall descriptorCall = (SqlCall) callBinding.operand(2); + List<SqlNode> descriptCols = descriptorCall.getOperandList(); + + // Get model input size + SqlModelCall modelCall = (SqlModelCall) callBinding.operand(1); + RelDataType modelInputType = modelCall.getInputType(validator); + + // Check sizes match + if (descriptCols.size() != modelInputType.getFieldCount()) { + return Optional.of( + new ValidationException( + String.format( + "Number of descriptor input columns (%d) does not match model input size (%d)", + descriptCols.size(), modelInputType.getFieldCount()))); + } + + // Check types match + final RelDataType tableType = validator.getValidatedNodeType(callBinding.operand(0)); + final SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher(); + for (int i = 0; i < descriptCols.size(); i++) { + SqlIdentifier columnName = (SqlIdentifier) descriptCols.get(i); + String descriptColName = + columnName.isSimple() + ? columnName.getSimple() + : Util.last(columnName.names); Review Comment: `validateCall` makes the descriptorColumn name qualified name which prepend the table name. e.g. it make `col` to `mytable`.`col` ########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java: ########## @@ -87,21 +111,25 @@ public List<String> paramNames() { @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { - // TODO: FLINK-37780 Check operand types after integrated with SqlExplicitModelCall in - // validator - return false; + if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 2, 1)) { Review Comment: `validateCall` doesn't check the position and count of descriptor also doesn't check first param needs to be table. `validateCall` can be used by both `ml_predict` and `ml_evaluate` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org