dtenedor commented on code in PR #39449:
URL: https://github.com/apache/spark/pull/39449#discussion_r1087097719


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -23,9 +23,82 @@ import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
StringType}
 import org.apache.spark.unsafe.types.UTF8String
 
+/**
+ * This trait has common methods, which can be used by Classes implementing 
Masking udf functions

Review Comment:
   these are technically built-in functions, rather than UDFs; maybe just 
delete `udf` from this line



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -23,9 +23,82 @@ import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
StringType}
 import org.apache.spark.unsafe.types.UTF8String
 
+/**
+ * This trait has common methods, which can be used by Classes implementing 
Masking udf functions
+ * Eg : Mask, MaskFirstN, etc
+ */
+private[sql] trait Maskable extends ExpectsInputTypes with QueryErrorsBase {
+  protected val upperCharExpr: Expression
+  protected val lowerCharExpr: Expression
+  protected val digitCharExpr: Expression
+  protected val otherCharExpr: Expression
+
+  @transient

Review Comment:
   maybe add a comment saying (1) these are lazy vals in order to cache them 
for efficiency, and (2) they must all be constant expression trees, which we 
enforce in each of the derived expressions? (If we violate (2), we have a 
correctness bug.)



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -236,40 +265,271 @@ case class Mask(
   }
 
   /**
-   * Returns the [[DataType]] of the result of evaluating this expression. It 
is invalid to query
-   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   * Returns a Seq of the children of this node. Children should not change. 
Immutability required
+   * for containsChild optimization
    */
-  override def dataType: DataType = StringType
+  override def children: Seq[Expression] =
+    Seq(input, upperCharExpr, lowerCharExpr, digitCharExpr, otherCharExpr)
+
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): Mask =
+    copy(
+      input = newChildren(0),
+      upperCharExpr = newChildren(1),
+      lowerCharExpr = newChildren(2),
+      digitCharExpr = newChildren(3),
+      otherCharExpr = newChildren(4))
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage =
+    """_FUNC_(input[, charCount, upperChar, lowerChar, digitChar, otherChar]) 
- masks the first n characters of given string value.
+       The function masks the first n characters of the value with 'X' or 'x', 
and numbers with 'n'.
+       This can be useful for creating copies of tables with sensitive 
information removed.

Review Comment:
   also mention here in this comment what happens if the string does not 
contain enough characters? (we replace the requested number of characters with 
the size of the input string instead, for that row)



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -23,9 +23,82 @@ import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
StringType}
 import org.apache.spark.unsafe.types.UTF8String
 
+/**
+ * This trait has common methods, which can be used by Classes implementing 
Masking udf functions
+ * Eg : Mask, MaskFirstN, etc
+ */
+private[sql] trait Maskable extends ExpectsInputTypes with QueryErrorsBase {
+  protected val upperCharExpr: Expression
+  protected val lowerCharExpr: Expression
+  protected val digitCharExpr: Expression
+  protected val otherCharExpr: Expression
+
+  @transient
+  protected lazy val upperChar = upperCharExpr.eval()
+  @transient
+  protected lazy val lowerChar = lowerCharExpr.eval()
+  @transient
+  protected lazy val digitChar = digitCharExpr.eval()
+  @transient
+  protected lazy val otherChar = otherCharExpr.eval()
+
+  protected def validateAdditionalFields(): Seq[Option[TypeCheckResult]] = 
Seq(None)

Review Comment:
   please add a comment saying what this is supposed to check?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -283,7 +543,32 @@ object Mask {
         transformChar(_, maskUpper, maskLower, maskDigit, maskOther).toChar
       }
     }
-    org.apache.spark.unsafe.types.UTF8String.fromString(transformedString)
+    UTF8String.fromString(transformedString)
+  }
+
+  def mask_first_n(
+      input: Any,
+      charCount: Int,
+      maskUpper: Any,
+      maskLower: Any,
+      maskDigit: Any,
+      maskOther: Any): UTF8String = {
+
+    val transformedString = if (input == null) {
+      null
+    } else {
+      val inputStr = input.asInstanceOf[UTF8String]
+      val stringSize = inputStr.numChars
+      val endIdx = if (stringSize < charCount) stringSize else charCount
+      inputStr.toString.zipWithIndex.map { case (ch, i) =>

Review Comment:
   `zipWithIndex` is slow, and this expression will be evaluating for every 
input row. Can we instead split the string based on `endIdx` and then just call 
`transformChar` on every char of the first token?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -236,40 +265,271 @@ case class Mask(
   }
 
   /**
-   * Returns the [[DataType]] of the result of evaluating this expression. It 
is invalid to query
-   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   * Returns a Seq of the children of this node. Children should not change. 
Immutability required
+   * for containsChild optimization
    */
-  override def dataType: DataType = StringType
+  override def children: Seq[Expression] =
+    Seq(input, upperCharExpr, lowerCharExpr, digitCharExpr, otherCharExpr)
+
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): Mask =
+    copy(
+      input = newChildren(0),
+      upperCharExpr = newChildren(1),
+      lowerCharExpr = newChildren(2),
+      digitCharExpr = newChildren(3),
+      otherCharExpr = newChildren(4))
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage =
+    """_FUNC_(input[, charCount, upperChar, lowerChar, digitChar, otherChar]) 
- masks the first n characters of given string value.
+       The function masks the first n characters of the value with 'X' or 'x', 
and numbers with 'n'.
+       This can be useful for creating copies of tables with sensitive 
information removed.
+       Error behavior: null value as replacement argument will throw 
AnalysisError.
+      """,
+  arguments = """
+    Arguments:
+      * input      - string value to mask. Supported types: STRING, VARCHAR, 
CHAR
+      * charCount  - number of characters to be masked. Default value: 4
+      * upperChar  - character to replace upper-case characters with. Specify 
NULL to retain original character. Default value: 'X'
+      * lowerChar  - character to replace lower-case characters with. Specify 
NULL to retain original character. Default value: 'x'
+      * digitChar  - character to replace digit characters with. Specify NULL 
to retain original character. Default value: 'n'
+      * otherChar  - character to replace all other characters with. Specify 
NULL to retain original character. Default value: NULL
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('abcd-EFGH-8765-4321');
+        xxxx-EFGH-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-4321', 9);
+        xxxx-XXXX-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 14);
+        xxxx-XXXX-nnnn-@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 15, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnno@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 20, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnnoooo
+      > SELECT _FUNC_('AbCD123-@$#', 10,'Q', 'q', 'd', 'o');
+        QqQQdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, 'q', 'd', 'o');
+        AqCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, 'd', 'o');
+        AbCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, NULL, 'o');
+        AbCD123ooo#
+      > SELECT _FUNC_(NULL);
+        NULL
+      > SELECT _FUNC_(NULL, 1, NULL, NULL, 'o');
+        NULL
+  """,
+  since = "3.4.0",
+  group = "string_funcs")
+// scalastyle:on line.size.limit
+case class MaskFirstN(
+    input: Expression,
+    charCountExpr: Expression,
+    override val upperCharExpr: Expression,
+    override val lowerCharExpr: Expression,
+    override val digitCharExpr: Expression,
+    override val otherCharExpr: Expression)
+    extends SeptenaryExpression
+    with Maskable {
+
+  def this(input: Expression) =
+    this(
+      input,
+      Literal(Mask.DEFAULT_CHAR_COUNT),
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression) =
+    this(
+      input,
+      charCountExpr,
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression, upperCharExpr: 
Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperCharExpr,
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperCharExpr: Expression,
+      lowerCharExpr: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperCharExpr,
+      lowerCharExpr,
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperCharExpr: Expression,
+      lowerCharExpr: Expression,
+      digitCharExpr: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperCharExpr,
+      lowerCharExpr,
+      digitCharExpr,
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  @transient
+  private lazy val charCount = {
+    val value = charCountExpr.eval().asInstanceOf[Int]
+    if (value < 0) 0 else value

Review Comment:
   ```suggestion
       .max(0)
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -77,13 +150,12 @@ import org.apache.spark.unsafe.types.UTF8String
 // scalastyle:on line.size.limit
 case class Mask(

Review Comment:
   in the `Arguments` comment, can you update this part to mention that the 
expressions must be constant?
   
   ```
         * upperChar  - character to replace upper-case characters with. 
Specify NULL to retain original character. Default value: 'X'
         * lowerChar  - character to replace lower-case characters with. 
Specify NULL to retain original character. Default value: 'x'
         * digitChar  - character to replace digit characters with. Specify 
NULL to retain original character. Default value: 'n'
         * otherChar  - character to replace all other characters with. Specify 
NULL to retain original character. Default value: NULL
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -283,7 +543,32 @@ object Mask {
         transformChar(_, maskUpper, maskLower, maskDigit, maskOther).toChar
       }
     }
-    org.apache.spark.unsafe.types.UTF8String.fromString(transformedString)
+    UTF8String.fromString(transformedString)
+  }
+
+  def mask_first_n(
+      input: Any,
+      charCount: Int,
+      maskUpper: Any,
+      maskLower: Any,
+      maskDigit: Any,
+      maskOther: Any): UTF8String = {
+
+    val transformedString = if (input == null) {
+      null
+    } else {
+      val inputStr = input.asInstanceOf[UTF8String]
+      val stringSize = inputStr.numChars
+      val endIdx = if (stringSize < charCount) stringSize else charCount

Review Comment:
   ```suggestion
         val endIdx = stringSize.min(charCount)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to