dtenedor commented on code in PR #48004:
URL: https://github.com/apache/spark/pull/48004#discussion_r1755297072


##########
sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala:
##########
@@ -118,11 +118,11 @@ class ExpressionsSchemaSuite extends QueryTest with 
SharedSparkSession {
         // SET spark.sql.parser.escapedStringLiterals=true
         example.split("  > 
").tail.filterNot(_.trim.startsWith("SET")).take(1).foreach {
           case _ if funcName == "from_avro" || funcName == "to_avro" ||
-            funcName == "from_protobuf" || funcName == "to_protobuf" =>
+            funcName == "from_protobuf" || funcName == "to_protobuf" || 
funcName == "uniform" =>

Review Comment:
   I updated this so `uniform` is no longer excluded.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala:
##########
@@ -181,3 +189,215 @@ case class Randn(child: Expression, hideSeed: Boolean = 
false) extends RDG {
 object Randn {
   def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
 }
+
+@ExpressionDescription(
+  usage = """
+    _FUNC_(min, max[, seed]) - Returns a random value with independent and 
identically
+      distributed (i.i.d.) values with the specified range of numbers. The 
random seed is optional.
+      The provided numbers specifying the minimum and maximum values of the 
range must be constant.
+      If both of these numbers are integers, then the result will also be an 
integer. Otherwise if
+      one or both of these are floating-point numbers, then the result will 
also be a floating-point
+      number.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(10, 20) > 0;
+      true
+  """,
+  since = "4.0.0",
+  group = "math_funcs")
+case class Uniform(min: Expression, max: Expression, seed: Expression)
+  extends RuntimeReplaceable with TernaryLike[Expression] with 
ExpressionWithRandomSeed {
+  def this(min: Expression, max: Expression) =
+    this(min, max, Literal(Uniform.random.nextLong(), LongType))
+
+  final override lazy val deterministic: Boolean = false
+  override val nodePatterns: Seq[TreePattern] =
+    Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)
+
+  override val dataType: DataType = {
+    val first = min.dataType
+    val second = max.dataType
+    (min.dataType, max.dataType) match {
+      case _ if !valid(min) || !valid(max) => NullType
+      case (_, LongType) | (LongType, _) if Seq(first, second).forall(integer) 
=> LongType
+      case (_, IntegerType) | (IntegerType, _) if Seq(first, 
second).forall(integer) => IntegerType
+      case (_, ShortType) | (ShortType, _) if Seq(first, 
second).forall(integer) => ShortType
+      case (_, DoubleType) | (DoubleType, _) => DoubleType
+      case (_, FloatType) | (FloatType, _) => FloatType
+      case _ => NullType
+    }
+  }
+
+  private def valid(e: Expression): Boolean = e.dataType match {
+    case _ if !e.foldable => false
+    case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: 
DoubleType => true
+    case _ => false
+  }
+
+  private def integer(t: DataType): Boolean = t match {
+    case _: ShortType | _: IntegerType | _: LongType => true
+    case _ => false
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
+    Seq(min, max, seed).zipWithIndex.foreach { case (expr: Expression, index: 
Int) =>
+      if (!valid(expr)) {
+        result = DataTypeMismatch(
+          errorSubClass = "UNEXPECTED_INPUT_TYPE",

Review Comment:
   Sounds good, done.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala:
##########
@@ -181,3 +189,215 @@ case class Randn(child: Expression, hideSeed: Boolean = 
false) extends RDG {
 object Randn {
   def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
 }
+
+@ExpressionDescription(
+  usage = """
+    _FUNC_(min, max[, seed]) - Returns a random value with independent and 
identically
+      distributed (i.i.d.) values with the specified range of numbers. The 
random seed is optional.
+      The provided numbers specifying the minimum and maximum values of the 
range must be constant.
+      If both of these numbers are integers, then the result will also be an 
integer. Otherwise if
+      one or both of these are floating-point numbers, then the result will 
also be a floating-point
+      number.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(10, 20) > 0;
+      true
+  """,
+  since = "4.0.0",
+  group = "math_funcs")
+case class Uniform(min: Expression, max: Expression, seed: Expression)
+  extends RuntimeReplaceable with TernaryLike[Expression] with 
ExpressionWithRandomSeed {
+  def this(min: Expression, max: Expression) =
+    this(min, max, Literal(Uniform.random.nextLong(), LongType))
+
+  final override lazy val deterministic: Boolean = false
+  override val nodePatterns: Seq[TreePattern] =
+    Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)
+
+  override val dataType: DataType = {
+    val first = min.dataType
+    val second = max.dataType
+    (min.dataType, max.dataType) match {
+      case _ if !valid(min) || !valid(max) => NullType
+      case (_, LongType) | (LongType, _) if Seq(first, second).forall(integer) 
=> LongType
+      case (_, IntegerType) | (IntegerType, _) if Seq(first, 
second).forall(integer) => IntegerType
+      case (_, ShortType) | (ShortType, _) if Seq(first, 
second).forall(integer) => ShortType
+      case (_, DoubleType) | (DoubleType, _) => DoubleType
+      case (_, FloatType) | (FloatType, _) => FloatType
+      case _ => NullType
+    }
+  }
+
+  private def valid(e: Expression): Boolean = e.dataType match {
+    case _ if !e.foldable => false
+    case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: 
DoubleType => true
+    case _ => false
+  }
+
+  private def integer(t: DataType): Boolean = t match {
+    case _: ShortType | _: IntegerType | _: LongType => true
+    case _ => false
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
+    Seq(min, max, seed).zipWithIndex.foreach { case (expr: Expression, index: 
Int) =>
+      if (!valid(expr)) {
+        result = DataTypeMismatch(
+          errorSubClass = "UNEXPECTED_INPUT_TYPE",
+          messageParameters = Map(
+            "paramIndex" -> ordinalNumber(index),
+            "requiredType" -> "constant value of integer or floating-point",

Review Comment:
   Good point, updated the error messages accordingly.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala:
##########
@@ -181,3 +189,215 @@ case class Randn(child: Expression, hideSeed: Boolean = 
false) extends RDG {
 object Randn {
   def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
 }
+
+@ExpressionDescription(
+  usage = """
+    _FUNC_(min, max[, seed]) - Returns a random value with independent and 
identically
+      distributed (i.i.d.) values with the specified range of numbers. The 
random seed is optional.
+      The provided numbers specifying the minimum and maximum values of the 
range must be constant.
+      If both of these numbers are integers, then the result will also be an 
integer. Otherwise if
+      one or both of these are floating-point numbers, then the result will 
also be a floating-point
+      number.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(10, 20) > 0;
+      true
+  """,
+  since = "4.0.0",
+  group = "math_funcs")
+case class Uniform(min: Expression, max: Expression, seed: Expression)
+  extends RuntimeReplaceable with TernaryLike[Expression] with 
ExpressionWithRandomSeed {
+  def this(min: Expression, max: Expression) =
+    this(min, max, Literal(Uniform.random.nextLong(), LongType))
+
+  final override lazy val deterministic: Boolean = false
+  override val nodePatterns: Seq[TreePattern] =
+    Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)
+
+  override val dataType: DataType = {
+    val first = min.dataType
+    val second = max.dataType
+    (min.dataType, max.dataType) match {
+      case _ if !valid(min) || !valid(max) => NullType
+      case (_, LongType) | (LongType, _) if Seq(first, second).forall(integer) 
=> LongType
+      case (_, IntegerType) | (IntegerType, _) if Seq(first, 
second).forall(integer) => IntegerType
+      case (_, ShortType) | (ShortType, _) if Seq(first, 
second).forall(integer) => ShortType
+      case (_, DoubleType) | (DoubleType, _) => DoubleType
+      case (_, FloatType) | (FloatType, _) => FloatType
+      case _ => NullType
+    }
+  }
+
+  private def valid(e: Expression): Boolean = e.dataType match {
+    case _ if !e.foldable => false
+    case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: 
DoubleType => true
+    case _ => false
+  }
+
+  private def integer(t: DataType): Boolean = t match {
+    case _: ShortType | _: IntegerType | _: LongType => true
+    case _ => false
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
+    Seq(min, max, seed).zipWithIndex.foreach { case (expr: Expression, index: 
Int) =>
+      if (!valid(expr)) {
+        result = DataTypeMismatch(
+          errorSubClass = "UNEXPECTED_INPUT_TYPE",
+          messageParameters = Map(
+            "paramIndex" -> ordinalNumber(index),
+            "requiredType" -> "constant value of integer or floating-point",
+            "inputSql" -> toSQLExpr(expr),
+            "inputType" -> toSQLType(expr.dataType)))
+      }
+    }
+    result
+  }
+
+  override def first: Expression = min
+  override def second: Expression = max
+  override def third: Expression = seed
+
+  override def seedExpression: Expression = seed
+  override def withNewSeed(newSeed: Long): Expression =
+    Uniform(min, max, Literal(newSeed, LongType))
+
+  override def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): 
Expression =
+    Uniform(newFirst, newSecond, newThird)
+
+  override def toString: String = prettyName + truncatedString(
+    Seq(min, max), "(", ", ", ")", SQLConf.get.maxToStringFields)

Review Comment:
   I reverted this method override now.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala:
##########
@@ -181,3 +189,215 @@ case class Randn(child: Expression, hideSeed: Boolean = 
false) extends RDG {
 object Randn {
   def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
 }
+
+@ExpressionDescription(
+  usage = """
+    _FUNC_(min, max[, seed]) - Returns a random value with independent and 
identically
+      distributed (i.i.d.) values with the specified range of numbers. The 
random seed is optional.
+      The provided numbers specifying the minimum and maximum values of the 
range must be constant.
+      If both of these numbers are integers, then the result will also be an 
integer. Otherwise if
+      one or both of these are floating-point numbers, then the result will 
also be a floating-point
+      number.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(10, 20) > 0;
+      true
+  """,
+  since = "4.0.0",
+  group = "math_funcs")
+case class Uniform(min: Expression, max: Expression, seed: Expression)
+  extends RuntimeReplaceable with TernaryLike[Expression] with 
ExpressionWithRandomSeed {
+  def this(min: Expression, max: Expression) =
+    this(min, max, Literal(Uniform.random.nextLong(), LongType))
+
+  final override lazy val deterministic: Boolean = false
+  override val nodePatterns: Seq[TreePattern] =
+    Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)
+
+  override val dataType: DataType = {
+    val first = min.dataType
+    val second = max.dataType
+    (min.dataType, max.dataType) match {
+      case _ if !valid(min) || !valid(max) => NullType
+      case (_, LongType) | (LongType, _) if Seq(first, second).forall(integer) 
=> LongType
+      case (_, IntegerType) | (IntegerType, _) if Seq(first, 
second).forall(integer) => IntegerType
+      case (_, ShortType) | (ShortType, _) if Seq(first, 
second).forall(integer) => ShortType
+      case (_, DoubleType) | (DoubleType, _) => DoubleType
+      case (_, FloatType) | (FloatType, _) => FloatType
+      case _ => NullType
+    }
+  }
+
+  private def valid(e: Expression): Boolean = e.dataType match {
+    case _ if !e.foldable => false
+    case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: 
DoubleType => true
+    case _ => false
+  }
+
+  private def integer(t: DataType): Boolean = t match {
+    case _: ShortType | _: IntegerType | _: LongType => true

Review Comment:
   Good question; we specifically want to exclude `ByteType` here, thus the 
explicit list of integer types.



##########
sql/core/src/test/resources/sql-functions/sql-expression-schema.md:
##########
@@ -265,6 +265,7 @@
 | org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder | 
raise_error | SELECT raise_error('custom error message') | 
struct<raise_error(USER_RAISED_EXCEPTION, map(errorMessage, custom error 
message)):void> |
 | org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | 
struct<rand():double> |
 | org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | 
struct<rand():double> |
+| org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT 
randstr(3, 0) | struct<randstr(3, 0):string> |

Review Comment:
   I updated this so `uniform` is no longer excluded.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to