xi-db commented on code in PR #46672:
URL: https://github.com/apache/spark/pull/46672#discussion_r2185476955


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala:
##########
@@ -561,4 +565,314 @@ case class TryAesDecrypt(
   override protected def withNewChildInternal(newChild: Expression): 
Expression =
     this.copy(replacement = newChild)
 }
+
+/**
+ * A function that compress input using Zstandard.
+ * If either argument is NULL, the return value is NULL.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr[, level[, streaming_mode]]) - Returns a compressed value of 
`expr` using Zstandard with the specified compression `level`.
+      The default level is 3. Uses single-pass mode by default.
+  """,
+  arguments = """
+    Arguments:
+      * expr - The binary value to compress.
+      * level - Optional integer argument that represents the compression 
level. The compression level controls the trade-off between compression speed 
and compression ratio.
+                Valid values: between 1 and 22 inclusive, where 1 means 
fastest but lowest compression ratio, and 22 means slowest but highest 
compression ratio.
+                The default level is 3 if not specified.
+      * streaming_mode - Optional boolean argument that represents whether to 
use streaming mode. If true, the function will compress the input in streaming 
mode.
+                        The default value is false.
+  """,
+  examples = """
+    Examples:
+      > SELECT base64(_FUNC_(repeat("Apache Spark ", 10)));
+       KLUv/SCCpQAAaEFwYWNoZSBTcGFyayABABLS+QU=
+      > SELECT base64(_FUNC_(repeat("Apache Spark ", 10), 5));
+       KLUv/SCCpQAAaEFwYWNoZSBTcGFyayABABLS+QU=
+      > SELECT base64(_FUNC_(repeat("Apache Spark ", 10), 3, true));
+       KLUv/QBYpAAAaEFwYWNoZSBTcGFyayABABLS+QUBAAA=
+  """,
+  since = "4.0.0",
+  group = "misc_funcs")
+// scalastyle:on line.size.limit
+case class ZstdCompress(input: Expression, level: Expression, streamingMode: 
Expression)
+    extends TernaryExpression with ImplicitCastInputTypes {
+  override def nullIntolerant: Boolean = true
+
+  def this(input: Expression, level: Expression) =
+    this(input, level, Literal(false))
+
+  def this(input: Expression) =
+    this(input, Literal(3))
+
+  override def prettyName: String = "zstd_compress"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression,
+      newSecond: Expression,
+      newThird: Expression): Expression = {
+    copy(newFirst, newSecond, newThird)
+  }
+
+  override def first: Expression = input
+  override def second: Expression = level
+  override def third: Expression = streamingMode
+
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(BinaryType, IntegerType, BooleanType)
+
+  override def dataType: DataType = BinaryType
+
+  override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): 
Any = {
+    ZstdUtils.compress(
+      input1.asInstanceOf[Array[Byte]],
+      input2.asInstanceOf[Number].intValue(),
+      input3.asInstanceOf[Boolean]
+    )
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    val clz = ZstdUtils.getClass.getName.stripSuffix("$")
+    defineCodeGen(
+      ctx,
+      ev,
+      f = (input, level, streamingMode) => {
+        s"$clz.compress($input, $level, $streamingMode)"
+      }
+    )
+  }
+}
+
+// scalastyle:off line.size.limit
+/**
+ * A function that decompress input using Zstandard. On decompression failure, 
it throws an exception.
+ * If either argument is NULL, the return value is NULL.
+ */
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr) - Returns the decompressed value of `expr` using Zstandard.
+      Supports data compressed in both single-pass mode and streaming mode.
+      On decompression failure, it throws an exception.
+  """,
+  arguments = """
+    Arguments:
+      * expr - The binary value to decompress.
+  """,
+  examples = """
+    Examples:
+      > SELECT 
string(_FUNC_(unbase64("KLUv/SCCpQAAaEFwYWNoZSBTcGFyayABABLS+QU=")));
+       Apache Spark Apache Spark Apache Spark Apache Spark Apache Spark Apache 
Spark Apache Spark Apache Spark Apache Spark Apache Spark
+  """,
+  since = "4.0.0",
+  group = "misc_funcs")
+// scalastyle:on line.size.limit
+case class ZstdDecompress(input: Expression)
+    extends UnaryExpression with ImplicitCastInputTypes {
+  override def nullIntolerant: Boolean = true
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+
+  override def prettyName: String = "zstd_decompress"
+
+  override def child: Expression = input
+
+  override protected def withNewChildInternal(newChild: Expression): 
Expression = copy(newChild)
+
+  override def dataType: DataType = BinaryType
+
+  override protected def nullSafeEval(input: Any): Any = {
+    val result = ZstdUtils.decompress(input.asInstanceOf[Array[Byte]])

Review Comment:
   In `ZstdUtils`, the `decompress` method returns null when the input is 
invalid, then we throw a `QueryExecutionErrors.zstdDecompressError()` error to 
user. Do you think it's ok?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to