cloud-fan commented on code in PR #53924:
URL: https://github.com/apache/spark/pull/53924#discussion_r2735623758
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala:
##########
@@ -196,3 +196,135 @@ case class VectorL2Distance(left: Expression, right:
Expression)
copy(left = newChildren(0), right = newChildren(1))
}
}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(vector, degree) - Returns the Lp norm of a float vector using the
specified degree.
+ Degree defaults to 2.0 (Euclidean norm) if unspecified. Supported values:
1.0 (L1 norm),
+ 2.0 (L2 norm), float('inf') (infinity norm).
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(3.0F, 4.0F), 2.0F);
+ 5.0
+ > SELECT _FUNC_(array(3.0F, 4.0F), 1.0F);
+ 7.0
+ > SELECT _FUNC_(array(3.0F, 4.0F), float('inf'));
+ 4.0
+ """,
+ since = "4.2.0",
+ group = "vector_funcs"
+)
+// scalastyle:on line.size.limit
+case class VectorNorm(vector: Expression, degree: Expression)
+ extends RuntimeReplaceable with QueryErrorsBase {
+
+ def this(vector: Expression) = this(vector, Literal(2.0f))
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (vector.dataType, degree.dataType) match {
+ case (ArrayType(FloatType, _), FloatType) =>
+ TypeCheckResult.TypeCheckSuccess
+ case (ArrayType(FloatType, _), _) =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(1),
+ "requiredType" -> toSQLType(FloatType),
+ "inputSql" -> toSQLExpr(degree),
+ "inputType" -> toSQLType(degree.dataType)))
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> toSQLType(ArrayType(FloatType)),
+ "inputSql" -> toSQLExpr(vector),
+ "inputType" -> toSQLType(vector.dataType)))
+ }
+ }
+
+ override lazy val replacement: Expression = StaticInvoke(
Review Comment:
minor: `StaticInvoke` has a `returnNullable` flag, which is true by default
for safety. If the function won't return null with non-null inputs, we can
specify `returnNullable` as false.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]