xi-db commented on code in PR #46672: URL: https://github.com/apache/spark/pull/46672#discussion_r2185476955
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala: ########## @@ -561,4 +565,314 @@ case class TryAesDecrypt( override protected def withNewChildInternal(newChild: Expression): Expression = this.copy(replacement = newChild) } + +/** + * A function that compress input using Zstandard. + * If either argument is NULL, the return value is NULL. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr[, level[, streaming_mode]]) - Returns a compressed value of `expr` using Zstandard with the specified compression `level`. + The default level is 3. Uses single-pass mode by default. + """, + arguments = """ + Arguments: + * expr - The binary value to compress. + * level - Optional integer argument that represents the compression level. The compression level controls the trade-off between compression speed and compression ratio. + Valid values: between 1 and 22 inclusive, where 1 means fastest but lowest compression ratio, and 22 means slowest but highest compression ratio. + The default level is 3 if not specified. + * streaming_mode - Optional boolean argument that represents whether to use streaming mode. If true, the function will compress the input in streaming mode. + The default value is false. + """, + examples = """ + Examples: + > SELECT base64(_FUNC_(repeat("Apache Spark ", 10))); + KLUv/SCCpQAAaEFwYWNoZSBTcGFyayABABLS+QU= + > SELECT base64(_FUNC_(repeat("Apache Spark ", 10), 5)); + KLUv/SCCpQAAaEFwYWNoZSBTcGFyayABABLS+QU= + > SELECT base64(_FUNC_(repeat("Apache Spark ", 10), 3, true)); + KLUv/QBYpAAAaEFwYWNoZSBTcGFyayABABLS+QUBAAA= + """, + since = "4.0.0", + group = "misc_funcs") +// scalastyle:on line.size.limit +case class ZstdCompress(input: Expression, level: Expression, streamingMode: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true + + def this(input: Expression, level: Expression) = + this(input, level, Literal(false)) + + def this(input: Expression) = + this(input, Literal(3)) + + override def prettyName: String = "zstd_compress" + + override protected def withNewChildrenInternal( + newFirst: Expression, + newSecond: Expression, + newThird: Expression): Expression = { + copy(newFirst, newSecond, newThird) + } + + override def first: Expression = input + override def second: Expression = level + override def third: Expression = streamingMode + + override def inputTypes: Seq[AbstractDataType] = + Seq(BinaryType, IntegerType, BooleanType) + + override def dataType: DataType = BinaryType + + override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { + ZstdUtils.compress( + input1.asInstanceOf[Array[Byte]], + input2.asInstanceOf[Number].intValue(), + input3.asInstanceOf[Boolean] + ) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val clz = ZstdUtils.getClass.getName.stripSuffix("$") + defineCodeGen( + ctx, + ev, + f = (input, level, streamingMode) => { + s"$clz.compress($input, $level, $streamingMode)" + } + ) + } +} + +// scalastyle:off line.size.limit +/** + * A function that decompress input using Zstandard. On decompression failure, it throws an exception. + * If either argument is NULL, the return value is NULL. + */ +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the decompressed value of `expr` using Zstandard. + Supports data compressed in both single-pass mode and streaming mode. + On decompression failure, it throws an exception. + """, + arguments = """ + Arguments: + * expr - The binary value to decompress. + """, + examples = """ + Examples: + > SELECT string(_FUNC_(unbase64("KLUv/SCCpQAAaEFwYWNoZSBTcGFyayABABLS+QU="))); + Apache Spark Apache Spark Apache Spark Apache Spark Apache Spark Apache Spark Apache Spark Apache Spark Apache Spark Apache Spark + """, + since = "4.0.0", + group = "misc_funcs") +// scalastyle:on line.size.limit +case class ZstdDecompress(input: Expression) + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override def prettyName: String = "zstd_decompress" + + override def child: Expression = input + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + + override def dataType: DataType = BinaryType + + override protected def nullSafeEval(input: Any): Any = { + val result = ZstdUtils.decompress(input.asInstanceOf[Array[Byte]]) Review Comment: In `ZstdUtils`, the `decompress` method returns null when the input is invalid, then we throw a `QueryExecutionErrors.zstdDecompressError()` error to user. Do you think it's ok? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org