Repository: spark
Updated Branches:
  refs/heads/master 3a7494dfe -> 6e36d8d56


[SPARK-22829] Add new built-in function date_trunc()

## What changes were proposed in this pull request?

Adding date_trunc() as a built-in function.
`date_trunc` is common in other databases, but Spark or Hive does not have 
support for this. `date_trunc` is commonly used by data scientists and business 
intelligence application such as Superset 
(https://github.com/apache/incubator-superset).
We do have `trunc` but this only works with 'MONTH' and 'YEAR' level on the 
DateType input.

date_trunc() in other databases:
AWS Redshift: http://docs.aws.amazon.com/redshift/latest/dg/r_DATE_TRUNC.html
PostgreSQL: https://www.postgresql.org/docs/9.1/static/functions-datetime.html
Presto: https://prestodb.io/docs/current/functions/datetime.html

## How was this patch tested?

Unit tests

(Please explain how this patch was tested. E.g. unit tests, integration tests, 
manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, 
remove this)

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: Youngbin Kim <ykim...@hotmail.com>

Closes #20015 from youngbink/date_trunc.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6e36d8d5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6e36d8d5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6e36d8d5

Branch: refs/heads/master
Commit: 6e36d8d56279a2c5c92c8df8e89ee99b514817e7
Parents: 3a7494d
Author: Youngbin Kim <ykim...@hotmail.com>
Authored: Tue Dec 19 20:22:33 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Tue Dec 19 20:22:33 2017 -0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  20 ++-
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/datetimeExpressions.scala       | 170 ++++++++++++++-----
 .../spark/sql/catalyst/util/DateTimeUtils.scala | 102 +++++++++--
 .../expressions/DateExpressionsSuite.scala      |  73 +++++++-
 .../sql/catalyst/util/DateTimeUtilsSuite.scala  |  70 ++++++++
 .../scala/org/apache/spark/sql/functions.scala  |  15 ++
 .../apache/spark/sql/DateFunctionsSuite.scala   |  46 +++++
 8 files changed, 445 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6e36d8d5/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 4e0fadd..5453005 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1099,7 +1099,7 @@ def trunc(date, format):
     """
     Returns date truncated to the unit specified by the format.
 
-    :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
+    :param format: 'year', 'yyyy', 'yy' or 'month', 'mon', 'mm'
 
     >>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
     >>> df.select(trunc(df.d, 'year').alias('year')).collect()
@@ -1111,6 +1111,24 @@ def trunc(date, format):
     return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
 
 
+@since(2.3)
+def date_trunc(format, timestamp):
+    """
+    Returns timestamp truncated to the unit specified by the format.
+
+    :param format: 'year', 'yyyy', 'yy', 'month', 'mon', 'mm',
+        'day', 'dd', 'hour', 'minute', 'second', 'week', 'quarter'
+
+    >>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t'])
+    >>> df.select(date_trunc('year', df.t).alias('year')).collect()
+    [Row(year=datetime.datetime(1997, 1, 1, 0, 0))]
+    >>> df.select(date_trunc('mon', df.t).alias('month')).collect()
+    [Row(month=datetime.datetime(1997, 2, 1, 0, 0))]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.date_trunc(format, 
_to_java_column(timestamp)))
+
+
 @since(1.5)
 def next_day(date, dayOfWeek):
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/6e36d8d5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 11538bd..5ddb398 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -392,6 +392,7 @@ object FunctionRegistry {
     expression[ToUnixTimestamp]("to_unix_timestamp"),
     expression[ToUTCTimestamp]("to_utc_timestamp"),
     expression[TruncDate]("trunc"),
+    expression[TruncTimestamp]("date_trunc"),
     expression[UnixTimestamp]("unix_timestamp"),
     expression[DayOfWeek]("dayofweek"),
     expression[WeekOfYear]("weekofyear"),

http://git-wip-us.apache.org/repos/asf/spark/blob/6e36d8d5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index cfec7f8..59c3e3d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -1294,80 +1294,79 @@ case class ParseToTimestamp(left: Expression, format: 
Option[Expression], child:
   override def dataType: DataType = TimestampType
 }
 
