dtenedor commented on code in PR #39449:
URL: https://github.com/apache/spark/pull/39449#discussion_r1084350176


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -23,9 +23,54 @@ import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
StringType}
 import org.apache.spark.unsafe.types.UTF8String
 
+/**
+ * The trait does Input Data Type validation .

Review Comment:
   please also mention which types of expressions we intend to inherit this 
trait, perhaps with a couple examples?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -257,19 +271,272 @@ case class Mask(
       otherChar = newChildren(4))
 }
 
-case class MaskArgument(maskChar: Char, ignore: Boolean)
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage =
+    """_FUNC_(input[, charCount, upperChar, lowerChar, digitChar, otherChar]) 
- masks the first n characters of given string value.
+       The function masks the first n characters of the value with 'X' or 'x', 
and numbers with 'n'.
+       This can be useful for creating copies of tables with sensitive 
information removed.
+       Error behavior: null value as replacement argument will throw 
AnalysisError.
+      """,
+  arguments = """
+    Arguments:
+      * input      - string value to mask. Supported types: STRING, VARCHAR, 
CHAR
+      * charCount  - number of characters to be masked. Default value: 4
+      * upperChar  - character to replace upper-case characters with. Specify 
NULL to retain original character. Default value: 'X'
+      * lowerChar  - character to replace lower-case characters with. Specify 
NULL to retain original character. Default value: 'x'
+      * digitChar  - character to replace digit characters with. Specify NULL 
to retain original character. Default value: 'n'
+      * otherChar  - character to replace all other characters with. Specify 
NULL to retain original character. Default value: NULL
+  """,
+  examples = """
+    Examples:

Review Comment:
   please also add a case where we request to mask more characters than are 
present in the input string. What is the behavior?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -257,19 +271,272 @@ case class Mask(
       otherChar = newChildren(4))
 }
 
-case class MaskArgument(maskChar: Char, ignore: Boolean)
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage =
+    """_FUNC_(input[, charCount, upperChar, lowerChar, digitChar, otherChar]) 
- masks the first n characters of given string value.
+       The function masks the first n characters of the value with 'X' or 'x', 
and numbers with 'n'.
+       This can be useful for creating copies of tables with sensitive 
information removed.
+       Error behavior: null value as replacement argument will throw 
AnalysisError.
+      """,
+  arguments = """
+    Arguments:
+      * input      - string value to mask. Supported types: STRING, VARCHAR, 
CHAR
+      * charCount  - number of characters to be masked. Default value: 4
+      * upperChar  - character to replace upper-case characters with. Specify 
NULL to retain original character. Default value: 'X'
+      * lowerChar  - character to replace lower-case characters with. Specify 
NULL to retain original character. Default value: 'x'
+      * digitChar  - character to replace digit characters with. Specify NULL 
to retain original character. Default value: 'n'
+      * otherChar  - character to replace all other characters with. Specify 
NULL to retain original character. Default value: NULL
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('abcd-EFGH-8765-4321');
+        xxxx-EFGH-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-4321', 9);
+        xxxx-XXXX-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 14);
+        xxxx-XXXX-nnnn-@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 15, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnno@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 20, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnnoooo
+      > SELECT _FUNC_('AbCD123-@$#', 10,'Q', 'q', 'd', 'o');
+        QqQQdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, 'q', 'd', 'o');
+        AqCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, 'd', 'o');
+        AbCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, NULL, 'o');
+        AbCD123ooo#
+      > SELECT _FUNC_(NULL);
+        NULL
+      > SELECT _FUNC_(NULL, 1, NULL, NULL, 'o');
+        NULL
+  """,
+  since = "3.4.0",
+  group = "string_funcs")
+// scalastyle:on line.size.limit
+case class MaskFirstN(
+    input: Expression,
+    charCountExpr: Expression,
+    upperChar: Expression,
+    lowerChar: Expression,
+    digitChar: Expression,
+    otherChar: Expression)
+    extends SeptenaryExpression
+    with Maskable
+    with ExpectsInputTypes
+    with QueryErrorsBase {
+
+  def this(input: Expression) =
+    this(
+      input,
+      Literal(Mask.DEFAULT_CHAR_COUNT),
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression) =
+    this(
+      input,
+      charCountExpr,
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression, upperChar: 
Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression,
+      digitChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      digitChar,
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  @transient
+  private lazy val charCount = {
+    val value = charCountExpr.eval().asInstanceOf[Int]
+    if (value < 0) 0 else value
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    validateInputDataTypes(
+      super.checkInputDataTypes(),
+      Seq(
+        (upperChar, "upperChar"),
+        (lowerChar, "lowerChar"),
+        (digitChar, "digitChar"),
+        (otherChar, "otherChar")),
+      () =>
+        Seq(if (!charCountExpr.foldable) {
+          Some(
+            DataTypeMismatch(
+              errorSubClass = "NON_FOLDABLE_INPUT",
+              messageParameters = Map(
+                "inputName" -> "charCount",
+                "inputType" -> toSQLType(charCountExpr.dataType),
+                "inputExpr" -> toSQLExpr(charCountExpr))))
+        } else if (charCountExpr.eval() == null) {
+          Some(
+            DataTypeMismatch(
+              errorSubClass = "UNEXPECTED_NULL",
+              messageParameters = Map("exprName" -> "charCount")))
+        } else {
+          None
+        }))
+
+  /**
+   * Expected input types from child expressions. The i-th position in the 
returned seq indicates
+   * the type requirement for the i-th child.
+   *
+   * The possible values at each position are:
+   *   1. a specific data type, e.g. LongType, StringType. 2. a non-leaf 
abstract data type, e.g.
+   *      NumericType, IntegralType, FractionalType.
+   */
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringType, IntegerType, StringType, StringType, StringType, 
StringType)
+
+  override def nullable: Boolean = true
+
+  /**
+   * Default behavior of evaluation according to the default nullability of 
QuinaryExpression. If
+   * subclass of QuinaryExpression override nullable, probably should also 
override this.
+   */
+  override def eval(input: InternalRow): Any = {
+    Mask.mask_first_n(
+      children(0).eval(input),
+      charCount,
+      children(2).eval(input),
+      children(3).eval(input),
+      children(4).eval(input),
+      children(5).eval(input))
+  }
+
+  /**
+   * Returns Java source code that can be compiled to evaluate this 
expression. The default
+   * behavior is to call the eval method of the expression. Concrete 
expression implementations
+   * should override this to do actual code generation.
+   *
+   * @param ctx
+   *   a [[CodegenContext]]
+   * @param ev
+   *   an [[ExprCode]] with unique terms.
+   * @return
+   *   an [[ExprCode]] containing the Java source code to generate the given 
expression
+   */
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode =
+    defineCodeGen(
+      ctx,
+      ev,
+      (input, count, upperChar, lowerChar, digitChar, otherChar, inputOpt) => {
+        s"org.apache.spark.sql.catalyst.expressions.Mask." +
+          s"mask_first_n($input, $charCount,$upperChar, $lowerChar, 
$digitChar, $otherChar);"
+      })
+
+  /**
+   * Short hand for generating septenary evaluation code. If either of the 
sub-expressions is
+   * null, the result of this computation is assumed to be null.
+   *
+   * @param f
+   *   function that accepts the 7 non-null evaluation result names of 
children and returns Java
+   *   code to compute the output.
+   */
+  override protected def nullSafeCodeGen(
+      ctx: CodegenContext,
+      ev: ExprCode,
+      f: (String, String, String, String, String, String, Option[String]) => 
String): ExprCode = {
+    val firstGen = children(0).genCode(ctx)
+    val secGen = children(1).genCode(ctx)
+    val thirdGen = children(2).genCode(ctx)
+    val fourthGen = children(3).genCode(ctx)
+    val fifthGen = children(4).genCode(ctx)
+    val sixthGen = children(5).genCode(ctx)
+    val resultCode =
+      f(
+        firstGen.value,
+        secGen.value,
+        thirdGen.value,
+        fourthGen.value,
+        fifthGen.value,
+        sixthGen.value,
+        None)
+    ev.copy(
+      code = code"""
+        ${firstGen.code}
+        ${thirdGen.code}
+        ${fourthGen.code}
+        ${fifthGen.code}
+        ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+        $resultCode""",
+      isNull = FalseLiteral)
+  }
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It 
is invalid to query
+   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override def dataType: DataType = StringType
+
+  /**
+   * Returns a Seq of the children of this node. Children should not change. 
Immutability required
+   * for containsChild optimization
+   */
+  override def children: Seq[Expression] =
+    Seq(input, charCountExpr, upperChar, lowerChar, digitChar, otherChar)
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): MaskFirstN =
+    copy(
+      input = newChildren(0),
+      charCountExpr = newChildren(1),
+      upperChar = newChildren(2),
+      lowerChar = newChildren(3),
+      digitChar = newChildren(4),
+      otherChar = newChildren(5))
+}
 
 object Mask {
   // Default character to replace upper-case characters
-  private val MASKED_UPPERCASE = 'X'
+  val MASKED_UPPERCASE = 'X'
   // Default character to replace lower-case characters
-  private val MASKED_LOWERCASE = 'x'
+  val MASKED_LOWERCASE = 'x'
   // Default character to replace digits
-  private val MASKED_DIGIT = 'n'
+  val MASKED_DIGIT = 'n'
   // This value helps to retain original value in the input by ignoring the 
replacement rules
-  private val MASKED_IGNORE = null
+  val MASKED_IGNORE = null
+  // Default number of characters to be masked
+  val DEFAULT_CHAR_COUNT = 4

Review Comment:
   please sort these alphabetically



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -257,19 +271,272 @@ case class Mask(
       otherChar = newChildren(4))
 }
 
