HyukjinKwon commented on code in PR #36063:
URL: https://github.com/apache/spark/pull/36063#discussion_r848995606


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala:
##########
@@ -1014,3 +1015,51 @@ case class PercentRank(children: Seq[Expression]) 
extends RankLike with SizeBase
   override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): PercentRank =
     copy(children = newChildren)
 }
+
+/**
+ * Exponential Weighted Moment. This expression is dedicated only for Pandas 
API on Spark.
+ * An exponentially weighted window is similar to an expanding window but with 
each prior point
+ * being exponentially weighted down relative to the current point.
+ * See 
https://pandas.pydata.org/docs/user_guide/window.html#exponentially-weighted-window
+ * for details.
+ * Currently, only weighted moving average is supported. In general, it is 
calculated as
+ *    y_t = \frac{\sum_{i=0}^t w_i x_{t-i}}{\sum_{i=0}^t w_i},
+ * where x_t is the input, y_t is the result and the w_i are the weights.
+ */
+@DeveloperApi
+@Experimental
+@Unstable
+case class EWM(input: Expression, alpha: Double)
+  extends AggregateWindowFunction with UnaryLike[Expression] {
+  assert(0 < alpha && alpha <= 1)
+
+  override def dataType: DataType = DoubleType
+
+  private val numerator = AttributeReference("numerator", DoubleType, nullable 
= false)()
+  private val denominator = AttributeReference("denominator", DoubleType, 
nullable = false)()
+  override def aggBufferAttributes: Seq[AttributeReference] = numerator :: 
denominator :: Nil
+
+  override val initialValues: Seq[Expression] = Seq(Literal(0.0), Literal(0.0))
+
+  override val updateExpressions: Seq[Expression] = {
+    val beta = Literal(1.0 - alpha)
+    val casted = input.cast(DoubleType)
+    // TODO: after adding param ignore_na, we can remove this check
+    val error = RaiseError(Literal("Input values Must not be Null or 
NaN")).cast(DoubleType)

Review Comment:
   ```suggestion
       val error = RaiseError(Literal("Input values must not be null or 
NaN")).cast(DoubleType)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to