-/**
- * Returns date truncated to the unit specified by the format.
- */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
-  usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day 
truncated to the unit specified by the format model `fmt`.",
-  examples = """
-    Examples:
-      > SELECT _FUNC_('2009-02-12', 'MM');
-       2009-02-01
-      > SELECT _FUNC_('2015-10-27', 'YEAR');
-       2015-01-01
-  """,
-  since = "1.5.0")
-// scalastyle:on line.size.limit
-case class TruncDate(date: Expression, format: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
-  override def left: Expression = date
-  override def right: Expression = format
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
-  override def dataType: DataType = DateType
+trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
+  val instant: Expression
+  val format: Expression
   override def nullable: Boolean = true
-  override def prettyName: String = "trunc"
 
   private lazy val truncLevel: Int =
     DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
 
-  override def eval(input: InternalRow): Any = {
+  /**
+   * @param input internalRow (time)
+   * @param maxLevel Maximum level that can be used for truncation (e.g MONTH 
for Date input)
+   * @param truncFunc function: (time, level) => time
+   */
+  protected def evalHelper(input: InternalRow, maxLevel: Int)(
+    truncFunc: (Any, Int) => Any): Any = {
     val level = if (format.foldable) {
       truncLevel
     } else {
       DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
     }
-    if (level == -1) {
-      // unknown format
+    if (level == DateTimeUtils.TRUNC_INVALID || level > maxLevel) {
+      // unknown format or too large level
       null
     } else {
-      val d = date.eval(input)
-      if (d == null) {
+      val t = instant.eval(input)
+      if (t == null) {
         null
       } else {
-        DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
+        truncFunc(t, level)
       }
     }
   }
 
-  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+  protected def codeGenHelper(
+      ctx: CodegenContext,
+      ev: ExprCode,
+      maxLevel: Int,
+      orderReversed: Boolean = false)(
+      truncFunc: (String, String) => String)
+    : ExprCode = {
     val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
 
     if (format.foldable) {
-      if (truncLevel == -1) {
+      if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
         ev.copy(code = s"""
           boolean ${ev.isNull} = true;
           ${ctx.javaType(dataType)} ${ev.value} = 
${ctx.defaultValue(dataType)};""")
       } else {
-        val d = date.genCode(ctx)
+        val t = instant.genCode(ctx)
+        val truncFuncStr = truncFunc(t.value, truncLevel.toString)
         ev.copy(code = s"""
-          ${d.code}
-          boolean ${ev.isNull} = ${d.isNull};
+          ${t.code}
+          boolean ${ev.isNull} = ${t.isNull};
           ${ctx.javaType(dataType)} ${ev.value} = 
${ctx.defaultValue(dataType)};
           if (!${ev.isNull}) {
-            ${ev.value} = $dtu.truncDate(${d.value}, $truncLevel);
+            ${ev.value} = $dtu.$truncFuncStr;
           }""")
       }
     } else {
-      nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
+      nullSafeCodeGen(ctx, ev, (left, right) => {
         val form = ctx.freshName("form")
+        val (dateVal, fmt) = if (orderReversed) {
+          (right, left)
+        } else {
+          (left, right)
+        }
+        val truncFuncStr = truncFunc(dateVal, form)
         s"""
           int $form = $dtu.parseTruncLevel($fmt);
-          if ($form == -1) {
+          if ($form == -1 || $form > $maxLevel) {
             ${ev.isNull} = true;
           } else {
-            ${ev.value} = $dtu.truncDate($dateVal, $form);
+            ${ev.value} = $dtu.$truncFuncStr
           }
         """
       })
@@ -1376,6 +1375,101 @@ case class TruncDate(date: Expression, format: 
Expression)
 }
 
 /**
+ * Returns date truncated to the unit specified by the format.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(date, fmt) - Returns `date` with the time portion of the day 
truncated to the unit specified by the format model `fmt`.
+    `fmt` should be one of ["year", "yyyy", "yy", "mon", "month", "mm"]
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('2009-02-12', 'MM');
+       2009-02-01
+      > SELECT _FUNC_('2015-10-27', 'YEAR');
+       2015-01-01
+  """,
+  since = "1.5.0")
+// scalastyle:on line.size.limit
+case class TruncDate(date: Expression, format: Expression)
+  extends TruncInstant {
+  override def left: Expression = date
+  override def right: Expression = format
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
+  override def dataType: DataType = DateType
+  override def prettyName: String = "trunc"
+  override val instant = date
+
+  override def eval(input: InternalRow): Any = {
+    evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (d: Any, 
level: Int) =>
+      DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (date: 
String, fmt: String) =>
+      s"truncDate($date, $fmt);"
+    }
+  }
+}
+
+/**
+ * Returns timestamp truncated to the unit specified by the format.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(fmt, ts) - Returns timestamp `ts` truncated to the unit specified 
by the format model `fmt`.
+    `fmt` should be one of ["YEAR", "YYYY", "YY", "MON", "MONTH", "MM", "DAY", 
"DD", "HOUR", "MINUTE", "SECOND", "WEEK", "QUARTER"]
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR');
+       2015-01-01T00:00:00
+      > SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM');
+       2015-03-01T00:00:00
+      > SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD');
+       2015-03-05T00:00:00
+      > SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR');
+       2015-03-05T09:00:00
+  """,
+  since = "2.3.0")
+// scalastyle:on line.size.limit
+case class TruncTimestamp(
+    format: Expression,
+    timestamp: Expression,
+    timeZoneId: Option[String] = None)
+  extends TruncInstant with TimeZoneAwareExpression {
+  override def left: Expression = format
+  override def right: Expression = timestamp
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, 
TimestampType)
+  override def dataType: TimestampType = TimestampType
+  override def prettyName: String = "date_trunc"
+  override val instant = timestamp
+  override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+    copy(timeZoneId = Option(timeZoneId))
+
+  def this(format: Expression, timestamp: Expression) = this(format, 
timestamp, None)
+
+  override def eval(input: InternalRow): Any = {
+    evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_SECOND) { (t: Any, 
level: Int) =>
+      DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, timeZone)
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val tz = ctx.addReferenceObj("timeZone", timeZone)
+    codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_SECOND, true) {
+      (date: String, fmt: String) =>
+        s"truncTimestamp($date, $fmt, $tz);"
+    }
+  }
+}
+
+/**
  * Returns the number of days from startDate to endDate.
  */
 @ExpressionDescription(

http://git-wip-us.apache.org/repos/asf/spark/blob/6e36d8d5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index b1ed256..fa69b8a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -45,7 +45,8 @@ object DateTimeUtils {
   // it's 2440587.5, rounding up to compatible with Hive
   final val JULIAN_DAY_OF_EPOCH = 2440588
   final val SECONDS_PER_DAY = 60 * 60 * 24L
-  final val MICROS_PER_SECOND = 1000L * 1000L
+  final val MICROS_PER_MILLIS = 1000L
+  final val MICROS_PER_SECOND = MICROS_PER_MILLIS * MILLIS_PER_SECOND
   final val MILLIS_PER_SECOND = 1000L
   final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L
   final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY
@@ -909,6 +910,15 @@ object DateTimeUtils {
     math.round(diff * 1e8) / 1e8
   }
 
+  // Thursday = 0 since 1970/Jan/01 => Thursday
+  private val SUNDAY = 3
+  private val MONDAY = 4
+  private val TUESDAY = 5
+  private val WEDNESDAY = 6
+  private val THURSDAY = 0
+  private val FRIDAY = 1
+  private val SATURDAY = 2
+
   /*
    * Returns day of week from String. Starting from Thursday, marked as 0.
    * (Because 1970-01-01 is Thursday).
@@ -916,13 +926,13 @@ object DateTimeUtils {
   def getDayOfWeekFromString(string: UTF8String): Int = {
     val dowString = string.toString.toUpperCase(Locale.ROOT)
     dowString match {
-      case "SU" | "SUN" | "SUNDAY" => 3
-      case "MO" | "MON" | "MONDAY" => 4
-      case "TU" | "TUE" | "TUESDAY" => 5
-      case "WE" | "WED" | "WEDNESDAY" => 6
-      case "TH" | "THU" | "THURSDAY" => 0
-      case "FR" | "FRI" | "FRIDAY" => 1
-      case "SA" | "SAT" | "SATURDAY" => 2
+      case "SU" | "SUN" | "SUNDAY" => SUNDAY
+      case "MO" | "MON" | "MONDAY" => MONDAY
+      case "TU" | "TUE" | "TUESDAY" => TUESDAY
+      case "WE" | "WED" | "WEDNESDAY" => WEDNESDAY
+      case "TH" | "THU" | "THURSDAY" => THURSDAY
+      case "FR" | "FRI" | "FRIDAY" => FRIDAY
+      case "SA" | "SAT" | "SATURDAY" => SATURDAY
       case _ => -1
     }
   }
@@ -944,9 +954,16 @@ object DateTimeUtils {
     date + daysToMonthEnd
   }
 
-  private val TRUNC_TO_YEAR = 1
-  private val TRUNC_TO_MONTH = 2
-  private val TRUNC_INVALID = -1
+  // Visible for testing.
+  private[sql] val TRUNC_TO_YEAR = 1
+  private[sql] val TRUNC_TO_MONTH = 2
+  private[sql] val TRUNC_TO_QUARTER = 3
+  private[sql] val TRUNC_TO_WEEK = 4
+  private[sql] val TRUNC_TO_DAY = 5
+  private[sql] val TRUNC_TO_HOUR = 6
+  private[sql] val TRUNC_TO_MINUTE = 7
+  private[sql] val TRUNC_TO_SECOND = 8
+  private[sql] val TRUNC_INVALID = -1
 
   /**
    * Returns the trunc date from original date and trunc level.
@@ -964,7 +981,62 @@ object DateTimeUtils {
   }
 
   /**
-   * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or 
TRUNC_INVALID,
+   * Returns the trunc date time from original date time and trunc level.
+   * Trunc level should be generated using `parseTruncLevel()`, should be 
between 1 and 8
+   */
+  def truncTimestamp(t: SQLTimestamp, level: Int, timeZone: TimeZone): 
SQLTimestamp = {
+    var millis = t / MICROS_PER_MILLIS
+    val truncated = level match {
+      case TRUNC_TO_YEAR =>
+        val dDays = millisToDays(millis, timeZone)
+        daysToMillis(truncDate(dDays, level), timeZone)
+      case TRUNC_TO_MONTH =>
+        val dDays = millisToDays(millis, timeZone)
+        daysToMillis(truncDate(dDays, level), timeZone)
+      case TRUNC_TO_DAY =>
+        val offset = timeZone.getOffset(millis)
+        millis += offset
+        millis - millis % (MILLIS_PER_SECOND * SECONDS_PER_DAY) - offset
+      case TRUNC_TO_HOUR =>
+        val offset = timeZone.getOffset(millis)
+        millis += offset
+        millis - millis % (60 * 60 * MILLIS_PER_SECOND) - offset
+      case TRUNC_TO_MINUTE =>
+        millis - millis % (60 * MILLIS_PER_SECOND)
+      case TRUNC_TO_SECOND =>
+        millis - millis % MILLIS_PER_SECOND
+      case TRUNC_TO_WEEK =>
+        val dDays = millisToDays(millis, timeZone)
+        val prevMonday = getNextDateForDayOfWeek(dDays - 7, MONDAY)
+        daysToMillis(prevMonday, timeZone)
+      case TRUNC_TO_QUARTER =>
+        val dDays = millisToDays(millis, timeZone)
+        millis = daysToMillis(truncDate(dDays, TRUNC_TO_MONTH), timeZone)
+        val cal = Calendar.getInstance()
+        cal.setTimeInMillis(millis)
+        val quarter = getQuarter(dDays)
+        val month = quarter match {
+          case 1 => Calendar.JANUARY
+          case 2 => Calendar.APRIL
+          case 3 => Calendar.JULY
+          case 4 => Calendar.OCTOBER
+        }
+        cal.set(Calendar.MONTH, month)
+        cal.getTimeInMillis()
+      case _ =>
+        // caller make sure that this should never be reached
+        sys.error(s"Invalid trunc level: $level")
+    }
+    truncated * MICROS_PER_MILLIS
+  }
+
+  def truncTimestamp(d: SQLTimestamp, level: Int): SQLTimestamp = {
+    truncTimestamp(d, level, defaultTimeZone())
+  }
+
+  /**
+   * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, 
TRUNC_TO_DAY, TRUNC_TO_HOUR,
+   * TRUNC_TO_MINUTE, TRUNC_TO_SECOND, TRUNC_TO_WEEK, TRUNC_TO_QUARTER or 
TRUNC_INVALID,
    * TRUNC_INVALID means unsupported truncate level.
    */
   def parseTruncLevel(format: UTF8String): Int = {
@@ -974,6 +1046,12 @@ object DateTimeUtils {
       format.toString.toUpperCase(Locale.ROOT) match {
         case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR
         case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH
+        case "DAY" | "DD" => TRUNC_TO_DAY
+        case "HOUR" => TRUNC_TO_HOUR
+        case "MINUTE" => TRUNC_TO_MINUTE
+        case "SECOND" => TRUNC_TO_SECOND
+        case "WEEK" => TRUNC_TO_WEEK
+        case "QUARTER" => TRUNC_TO_QUARTER
         case _ => TRUNC_INVALID
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/6e36d8d5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 89d99f9..63f6cee 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -527,7 +527,7 @@ class DateExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, 
StringType)), null)
   }
 
-  test("function trunc") {
+  test("TruncDate") {
     def testTrunc(input: Date, fmt: String, expected: Date): Unit = {
       checkEvaluation(TruncDate(Literal.create(input, DateType), 
Literal.create(fmt, StringType)),
         expected)
@@ -543,11 +543,82 @@ class DateExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       testTrunc(date, fmt, Date.valueOf("2015-07-01"))
     }
     testTrunc(date, "DD", null)
+    testTrunc(date, "SECOND", null)
+    testTrunc(date, "HOUR", null)
     testTrunc(date, null, null)
     testTrunc(null, "MON", null)
     testTrunc(null, null, null)
   }
 
+  test("TruncTimestamp") {
+    def testTrunc(input: Timestamp, fmt: String, expected: Timestamp): Unit = {
+      checkEvaluation(
+        TruncTimestamp(Literal.create(fmt, StringType), Literal.create(input, 
TimestampType)),
+        expected)
+      checkEvaluation(
+        TruncTimestamp(
+          NonFoldableLiteral.create(fmt, StringType), Literal.create(input, 
TimestampType)),
+        expected)
+    }
+
+    withDefaultTimeZone(TimeZoneGMT) {
+      val inputDate = Timestamp.valueOf("2015-07-22 05:30:06")
+
+      Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt =>
+        testTrunc(
+          inputDate, fmt,
+          Timestamp.valueOf("2015-01-01 00:00:00"))
+      }
+
+      Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
+        testTrunc(
+          inputDate, fmt,
+          Timestamp.valueOf("2015-07-01 00:00:00"))
+      }
+
+      Seq("DAY", "day", "DD", "dd").foreach { fmt =>
+        testTrunc(
+          inputDate, fmt,
+          Timestamp.valueOf("2015-07-22 00:00:00"))
+      }
+
+      Seq("HOUR", "hour").foreach { fmt =>
+        testTrunc(
+          inputDate, fmt,
+          Timestamp.valueOf("2015-07-22 05:00:00"))
+      }
+
+      Seq("MINUTE", "minute").foreach { fmt =>
+        testTrunc(
+          inputDate, fmt,
+          Timestamp.valueOf("2015-07-22 05:30:00"))
+      }
+
+      Seq("SECOND", "second").foreach { fmt =>
+        testTrunc(
+          inputDate, fmt,
+          Timestamp.valueOf("2015-07-22 05:30:06"))
+      }
+
+      Seq("WEEK", "week").foreach { fmt =>
+        testTrunc(
+          inputDate, fmt,
+          Timestamp.valueOf("2015-07-20 00:00:00"))
+      }
+
+      Seq("QUARTER", "quarter").foreach { fmt =>
+        testTrunc(
+          inputDate, fmt,
+          Timestamp.valueOf("2015-07-01 00:00:00"))
+      }
+
+      testTrunc(inputDate, "INVALID", null)
+      testTrunc(inputDate, null, null)
+      testTrunc(null, "MON", null)
+      testTrunc(null, null, null)
+    }
+  }
+
   test("from_unixtime") {
     val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
     val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"

http://git-wip-us.apache.org/repos/asf/spark/blob/6e36d8d5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
index c8cf16d..625ff38 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
@@ -563,6 +563,76 @@ class DateTimeUtilsSuite extends SparkFunSuite {
     }
   }
 
+  test("truncTimestamp") {
+    def testTrunc(
+        level: Int,
+        expected: String,
+        inputTS: SQLTimestamp,
+        timezone: TimeZone = DateTimeUtils.defaultTimeZone()): Unit = {
+      val truncated =
+        DateTimeUtils.truncTimestamp(inputTS, level, timezone)
+      val expectedTS =
+        DateTimeUtils.stringToTimestamp(UTF8String.fromString(expected))
+      assert(truncated === expectedTS.get)
+    }
+
+    val defaultInputTS =
+      
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-05T09:32:05.359"))
+    val defaultInputTS1 =
+      
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-31T20:32:05.359"))
+    val defaultInputTS2 =
+      
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-04-01T02:32:05.359"))
+    val defaultInputTS3 =
+      
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-30T02:32:05.359"))
+    val defaultInputTS4 =
+      
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-29T02:32:05.359"))
+
+    testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", 
defaultInputTS.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", 
defaultInputTS.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", 
defaultInputTS.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", 
defaultInputTS.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", 
defaultInputTS.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", 
defaultInputTS.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", 
defaultInputTS.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", 
defaultInputTS1.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", 
defaultInputTS2.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", 
defaultInputTS3.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", 
defaultInputTS4.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", 
defaultInputTS.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", 
defaultInputTS1.get)
+    testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", 
defaultInputTS2.get)
+
+    for (tz <- DateTimeTestUtils.ALL_TIMEZONES) {
+      DateTimeTestUtils.withDefaultTimeZone(tz) {
+        val inputTS =
+          
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-05T09:32:05.359"))
+        val inputTS1 =
+          
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-31T20:32:05.359"))
+        val inputTS2 =
+          
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-04-01T02:32:05.359"))
+        val inputTS3 =
+          
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-30T02:32:05.359"))
+        val inputTS4 =
+          
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-29T02:32:05.359"))
+
+        testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", 
inputTS.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", 
inputTS.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", 
inputTS.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", 
inputTS.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", 
inputTS.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", 
inputTS.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", 
inputTS.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", 
inputTS1.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", 
inputTS2.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", 
inputTS3.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", 
inputTS4.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", 
inputTS.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", 
inputTS1.get, tz)
+        testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", 
inputTS2.get, tz)
+      }
+    }
+  }
+
   test("daysToMillis and millisToDays") {
     val c = Calendar.getInstance(TimeZonePST)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6e36d8d5/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3e4659b..052a3f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2798,6 +2798,21 @@ object functions {
   }
 
   /**
+   * Returns timestamp truncated to the unit specified by the format.
+   *
+   * @param format: 'year', 'yyyy', 'yy' for truncate by year,
+   *                'month', 'mon', 'mm' for truncate by month,
+   *                'day', 'dd' for truncate by day,
+   *                Other options are: 'second', 'minute', 'hour', 'week', 
'month', 'quarter'
+   *
+   * @group datetime_funcs
+   * @since 2.3.0
+   */
+  def date_trunc(format: String, timestamp: Column): Column = withExpr {
+    TruncTimestamp(Literal(format), timestamp.expr)
+  }
+
+  /**
    * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time 
in UTC, and renders
    * that time as a timestamp in the given time zone. For example, 'GMT+1' 
would yield
    * '2017-07-14 03:40:00.0'.

http://git-wip-us.apache.org/repos/asf/spark/blob/6e36d8d5/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
index 3a86948..6bbf385 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
@@ -435,6 +435,52 @@ class DateFunctionsSuite extends QueryTest with 
SharedSQLContext {
       Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01"))))
   }
 
+  test("function date_trunc") {
+    val df = Seq(
+      (1, Timestamp.valueOf("2015-07-22 10:01:40.523")),
+      (2, Timestamp.valueOf("2014-12-31 05:29:06.876"))).toDF("i", "t")
+
+    checkAnswer(
+      df.select(date_trunc("YY", col("t"))),
+      Seq(Row(Timestamp.valueOf("2015-01-01 00:00:00")),
+        Row(Timestamp.valueOf("2014-01-01 00:00:00"))))
+
+    checkAnswer(
+      df.selectExpr("date_trunc('MONTH', t)"),
+      Seq(Row(Timestamp.valueOf("2015-07-01 00:00:00")),
+        Row(Timestamp.valueOf("2014-12-01 00:00:00"))))
+
+    checkAnswer(
+      df.selectExpr("date_trunc('DAY', t)"),
+      Seq(Row(Timestamp.valueOf("2015-07-22 00:00:00")),
+        Row(Timestamp.valueOf("2014-12-31 00:00:00"))))
+
+    checkAnswer(
+      df.selectExpr("date_trunc('HOUR', t)"),
+      Seq(Row(Timestamp.valueOf("2015-07-22 10:00:00")),
+        Row(Timestamp.valueOf("2014-12-31 05:00:00"))))
+
+    checkAnswer(
+      df.selectExpr("date_trunc('MINUTE', t)"),
+      Seq(Row(Timestamp.valueOf("2015-07-22 10:01:00")),
+        Row(Timestamp.valueOf("2014-12-31 05:29:00"))))
+
+    checkAnswer(
+      df.selectExpr("date_trunc('SECOND', t)"),
+      Seq(Row(Timestamp.valueOf("2015-07-22 10:01:40")),
+        Row(Timestamp.valueOf("2014-12-31 05:29:06"))))
+
+    checkAnswer(
+      df.selectExpr("date_trunc('WEEK', t)"),
+      Seq(Row(Timestamp.valueOf("2015-07-20 00:00:00")),
+        Row(Timestamp.valueOf("2014-12-29 00:00:00"))))
+
+    checkAnswer(
+      df.selectExpr("date_trunc('QUARTER', t)"),
+      Seq(Row(Timestamp.valueOf("2015-07-01 00:00:00")),
+        Row(Timestamp.valueOf("2014-10-01 00:00:00"))))
+  }
+
   test("from_unixtime") {
     val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
     val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to