-case class MaskArgument(maskChar: Char, ignore: Boolean)
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage =
+    """_FUNC_(input[, charCount, upperChar, lowerChar, digitChar, otherChar]) 
- masks the first n characters of given string value.
+       The function masks the first n characters of the value with 'X' or 'x', 
and numbers with 'n'.
+       This can be useful for creating copies of tables with sensitive 
information removed.
+       Error behavior: null value as replacement argument will throw 
AnalysisError.
+      """,
+  arguments = """
+    Arguments:
+      * input      - string value to mask. Supported types: STRING, VARCHAR, 
CHAR
+      * charCount  - number of characters to be masked. Default value: 4
+      * upperChar  - character to replace upper-case characters with. Specify 
NULL to retain original character. Default value: 'X'
+      * lowerChar  - character to replace lower-case characters with. Specify 
NULL to retain original character. Default value: 'x'
+      * digitChar  - character to replace digit characters with. Specify NULL 
to retain original character. Default value: 'n'
+      * otherChar  - character to replace all other characters with. Specify 
NULL to retain original character. Default value: NULL
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('abcd-EFGH-8765-4321');
+        xxxx-EFGH-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-4321', 9);
+        xxxx-XXXX-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 14);
+        xxxx-XXXX-nnnn-@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 15, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnno@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 20, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnnoooo
+      > SELECT _FUNC_('AbCD123-@$#', 10,'Q', 'q', 'd', 'o');
+        QqQQdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, 'q', 'd', 'o');
+        AqCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, 'd', 'o');
+        AbCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, NULL, 'o');
+        AbCD123ooo#
+      > SELECT _FUNC_(NULL);
+        NULL
+      > SELECT _FUNC_(NULL, 1, NULL, NULL, 'o');
+        NULL
+  """,
+  since = "3.4.0",
+  group = "string_funcs")
+// scalastyle:on line.size.limit
+case class MaskFirstN(
+    input: Expression,
+    charCountExpr: Expression,
+    upperChar: Expression,
+    lowerChar: Expression,
+    digitChar: Expression,
+    otherChar: Expression)
+    extends SeptenaryExpression
+    with Maskable
+    with ExpectsInputTypes
+    with QueryErrorsBase {
+
+  def this(input: Expression) =
+    this(
+      input,
+      Literal(Mask.DEFAULT_CHAR_COUNT),
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression) =
+    this(
+      input,
+      charCountExpr,
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression, upperChar: 
Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression,
+      digitChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      digitChar,
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  @transient
+  private lazy val charCount = {
+    val value = charCountExpr.eval().asInstanceOf[Int]
+    if (value < 0) 0 else value
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    validateInputDataTypes(
+      super.checkInputDataTypes(),
+      Seq(
+        (upperChar, "upperChar"),
+        (lowerChar, "lowerChar"),
+        (digitChar, "digitChar"),
+        (otherChar, "otherChar")),
+      () =>
+        Seq(if (!charCountExpr.foldable) {
+          Some(
+            DataTypeMismatch(
+              errorSubClass = "NON_FOLDABLE_INPUT",
+              messageParameters = Map(
+                "inputName" -> "charCount",
+                "inputType" -> toSQLType(charCountExpr.dataType),
+                "inputExpr" -> toSQLExpr(charCountExpr))))
+        } else if (charCountExpr.eval() == null) {
+          Some(
+            DataTypeMismatch(
+              errorSubClass = "UNEXPECTED_NULL",
+              messageParameters = Map("exprName" -> "charCount")))
+        } else {
+          None
+        }))
+
+  /**
+   * Expected input types from child expressions. The i-th position in the 
returned seq indicates
+   * the type requirement for the i-th child.
+   *
+   * The possible values at each position are:
+   *   1. a specific data type, e.g. LongType, StringType. 2. a non-leaf 
abstract data type, e.g.
+   *      NumericType, IntegralType, FractionalType.
+   */
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringType, IntegerType, StringType, StringType, StringType, 
StringType)
+
+  override def nullable: Boolean = true
+
+  /**
+   * Default behavior of evaluation according to the default nullability of 
QuinaryExpression. If
+   * subclass of QuinaryExpression override nullable, probably should also 
override this.
+   */
+  override def eval(input: InternalRow): Any = {
+    Mask.mask_first_n(
+      children(0).eval(input),
+      charCount,
+      children(2).eval(input),

Review Comment:
   we are evaluating all the constant arguments for each row here. Can we 
please instead perform this once and cache the results instead, for efficiency?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -23,9 +23,54 @@ import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
StringType}
 import org.apache.spark.unsafe.types.UTF8String
 
+/**
+ * The trait does Input Data Type validation .
+ */
+private[sql] trait Maskable extends QueryErrorsBase {
+
+  def validateInputDataTypes(
+      typeCheckResult: TypeCheckResult,
+      expressions: Seq[(Expression, String)],
+      additionalValidaton: () => Seq[Option[TypeCheckResult]] = () => 
Seq(None))
+      : TypeCheckResult = {
+
+    if (typeCheckResult.isSuccess) {
+      (expressions
+        .map { case (exp: Expression, message: String) =>
+          validateInputDataType(exp, message)
+        } ++ additionalValidaton()).flatten.headOption
+        .getOrElse(typeCheckResult)
+    } else {
+      typeCheckResult
+    }
+  }
+
+  private def validateInputDataType(exp: Expression, message: String): 
Option[TypeCheckResult] = {
+    if (!exp.foldable) {
+      Some(
+        DataTypeMismatch(
+          errorSubClass = "NON_FOLDABLE_INPUT",
+          messageParameters = Map(
+            "inputName" -> message,
+            "inputType" -> toSQLType(exp.dataType),
+            "inputExpr" -> toSQLExpr(exp))))
+    } else {
+      val replaceChar = exp.eval()
+      if (replaceChar != null && replaceChar.asInstanceOf[UTF8String].numChars 
!= 1) {

Review Comment:
   please use a pattern match against `exp.eval()` instead to (1) check that it 
is in fact an instance of `UTF8String` and (2) combine the new `if` to the 
previous `else`, de-denting the rest of the block?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -257,19 +271,272 @@ case class Mask(
       otherChar = newChildren(4))
 }
 
-case class MaskArgument(maskChar: Char, ignore: Boolean)
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage =
+    """_FUNC_(input[, charCount, upperChar, lowerChar, digitChar, otherChar]) 
- masks the first n characters of given string value.
+       The function masks the first n characters of the value with 'X' or 'x', 
and numbers with 'n'.
+       This can be useful for creating copies of tables with sensitive 
information removed.
+       Error behavior: null value as replacement argument will throw 
AnalysisError.
+      """,
+  arguments = """
+    Arguments:
+      * input      - string value to mask. Supported types: STRING, VARCHAR, 
CHAR
+      * charCount  - number of characters to be masked. Default value: 4
+      * upperChar  - character to replace upper-case characters with. Specify 
NULL to retain original character. Default value: 'X'
+      * lowerChar  - character to replace lower-case characters with. Specify 
NULL to retain original character. Default value: 'x'
+      * digitChar  - character to replace digit characters with. Specify NULL 
to retain original character. Default value: 'n'
+      * otherChar  - character to replace all other characters with. Specify 
NULL to retain original character. Default value: NULL
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('abcd-EFGH-8765-4321');
+        xxxx-EFGH-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-4321', 9);
+        xxxx-XXXX-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 14);
+        xxxx-XXXX-nnnn-@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 15, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnno@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 20, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnnoooo
+      > SELECT _FUNC_('AbCD123-@$#', 10,'Q', 'q', 'd', 'o');
+        QqQQdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, 'q', 'd', 'o');
+        AqCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, 'd', 'o');
+        AbCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, NULL, 'o');
+        AbCD123ooo#
+      > SELECT _FUNC_(NULL);
+        NULL
+      > SELECT _FUNC_(NULL, 1, NULL, NULL, 'o');
+        NULL
+  """,
+  since = "3.4.0",
+  group = "string_funcs")
+// scalastyle:on line.size.limit
+case class MaskFirstN(
+    input: Expression,
+    charCountExpr: Expression,
+    upperChar: Expression,
+    lowerChar: Expression,
+    digitChar: Expression,
+    otherChar: Expression)
+    extends SeptenaryExpression
+    with Maskable
+    with ExpectsInputTypes
+    with QueryErrorsBase {
+
+  def this(input: Expression) =
+    this(
+      input,
+      Literal(Mask.DEFAULT_CHAR_COUNT),
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression) =
+    this(
+      input,
+      charCountExpr,
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression, upperChar: 
Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression,
+      digitChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      digitChar,
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  @transient
+  private lazy val charCount = {
+    val value = charCountExpr.eval().asInstanceOf[Int]
+    if (value < 0) 0 else value
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    validateInputDataTypes(
+      super.checkInputDataTypes(),
+      Seq(
+        (upperChar, "upperChar"),
+        (lowerChar, "lowerChar"),
+        (digitChar, "digitChar"),
+        (otherChar, "otherChar")),
+      () =>

Review Comment:
   can you please move this lambda into a separate method for better code 
health?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -257,19 +271,272 @@ case class Mask(
       otherChar = newChildren(4))
 }
 
-case class MaskArgument(maskChar: Char, ignore: Boolean)
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage =
+    """_FUNC_(input[, charCount, upperChar, lowerChar, digitChar, otherChar]) 
- masks the first n characters of given string value.
+       The function masks the first n characters of the value with 'X' or 'x', 
and numbers with 'n'.
+       This can be useful for creating copies of tables with sensitive 
information removed.
+       Error behavior: null value as replacement argument will throw 
AnalysisError.
+      """,
+  arguments = """
+    Arguments:
+      * input      - string value to mask. Supported types: STRING, VARCHAR, 
CHAR
+      * charCount  - number of characters to be masked. Default value: 4
+      * upperChar  - character to replace upper-case characters with. Specify 
NULL to retain original character. Default value: 'X'
+      * lowerChar  - character to replace lower-case characters with. Specify 
NULL to retain original character. Default value: 'x'
+      * digitChar  - character to replace digit characters with. Specify NULL 
to retain original character. Default value: 'n'
+      * otherChar  - character to replace all other characters with. Specify 
NULL to retain original character. Default value: NULL
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('abcd-EFGH-8765-4321');
+        xxxx-EFGH-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-4321', 9);
+        xxxx-XXXX-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 14);
+        xxxx-XXXX-nnnn-@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 15, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnno@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 20, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnnoooo
+      > SELECT _FUNC_('AbCD123-@$#', 10,'Q', 'q', 'd', 'o');
+        QqQQdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, 'q', 'd', 'o');
+        AqCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, 'd', 'o');
+        AbCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, NULL, 'o');
+        AbCD123ooo#
+      > SELECT _FUNC_(NULL);
+        NULL
+      > SELECT _FUNC_(NULL, 1, NULL, NULL, 'o');
+        NULL
+  """,
+  since = "3.4.0",
+  group = "string_funcs")
+// scalastyle:on line.size.limit
+case class MaskFirstN(
+    input: Expression,
+    charCountExpr: Expression,
+    upperChar: Expression,
+    lowerChar: Expression,
+    digitChar: Expression,
+    otherChar: Expression)
+    extends SeptenaryExpression
+    with Maskable
+    with ExpectsInputTypes
+    with QueryErrorsBase {
+
+  def this(input: Expression) =
+    this(
+      input,
+      Literal(Mask.DEFAULT_CHAR_COUNT),
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression) =
+    this(
+      input,
+      charCountExpr,
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression, upperChar: 
Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression,
+      digitChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      digitChar,
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  @transient
+  private lazy val charCount = {
+    val value = charCountExpr.eval().asInstanceOf[Int]
+    if (value < 0) 0 else value
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    validateInputDataTypes(
+      super.checkInputDataTypes(),
+      Seq(
+        (upperChar, "upperChar"),
+        (lowerChar, "lowerChar"),
+        (digitChar, "digitChar"),
+        (otherChar, "otherChar")),
+      () =>
+        Seq(if (!charCountExpr.foldable) {

Review Comment:
   let's use a pattern match for these checks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to