This is an automated email from the ASF dual-hosted git repository.

yiconghuang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/texera.git


The following commit(s) were added to refs/heads/main by this push:
     new 6980f19376 refactor(core): unify type ops and reuse in sort/agg (#4024)
6980f19376 is described below

commit 6980f19376c016b3505663827465aa41283e608e
Author: carloea2 <[email protected]>
AuthorDate: Fri Nov 21 11:54:38 2025 -0800

    refactor(core): unify type ops and reuse in sort/agg (#4024)
    
    ### What changes were proposed in this PR?
    
    1. **Centralize and extend `AttributeType` operations**
    
    Move and refactor the existing attribute-type helpers into
    `AttributeTypeUtils`:
    
       * `compare`, `add`, `zeroValue`, `minValue`, `maxValue`.
    * Unify null-handling semantics across these operations. (use of
    match-case instead of if + match)
    
       Extend support to additional types:
    
    * Add comparison/aggregation support for `BOOLEAN`, `STRING`, and
    `BINARY`.
    
       Change numeric coercion strategy:
    
    * Coerce numeric values to `Number` instead of a specific primitive type
    (e.g., `Double`) to reduce `ClassCastException`s when the input is not
    strictly schema-validated.
    * Preserve existing comparison semantics for doubles by delegating to
    `java.lang.Double.compare` (including handling of ±∞ and `NaN`).
    
       Introduce “identity” helpers:
    
    * `zeroValue` returns an additive identity for numeric/timestamp types,
    and `Array.emptyByteArray` for `BINARY` as a safe, non-throwing
    identity.
    * `minValue` / `maxValue`: provide lower/upper bounds for supported
    numeric and timestamp types.
    
    2. **Refactor operators to reuse `AttributeTypeUtils`**
    
    * `AggregationOperation`: implement `SUM` / `MIN` / `MAX` using the
    centralized helpers instead of custom per-operator logic.
    * `StableMergeSortOpExec`: reuse the typed compare logic from
    `AttributeTypeUtils`.
    * `SortPartitionsOpExec`: simplify to use a one-liner comparator based
    on `AttributeTypeUtils.compare` (or a thin wrapper) for clarity and
    reuse.
    
    3. **Add tests**
    *
    
workflow-core/src/test/scala/org/apache/amber/core\tuple/AttributeTypeUtilsSpec.scala
    * **compare**: Verifies correct null-handling and ordering for INTEGER,
    BOOLEAN, TIMESTAMP, STRING, and BINARY values.
    * **add**: Ensures `null` acts as identity and confirms correct addition
    for INTEGER, LONG, DOUBLE, and TIMESTAMP.
    * **zeroValue**: Checks that numeric/timestamp zero identities and empty
    binary array for BINARY are returned, and that unsupported types (e.g.,
    STRING) throw.
    * **minValue / maxValue**: Validate correct numeric and timestamp
    bounds, BINARY minimum, and exceptions for unsupported types (e.g.,
    BOOLEAN, STRING).
    *
    
workflow-operator/src/test/scala/org/apache/amber/operator/aggregate/AggregateOpSpec.scala
    * Verifies `getAggregationAttribute` chooses the correct result type for
    different functions (SUM keeps input type, COUNT → INTEGER, CONCAT →
    STRING).
    * Checks `getAggFunc` SUM behavior for INTEGER and DOUBLE columns,
    ensuring correct totals and preserved fractional values.
    * Tests COUNT, CONCAT, MIN, MAX, and AVERAGE aggregations, including
    correct handling of `null` values and edge cases like “no rows”.
    * Confirms `getFinal` rewrites COUNT into a SUM on the intermediate
    count column and rewires attributes correctly for non-COUNT functions.
    * Exercises `AggregateOpExec` end-to-end: SUM grouped by a key (city)
    and combined global SUM+COUNT with no group-by keys, validating the
    produced tuples.
    
    
    5. **Scope / non-goals / Extras**
       * No change to external APIs
    * Main behavior changes are localized to `AttributeType` operations and
    the operators that consume them.
    
    ---
    
    **Any related issues, documentation, discussions?**
    
    * Closes: #3923
    
    **How was this PR tested?**
    
    Workflow Image:
    <img width="1684" height="859" alt="image"
    
src="https://github.com/user-attachments/assets/2682ebdc-0f45-40c6-b304-0cea0b76b44f";
    />
    
    Workflow file:
    
    
[agg_test_1.json](https://github.com/user-attachments/files/23540242/agg_test_1.json)
    
    Python benchmark:
    
    ```
    import pandas as pd
    
    df = pd.read_csv("/mnt/data/test.csv")
    
    # Limit BEFORE sorting
    df_limited = df.head(1000)
    
    # Now sort ascending
    df_sorted = df_limited.sort_values("rna_umis", ascending=True)
    
    # Group by pass_all_filters with aggregations
    agg = df_sorted.groupby("pass_all_filters")["rna_umis"].agg(
        min="min", max="max", count="count", avg="mean", sum="sum"
    ).reset_index()
    
    agg
    
    ```
    Python Result:
    <img width="928" height="188" alt="image"
    
src="https://github.com/user-attachments/assets/69da33cd-ada4-4b05-a3f9-ae139f8575b9";
    />
    
    Texera Result (Avg):
    
    False | 0 | 80926 | 240 | 15987.68 | 3837043
    -- | -- | -- | -- | -- | --
    True | 11893 | 102559 | 760 | 35557.93 | 27024027
    
    For timestamps test:
    - 1970-01-01T00:00:00Z
    - 2000-02-29T12:00:00Z
    - 2024-12-31T23:59:59Z
    
    
    1. Avg:
    
    - New version: 909835199750
    - Previous version: 909835199750
    
    2. Sum:
    
    - New version: 2055-03-01T05:59:59.000Z (UTC)
    - Previous version: 2055-03-01T11:59:59.000Z (UTC-6; Mexico City Time)
    
    **Was this PR authored or co-authored using generative AI tooling?**
    
    * Co-authored with ChatGPT.
---
 .../amber/core/tuple/AttributeTypeUtils.scala      | 131 ++++++++
 .../amber/core/tuple/AttributeTypeUtilsSpec.scala  | 135 ++++++++-
 .../operator/aggregate/AggregationOperation.scala  | 120 +-------
 .../sortPartitions/SortPartitionsOpExec.scala      |  23 +-
 .../amber/operator/aggregate/AggregateOpSpec.scala | 333 +++++++++++++++++++++
 5 files changed, 621 insertions(+), 121 deletions(-)

diff --git 
a/common/workflow-core/src/main/scala/org/apache/amber/core/tuple/AttributeTypeUtils.scala
 
b/common/workflow-core/src/main/scala/org/apache/amber/core/tuple/AttributeTypeUtils.scala
index e4fdcb4611..7cbfb27179 100644
--- 
a/common/workflow-core/src/main/scala/org/apache/amber/core/tuple/AttributeTypeUtils.scala
+++ 
b/common/workflow-core/src/main/scala/org/apache/amber/core/tuple/AttributeTypeUtils.scala
@@ -387,6 +387,137 @@ object AttributeTypeUtils extends Serializable {
     }
   }
 
+  /** Three-way compare for the given attribute type.
+    * Returns < 0 if left < right, > 0 if left > right, 0 if equal.
+    * Null semantics: null < non-null (both null => 0).
+    */
+  @throws[UnsupportedOperationException]
+  def compare(left: Any, right: Any, attrType: AttributeType): Int =
+    (left, right) match {
+      case (null, null) => 0
+      case (null, _)    => -1
+      case (_, null)    => 1
+      case _ =>
+        attrType match {
+          case AttributeType.INTEGER =>
+            java.lang.Integer.compare(
+              left.asInstanceOf[Number].intValue(),
+              right.asInstanceOf[Number].intValue()
+            )
+          case AttributeType.LONG =>
+            java.lang.Long.compare(
+              left.asInstanceOf[Number].longValue(),
+              right.asInstanceOf[Number].longValue()
+            )
+          case AttributeType.DOUBLE =>
+            java.lang.Double.compare(
+              left.asInstanceOf[Number].doubleValue(),
+              right.asInstanceOf[Number].doubleValue()
+            ) // -Infinity < ... < -0.0 < +0.0 < ... < +Infinity < NaN
+          case AttributeType.BOOLEAN =>
+            java.lang.Boolean.compare(
+              left.asInstanceOf[Boolean],
+              right.asInstanceOf[Boolean]
+            )
+          case AttributeType.TIMESTAMP =>
+            java.lang.Long.compare(
+              left.asInstanceOf[Timestamp].getTime,
+              right.asInstanceOf[Timestamp].getTime
+            )
+          case AttributeType.STRING =>
+            left.toString.compareTo(right.toString)
+          case AttributeType.BINARY =>
+            java.util.Arrays.compareUnsigned(
+              left.asInstanceOf[Array[Byte]],
+              right.asInstanceOf[Array[Byte]]
+            )
+          case _ =>
+            throw new UnsupportedOperationException(
+              s"Unsupported attribute type for compare: $attrType"
+            )
+        }
+    }
+
+  /** Type-aware addition (null is identity). */
+  @throws[UnsupportedOperationException]
+  def add(left: Object, right: Object, attrType: AttributeType): Object =
+    (left, right) match {
+      case (null, null)  => zeroValue(attrType)
+      case (null, right) => right
+      case (left, null)  => left
+      case (left, right) =>
+        attrType match {
+          case AttributeType.INTEGER =>
+            java.lang.Integer.valueOf(
+              left.asInstanceOf[Number].intValue() + 
right.asInstanceOf[Number].intValue()
+            )
+          case AttributeType.LONG =>
+            java.lang.Long.valueOf(
+              left.asInstanceOf[Number].longValue() + 
right.asInstanceOf[Number].longValue()
+            )
+          case AttributeType.DOUBLE =>
+            java.lang.Double.valueOf(
+              left.asInstanceOf[Number].doubleValue() + 
right.asInstanceOf[Number].doubleValue()
+            )
+          case AttributeType.TIMESTAMP =>
+            new Timestamp(
+              left.asInstanceOf[Timestamp].getTime + 
right.asInstanceOf[Timestamp].getTime
+            )
+          case _ =>
+            throw new UnsupportedOperationException(
+              s"Unsupported attribute type for addition: $attrType"
+            )
+        }
+    }
+
+  /** Additive identity for supported numeric/timestamp types.
+    * For BINARY an empty array is returned as an identity value.
+    */
+  @throws[UnsupportedOperationException]
+  def zeroValue(attrType: AttributeType): Object =
+    attrType match {
+      case AttributeType.INTEGER   => java.lang.Integer.valueOf(0)
+      case AttributeType.LONG      => java.lang.Long.valueOf(0L)
+      case AttributeType.DOUBLE    => java.lang.Double.valueOf(0.0d)
+      case AttributeType.TIMESTAMP => new Timestamp(0L)
+      case AttributeType.BINARY    => Array.emptyByteArray
+      case _ =>
+        throw new UnsupportedOperationException(
+          s"Unsupported attribute type for zero value: $attrType"
+        )
+    }
+
+  /** Returns the maximum possible value for a given attribute type. */
+  @throws[UnsupportedOperationException]
+  def maxValue(attrType: AttributeType): Object =
+    attrType match {
+      case AttributeType.INTEGER   => 
java.lang.Integer.valueOf(Integer.MAX_VALUE)
+      case AttributeType.LONG      => 
java.lang.Long.valueOf(java.lang.Long.MAX_VALUE)
+      case AttributeType.DOUBLE    => 
java.lang.Double.valueOf(java.lang.Double.MAX_VALUE)
+      case AttributeType.TIMESTAMP => new Timestamp(java.lang.Long.MAX_VALUE)
+      case _ =>
+        throw new UnsupportedOperationException(
+          s"Unsupported attribute type for max value: $attrType"
+        )
+    }
+
+  /** Returns the minimum possible value for a given attribute type. (note 
Double.MIN_VALUE is > 0).
+    * For BINARY under lexicographic order, the empty array is the global 
minimum.
+    */
+  @throws[UnsupportedOperationException]
+  def minValue(attrType: AttributeType): Object =
+    attrType match {
+      case AttributeType.INTEGER   => 
java.lang.Integer.valueOf(Integer.MIN_VALUE)
+      case AttributeType.LONG      => 
java.lang.Long.valueOf(java.lang.Long.MIN_VALUE)
+      case AttributeType.DOUBLE    => 
java.lang.Double.valueOf(java.lang.Double.MIN_VALUE)
+      case AttributeType.TIMESTAMP => new Timestamp(0L)
+      case AttributeType.BINARY    => Array.emptyByteArray
+      case _ =>
+        throw new UnsupportedOperationException(
+          s"Unsupported attribute type for min value: $attrType"
+        )
+    }
+
   class AttributeTypeException(msg: String, cause: Throwable = null)
       extends IllegalArgumentException(msg, cause) {}
 }
diff --git 
a/common/workflow-core/src/test/scala/org/apache/amber/core/tuple/AttributeTypeUtilsSpec.scala
 
b/common/workflow-core/src/test/scala/org/apache/amber/core/tuple/AttributeTypeUtilsSpec.scala
index 24c998b3f7..53e5f68430 100644
--- 
a/common/workflow-core/src/test/scala/org/apache/amber/core/tuple/AttributeTypeUtilsSpec.scala
+++ 
b/common/workflow-core/src/test/scala/org/apache/amber/core/tuple/AttributeTypeUtilsSpec.scala
@@ -24,11 +24,17 @@ import org.apache.amber.core.tuple.AttributeTypeUtils.{
   AttributeTypeException,
   inferField,
   inferSchemaFromRows,
-  parseField
+  parseField,
+  compare,
+  add,
+  minValue,
+  maxValue,
+  zeroValue
 }
 import org.scalatest.funsuite.AnyFunSuite
 
 class AttributeTypeUtilsSpec extends AnyFunSuite {
+
   // Unit Test for Infer Schema
 
   test("type should get inferred correctly individually") {
@@ -190,4 +196,131 @@ class AttributeTypeUtilsSpec extends AnyFunSuite {
     assert(parseField("anything", AttributeType.ANY) == "anything")
   }
 
+  test("compare correctly handles null values for different attribute types") {
+    assert(compare(null, null, INTEGER) == 0)
+    assert(compare(null, 10, INTEGER) < 0)
+    assert(compare(10, null, INTEGER) > 0)
+  }
+
+  test("compare correctly orders numeric, boolean, timestamp, string and 
binary values") {
+    assert(compare(1, 2, INTEGER) < 0)
+    assert(compare(2, 1, INTEGER) > 0)
+    assert(compare(5, 5, INTEGER) == 0)
+
+    assert(compare(false, true, BOOLEAN) < 0)
+    assert(compare(true, false, BOOLEAN) > 0)
+    assert(compare(true, true, BOOLEAN) == 0)
+
+    val earlierTimestamp = new java.sql.Timestamp(1000L)
+    val laterTimestamp = new java.sql.Timestamp(2000L)
+    assert(compare(earlierTimestamp, laterTimestamp, TIMESTAMP) < 0)
+    assert(compare(laterTimestamp, earlierTimestamp, TIMESTAMP) > 0)
+
+    assert(compare("apple", "banana", STRING) < 0)
+    assert(compare("banana", "apple", STRING) > 0)
+    assert(compare("same", "same", STRING) == 0)
+
+    val firstBytes = Array[Byte](0, 1, 2)
+    val secondBytes = Array[Byte](0, 2, 0)
+    assert(compare(firstBytes, secondBytes, BINARY) < 0)
+  }
+
+  test("add correctly handles null values as identity for numeric types") {
+    val integerZeroFromAdd = add(null, null, INTEGER).asInstanceOf[Int]
+    assert(integerZeroFromAdd == 0)
+
+    val rightOnlyResult =
+      add(null, java.lang.Integer.valueOf(5), INTEGER).asInstanceOf[Int]
+    assert(rightOnlyResult == 5)
+
+    val leftOnlyResult =
+      add(java.lang.Integer.valueOf(7), null, INTEGER).asInstanceOf[Int]
+    assert(leftOnlyResult == 7)
+  }
+
+  test("add correctly adds integer, long, double and timestamp values") {
+    val integerSum =
+      add(java.lang.Integer.valueOf(3), java.lang.Integer.valueOf(4), INTEGER)
+        .asInstanceOf[Int]
+    assert(integerSum == 7)
+
+    val longSum =
+      add(java.lang.Long.valueOf(10L), java.lang.Long.valueOf(5L), LONG)
+        .asInstanceOf[Long]
+    assert(longSum == 15L)
+
+    val doubleSum =
+      add(java.lang.Double.valueOf(1.5), java.lang.Double.valueOf(2.5), DOUBLE)
+        .asInstanceOf[Double]
+    assert(doubleSum == 4.0)
+
+    val firstTimestamp = new java.sql.Timestamp(1000L)
+    val secondTimestamp = new java.sql.Timestamp(2500L)
+    val timestampSum =
+      add(firstTimestamp, secondTimestamp, 
TIMESTAMP).asInstanceOf[java.sql.Timestamp]
+    assert(timestampSum.getTime == 3500L)
+  }
+
+  test("zeroValue returns correct numeric and timestamp identity values") {
+    val integerZero = zeroValue(INTEGER).asInstanceOf[Int]
+    val longZero = zeroValue(LONG).asInstanceOf[Long]
+    val doubleZero = zeroValue(DOUBLE).asInstanceOf[Double]
+    val timestampZero = zeroValue(TIMESTAMP).asInstanceOf[java.sql.Timestamp]
+
+    assert(integerZero == 0)
+    assert(longZero == 0L)
+    assert(doubleZero == 0.0d)
+    assert(timestampZero.getTime == 0L)
+  }
+
+  test("zeroValue returns empty binary array and fails for unsupported types") 
{
+    val binaryZero = zeroValue(BINARY).asInstanceOf[Array[Byte]]
+    assert(binaryZero.isEmpty)
+
+    assertThrows[UnsupportedOperationException] {
+      zeroValue(STRING)
+    }
+  }
+
+  test("maxValue returns correct maximum numeric bounds") {
+    val integerMax = maxValue(INTEGER).asInstanceOf[Int]
+    val longMax = maxValue(LONG).asInstanceOf[Long]
+    val doubleMax = maxValue(DOUBLE).asInstanceOf[Double]
+
+    assert(integerMax == Int.MaxValue)
+    assert(longMax == Long.MaxValue)
+    assert(doubleMax == Double.MaxValue)
+  }
+
+  test("maxValue returns maximum timestamp and fails for unsupported types") {
+    val timestampMax = maxValue(TIMESTAMP).asInstanceOf[java.sql.Timestamp]
+    assert(timestampMax.getTime == Long.MaxValue)
+
+    assertThrows[UnsupportedOperationException] {
+      maxValue(BOOLEAN)
+    }
+  }
+
+  test("minValue returns correct minimum numeric bounds") {
+    val integerMin = minValue(INTEGER).asInstanceOf[Int]
+    val longMin = minValue(LONG).asInstanceOf[Long]
+    val doubleMin = minValue(DOUBLE).asInstanceOf[Double]
+
+    assert(integerMin == Int.MinValue)
+    assert(longMin == Long.MinValue)
+    assert(doubleMin == java.lang.Double.MIN_VALUE)
+  }
+
+  test("minValue returns timestamp epoch and empty binary array, and fails for 
unsupported types") {
+    val timestampMin = minValue(TIMESTAMP).asInstanceOf[java.sql.Timestamp]
+    val binaryMin = minValue(BINARY).asInstanceOf[Array[Byte]]
+
+    assert(timestampMin.getTime == 0L)
+
+    assert(binaryMin.isEmpty)
+
+    assertThrows[UnsupportedOperationException] {
+      minValue(STRING)
+    }
+  }
 }
diff --git 
a/common/workflow-operator/src/main/scala/org/apache/amber/operator/aggregate/AggregationOperation.scala
 
b/common/workflow-operator/src/main/scala/org/apache/amber/operator/aggregate/AggregationOperation.scala
index 8818d831e1..931163e9ed 100644
--- 
a/common/workflow-operator/src/main/scala/org/apache/amber/operator/aggregate/AggregationOperation.scala
+++ 
b/common/workflow-operator/src/main/scala/org/apache/amber/operator/aggregate/AggregationOperation.scala
@@ -21,11 +21,9 @@ package org.apache.amber.operator.aggregate
 
 import com.fasterxml.jackson.annotation.{JsonIgnore, JsonProperty, 
JsonPropertyDescription}
 import com.kjetland.jackson.jsonSchema.annotations.{JsonSchemaInject, 
JsonSchemaTitle}
-import org.apache.amber.core.tuple.AttributeTypeUtils.parseTimestamp
-import org.apache.amber.core.tuple.{Attribute, AttributeType, Tuple}
+import org.apache.amber.core.tuple.{Attribute, AttributeType, 
AttributeTypeUtils, Tuple}
 import org.apache.amber.operator.metadata.annotations.AutofillAttributeName
 
-import java.sql.Timestamp
 import javax.validation.constraints.NotNull
 
 case class AveragePartialObj(sum: Double, count: Double) extends Serializable 
{}
@@ -130,12 +128,12 @@ class AggregationOperation {
       )
     }
     new DistributedAggregation[Object](
-      () => zero(attributeType),
+      () => AttributeTypeUtils.zeroValue(attributeType),
       (partial, tuple) => {
         val value = tuple.getField[Object](attribute)
-        add(partial, value, attributeType)
+        AttributeTypeUtils.add(partial, value, attributeType)
       },
-      (partial1, partial2) => add(partial1, partial2, attributeType),
+      (partial1, partial2) => AttributeTypeUtils.add(partial1, partial2, 
attributeType),
       partial => partial
     )
   }
@@ -190,15 +188,16 @@ class AggregationOperation {
       )
     }
     new DistributedAggregation[Object](
-      () => maxValue(attributeType),
+      () => AttributeTypeUtils.maxValue(attributeType),
       (partial, tuple) => {
         val value = tuple.getField[Object](attribute)
-        val comp = compare(value, partial, attributeType)
+        val comp = AttributeTypeUtils.compare(value, partial, attributeType)
         if (value != null && comp < 0) value else partial
       },
       (partial1, partial2) =>
-        if (compare(partial1, partial2, attributeType) < 0) partial1 else 
partial2,
-      partial => if (partial == maxValue(attributeType)) null else partial
+        if (AttributeTypeUtils.compare(partial1, partial2, attributeType) < 0) 
partial1
+        else partial2,
+      partial => if (partial == AttributeTypeUtils.maxValue(attributeType)) 
null else partial
     )
   }
 
@@ -214,15 +213,16 @@ class AggregationOperation {
       )
     }
     new DistributedAggregation[Object](
-      () => minValue(attributeType),
+      () => AttributeTypeUtils.minValue(attributeType),
       (partial, tuple) => {
         val value = tuple.getField[Object](attribute)
-        val comp = compare(value, partial, attributeType)
+        val comp = AttributeTypeUtils.compare(value, partial, attributeType)
         if (value != null && comp > 0) value else partial
       },
       (partial1, partial2) =>
-        if (compare(partial1, partial2, attributeType) > 0) partial1 else 
partial2,
-      partial => if (partial == maxValue(attributeType)) null else partial
+        if (AttributeTypeUtils.compare(partial1, partial2, attributeType) > 0) 
partial1
+        else partial2,
+      partial => if (partial == AttributeTypeUtils.maxValue(attributeType)) 
null else partial
     )
   }
 
@@ -232,7 +232,7 @@ class AggregationOperation {
       return None
 
     if (tuple.getSchema.getAttribute(attribute).getType == 
AttributeType.TIMESTAMP)
-      Option(parseTimestamp(value.toString).getTime.toDouble)
+      
Option(AttributeTypeUtils.parseTimestamp(value.toString).getTime.toDouble)
     else Option(value.toString.toDouble)
   }
 
@@ -254,94 +254,4 @@ class AggregationOperation {
       }
     )
   }
-
-  // return a.compare(b),
-  // < 0 if a < b,
-  // > 0 if a > b,
-  //   0 if a = b
-  private def compare(a: Object, b: Object, attributeType: AttributeType): Int 
= {
-    if (a == null && b == null) {
-      return 0
-    } else if (a == null) {
-      return -1
-    } else if (b == null) {
-      return 1
-    }
-    attributeType match {
-      case AttributeType.INTEGER => 
a.asInstanceOf[Integer].compareTo(b.asInstanceOf[Integer])
-      case AttributeType.DOUBLE =>
-        
a.asInstanceOf[java.lang.Double].compareTo(b.asInstanceOf[java.lang.Double])
-      case AttributeType.LONG =>
-        
a.asInstanceOf[java.lang.Long].compareTo(b.asInstanceOf[java.lang.Long])
-      case AttributeType.TIMESTAMP =>
-        
a.asInstanceOf[Timestamp].getTime.compareTo(b.asInstanceOf[Timestamp].getTime)
-      case _ =>
-        throw new UnsupportedOperationException(
-          "Unsupported attribute type for comparison: " + attributeType
-        )
-    }
-  }
-
-  private def add(a: Object, b: Object, attributeType: AttributeType): Object 
= {
-    if (a == null && b == null) {
-      return zero(attributeType)
-    } else if (a == null) {
-      return b
-    } else if (b == null) {
-      return a
-    }
-    attributeType match {
-      case AttributeType.INTEGER =>
-        Integer.valueOf(a.asInstanceOf[Integer] + b.asInstanceOf[Integer])
-      case AttributeType.DOUBLE =>
-        java.lang.Double.valueOf(
-          a.asInstanceOf[java.lang.Double] + b.asInstanceOf[java.lang.Double]
-        )
-      case AttributeType.LONG =>
-        java.lang.Long.valueOf(a.asInstanceOf[java.lang.Long] + 
b.asInstanceOf[java.lang.Long])
-      case AttributeType.TIMESTAMP =>
-        new Timestamp(a.asInstanceOf[Timestamp].getTime + 
b.asInstanceOf[Timestamp].getTime)
-      case _ =>
-        throw new UnsupportedOperationException(
-          "Unsupported attribute type for addition: " + attributeType
-        )
-    }
-  }
-
-  private def zero(attributeType: AttributeType): Object =
-    attributeType match {
-      case AttributeType.INTEGER   => java.lang.Integer.valueOf(0)
-      case AttributeType.DOUBLE    => java.lang.Double.valueOf(0)
-      case AttributeType.LONG      => java.lang.Long.valueOf(0)
-      case AttributeType.TIMESTAMP => new Timestamp(0)
-      case _ =>
-        throw new UnsupportedOperationException(
-          "Unsupported attribute type for zero value: " + attributeType
-        )
-    }
-
-  private def maxValue(attributeType: AttributeType): Object =
-    attributeType match {
-      case AttributeType.INTEGER   => Integer.MAX_VALUE.asInstanceOf[Object]
-      case AttributeType.DOUBLE    => 
java.lang.Double.MAX_VALUE.asInstanceOf[Object]
-      case AttributeType.LONG      => 
java.lang.Long.MAX_VALUE.asInstanceOf[Object]
-      case AttributeType.TIMESTAMP => new Timestamp(java.lang.Long.MAX_VALUE)
-      case _ =>
-        throw new UnsupportedOperationException(
-          "Unsupported attribute type for max value: " + attributeType
-        )
-    }
-
-  private def minValue(attributeType: AttributeType): Object =
-    attributeType match {
-      case AttributeType.INTEGER   => Integer.MIN_VALUE.asInstanceOf[Object]
-      case AttributeType.DOUBLE    => 
java.lang.Double.MIN_VALUE.asInstanceOf[Object]
-      case AttributeType.LONG      => 
java.lang.Long.MIN_VALUE.asInstanceOf[Object]
-      case AttributeType.TIMESTAMP => new Timestamp(0)
-      case _ =>
-        throw new UnsupportedOperationException(
-          "Unsupported attribute type for min value: " + attributeType
-        )
-    }
-
 }
diff --git 
a/common/workflow-operator/src/main/scala/org/apache/amber/operator/sortPartitions/SortPartitionsOpExec.scala
 
b/common/workflow-operator/src/main/scala/org/apache/amber/operator/sortPartitions/SortPartitionsOpExec.scala
index ac6a9da59c..5748a41da6 100644
--- 
a/common/workflow-operator/src/main/scala/org/apache/amber/operator/sortPartitions/SortPartitionsOpExec.scala
+++ 
b/common/workflow-operator/src/main/scala/org/apache/amber/operator/sortPartitions/SortPartitionsOpExec.scala
@@ -20,7 +20,7 @@
 package org.apache.amber.operator.sortPartitions
 
 import org.apache.amber.core.executor.OperatorExecutor
-import org.apache.amber.core.tuple.{AttributeType, Tuple, TupleLike}
+import org.apache.amber.core.tuple.{AttributeTypeUtils, Tuple, TupleLike}
 import org.apache.amber.util.JSONUtils.objectMapper
 
 import scala.collection.mutable.ArrayBuffer
@@ -47,18 +47,11 @@ class SortPartitionsOpExec(descString: String) extends 
OperatorExecutor {
 
   override def onFinish(port: Int): Iterator[TupleLike] = sortTuples()
 
-  private def compareTuples(t1: Tuple, t2: Tuple): Boolean = {
-    val attributeType = 
t1.getSchema.getAttribute(desc.sortAttributeName).getType
-    val attributeIndex = t1.getSchema.getIndex(desc.sortAttributeName)
-    attributeType match {
-      case AttributeType.LONG =>
-        t1.getField[Long](attributeIndex) < t2.getField[Long](attributeIndex)
-      case AttributeType.INTEGER =>
-        t1.getField[Int](attributeIndex) < t2.getField[Int](attributeIndex)
-      case AttributeType.DOUBLE =>
-        t1.getField[Double](attributeIndex) < 
t2.getField[Double](attributeIndex)
-      case _ =>
-        true // unsupported type
-    }
-  }
+  private def compareTuples(tuple1: Tuple, tuple2: Tuple): Boolean =
+    AttributeTypeUtils.compare(
+      tuple1.getField[Any](tuple1.getSchema.getIndex(desc.sortAttributeName)),
+      tuple2.getField[Any](tuple2.getSchema.getIndex(desc.sortAttributeName)),
+      tuple1.getSchema.getAttribute(desc.sortAttributeName).getType
+    ) < 0
+
 }
diff --git 
a/common/workflow-operator/src/test/scala/org/apache/amber/operator/aggregate/AggregateOpSpec.scala
 
b/common/workflow-operator/src/test/scala/org/apache/amber/operator/aggregate/AggregateOpSpec.scala
new file mode 100644
index 0000000000..9eb405d817
--- /dev/null
+++ 
b/common/workflow-operator/src/test/scala/org/apache/amber/operator/aggregate/AggregateOpSpec.scala
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.amber.operator.aggregate
+
+import org.apache.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple}
+import org.apache.amber.util.JSONUtils.objectMapper
+import org.scalatest.funsuite.AnyFunSuite
+
+class AggregateOpSpec extends AnyFunSuite {
+
+  /** Helpers */
+
+  private def makeAggregationOp(
+      fn: AggregationFunction,
+      attributeName: String,
+      resultName: String
+  ): AggregationOperation = {
+    val operation = new AggregationOperation()
+    operation.aggFunction = fn
+    operation.attribute = attributeName
+    operation.resultAttribute = resultName
+    operation
+  }
+
+  private def makeSchema(fields: (String, AttributeType)*): Schema =
+    Schema(fields.map { case (n, t) => new Attribute(n, t) }.toList)
+
+  private def makeTuple(schema: Schema, values: Any*): Tuple =
+    Tuple(schema, values.toArray)
+
+  test("getAggregationAttribute keeps original type for SUM") {
+    val operation = makeAggregationOp(AggregationFunction.SUM, "amount", 
"total_amount")
+    val attr = operation.getAggregationAttribute(AttributeType.DOUBLE)
+
+    assert(attr.getName == "total_amount")
+    assert(attr.getType == AttributeType.DOUBLE)
+  }
+
+  test("getAggregationAttribute maps COUNT result to INTEGER regardless of 
input type") {
+    val operation = makeAggregationOp(AggregationFunction.COUNT, "quantity", 
"row_count")
+    val attr = operation.getAggregationAttribute(AttributeType.LONG)
+
+    assert(attr.getName == "row_count")
+    assert(attr.getType == AttributeType.INTEGER)
+  }
+
+  test("getAggregationAttribute maps CONCAT result type to STRING") {
+    val operation = makeAggregationOp(AggregationFunction.CONCAT, "tag", 
"all_tags")
+    val attr = operation.getAggregationAttribute(AttributeType.INTEGER)
+
+    assert(attr.getName == "all_tags")
+    assert(attr.getType == AttributeType.STRING)
+  }
+
+  // 
---------------------------------------------------------------------------
+  // Basic DistributedAggregation behaviour via AggregationOperation.getAggFunc
+  // 
---------------------------------------------------------------------------
+
+  test("SUM aggregation over INTEGER column adds values correctly") {
+    val schema = makeSchema("amount" -> AttributeType.INTEGER)
+    val tuple1 = makeTuple(schema, 5)
+    val tuple2 = makeTuple(schema, 7)
+    val tuple3 = makeTuple(schema, 3)
+
+    val operation = makeAggregationOp(AggregationFunction.SUM, "amount", 
"total_amount")
+    val agg = operation.getAggFunc(AttributeType.INTEGER)
+
+    var partial = agg.init()
+    partial = agg.iterate(partial, tuple1)
+    partial = agg.iterate(partial, tuple2)
+    partial = agg.iterate(partial, tuple3)
+
+    val result = agg.finalAgg(partial).asInstanceOf[Number].intValue()
+    assert(result == 15)
+  }
+
+  test("SUM aggregation over DOUBLE column keeps fractional part") {
+    val schema = makeSchema("score" -> AttributeType.DOUBLE)
+    val tuple1 = makeTuple(schema, 1.25)
+    val tuple2 = makeTuple(schema, 2.75)
+
+    val operation = makeAggregationOp(AggregationFunction.SUM, "score", 
"total_score")
+    val agg = operation.getAggFunc(AttributeType.DOUBLE)
+
+    var partial = agg.init()
+    partial = agg.iterate(partial, tuple1)
+    partial = agg.iterate(partial, tuple2)
+
+    val result = 
agg.finalAgg(partial).asInstanceOf[java.lang.Double].doubleValue()
+    assert(math.abs(result - 4.0) < 1e-6)
+  }
+
+  test("COUNT aggregation with attribute == null counts all rows") {
+    val schema = makeSchema("points" -> AttributeType.INTEGER)
+    val tuple1 = makeTuple(schema, 10)
+    val tuple2 = makeTuple(schema, null)
+    val tuple3 = makeTuple(schema, 20)
+
+    val operation = makeAggregationOp(AggregationFunction.COUNT, null, 
"row_count")
+    val agg = operation.getAggFunc(AttributeType.INTEGER)
+
+    var partial = agg.init()
+    partial = agg.iterate(partial, tuple1)
+    partial = agg.iterate(partial, tuple2)
+    partial = agg.iterate(partial, tuple3)
+
+    val result = agg.finalAgg(partial).asInstanceOf[Number].intValue()
+    assert(result == 3)
+  }
+
+  test("COUNT aggregation with attribute set only counts non-null values") {
+    val schema = makeSchema("points" -> AttributeType.INTEGER)
+    val tuple1 = makeTuple(schema, 10)
+    val tuple2 = makeTuple(schema, null)
+    val tuple3 = makeTuple(schema, 5)
+
+    val operation = makeAggregationOp(AggregationFunction.COUNT, "points", 
"non_null_points")
+    val agg = operation.getAggFunc(AttributeType.INTEGER)
+
+    var partial = agg.init()
+    partial = agg.iterate(partial, tuple1)
+    partial = agg.iterate(partial, tuple2)
+    partial = agg.iterate(partial, tuple3)
+
+    val result = agg.finalAgg(partial).asInstanceOf[Number].intValue()
+    assert(result == 2)
+  }
+
+  test("CONCAT aggregation concatenates string representations with commas") {
+    val schema = makeSchema("tag" -> AttributeType.STRING)
+    val tuple1 = makeTuple(schema, "red")
+    val tuple2 = makeTuple(schema, null)
+    val tuple3 = makeTuple(schema, "blue")
+
+    val operation = makeAggregationOp(AggregationFunction.CONCAT, "tag", 
"all_tags")
+    val agg = operation.getAggFunc(AttributeType.STRING)
+
+    var partial = agg.init()
+    partial = agg.iterate(partial, tuple1)
+    partial = agg.iterate(partial, tuple2)
+    partial = agg.iterate(partial, tuple3)
+
+    val result = agg.finalAgg(partial).asInstanceOf[String]
+    assert(result == "red,,blue")
+  }
+
+  test("MIN aggregation finds smallest INTEGER and returns null when given no 
values") {
+    val schema = makeSchema("temperature" -> AttributeType.INTEGER)
+    val tuple1 = makeTuple(schema, 10)
+    val tuple2 = makeTuple(schema, -2)
+    val tuple3 = makeTuple(schema, 5)
+
+    val operation = makeAggregationOp(AggregationFunction.MIN, "temperature", 
"min_temp")
+    val agg = operation.getAggFunc(AttributeType.INTEGER)
+
+    // Empty case: never iterate, just finalize init
+    val emptyPartial = agg.init()
+    val emptyResult = agg.finalAgg(emptyPartial)
+    assert(emptyResult == null)
+
+    // Non-empty case
+    var partial = agg.init()
+    partial = agg.iterate(partial, tuple1)
+    partial = agg.iterate(partial, tuple2)
+    partial = agg.iterate(partial, tuple3)
+
+    val result = agg.finalAgg(partial).asInstanceOf[Number].intValue()
+    assert(result == -2)
+  }
+
+  test("MAX aggregation finds largest LONG value") {
+    val schema = makeSchema("latency" -> AttributeType.LONG)
+    val tuple1 = makeTuple(schema, 100L)
+    val tuple2 = makeTuple(schema, 50L)
+    val tuple3 = makeTuple(schema, 250L)
+
+    val operation = makeAggregationOp(AggregationFunction.MAX, "latency", 
"max_latency")
+    val agg = operation.getAggFunc(AttributeType.LONG)
+
+    var partial = agg.init()
+    partial = agg.iterate(partial, tuple1)
+    partial = agg.iterate(partial, tuple2)
+    partial = agg.iterate(partial, tuple3)
+
+    val result = agg.finalAgg(partial).asInstanceOf[java.lang.Long].longValue()
+    assert(result == 250L)
+  }
+
+  test("AVERAGE aggregation ignores nulls and returns null when all values are 
null") {
+    val schema = makeSchema("price" -> AttributeType.DOUBLE)
+    val tuple1 = makeTuple(schema, 10.0)
+    val tuple2 = makeTuple(schema, null)
+    val tuple3 = makeTuple(schema, 20.0)
+
+    val operation = makeAggregationOp(AggregationFunction.AVERAGE, "price", 
"avg_price")
+    val agg = operation.getAggFunc(AttributeType.DOUBLE)
+
+    // Mixed null and non-null
+    var partial = agg.init()
+    partial = agg.iterate(partial, tuple1)
+    partial = agg.iterate(partial, tuple2)
+    partial = agg.iterate(partial, tuple3)
+
+    val avg = 
agg.finalAgg(partial).asInstanceOf[java.lang.Double].doubleValue()
+    assert(math.abs(avg - 15.0) < 1e-6)
+
+    // All nulls
+    val allNull = makeTuple(schema, null)
+    var partialAllNull = agg.init()
+    partialAllNull = agg.iterate(partialAllNull, allNull)
+    val allNullResult = agg.finalAgg(partialAllNull)
+    assert(allNullResult == null)
+  }
+
+  // 
---------------------------------------------------------------------------
+  // getFinal behaviour
+  // 
---------------------------------------------------------------------------
+
+  test("getFinal rewrites COUNT into SUM over the intermediate result 
attribute") {
+    val operation = makeAggregationOp(AggregationFunction.COUNT, "price", 
"price_count")
+    val finalOp = operation.getFinal
+
+    assert(finalOp.aggFunction == AggregationFunction.SUM)
+    assert(finalOp.attribute == "price_count")
+    assert(finalOp.resultAttribute == "price_count")
+  }
+
+  test("getFinal keeps non-COUNT aggregation function and rewires attribute to 
resultAttribute") {
+    val operation = makeAggregationOp(AggregationFunction.SUM, "amount", 
"total_amount")
+    val finalOp = operation.getFinal
+
+    assert(finalOp.aggFunction == AggregationFunction.SUM)
+    assert(finalOp.attribute == "total_amount")
+    assert(finalOp.resultAttribute == "total_amount")
+  }
+
+  // 
---------------------------------------------------------------------------
+  // AggregateOpExec: integration-style tests with groupBy
+  // 
---------------------------------------------------------------------------
+
+  test("AggregateOpExec groups by a single key and computes SUM per group") {
+    // schema: city (group key), sales
+    val schema = makeSchema(
+      "city" -> AttributeType.STRING,
+      "sales" -> AttributeType.INTEGER
+    )
+
+    val tuple1 = makeTuple(schema, "NY", 10)
+    val tuple2 = makeTuple(schema, "SF", 20)
+    val tuple3 = makeTuple(schema, "NY", 5)
+
+    val desc = new AggregateOpDesc()
+    val sumAgg = makeAggregationOp(AggregationFunction.SUM, "sales", 
"total_sales")
+    desc.aggregations = List(sumAgg)
+    desc.groupByKeys = List("city")
+
+    val descJson = objectMapper.writeValueAsString(desc)
+
+    val exec = new AggregateOpExec(descJson)
+    exec.open()
+    exec.processTuple(tuple1, 0)
+    exec.processTuple(tuple2, 0)
+    exec.processTuple(tuple3, 0)
+
+    val results = exec.onFinish(0).toList
+
+    // Expect two output rows: (NY, 15) and (SF, 20)
+    val resultMap = results.map { tupleLike =>
+      val fields = tupleLike.getFields
+      val city = fields(0).asInstanceOf[String]
+      val total = fields(1).asInstanceOf[Number].intValue()
+      city -> total
+    }.toMap
+
+    assert(resultMap.size == 2)
+    assert(resultMap("NY") == 15)
+    assert(resultMap("SF") == 20)
+  }
+
+  test("AggregateOpExec performs global SUM and COUNT when there are no 
groupBy keys") {
+    // schema: region (ignored for aggregation), revenue
+    val schema = makeSchema(
+      "region" -> AttributeType.STRING,
+      "revenue" -> AttributeType.INTEGER
+    )
+
+    val tuple1 = makeTuple(schema, "west", 100)
+    val tuple2 = makeTuple(schema, "east", 200)
+    val tuple3 = makeTuple(schema, "west", 50)
+
+    val desc = new AggregateOpDesc()
+    val sumAgg = makeAggregationOp(AggregationFunction.SUM, "revenue", 
"total_revenue")
+    val countAgg = makeAggregationOp(AggregationFunction.COUNT, "revenue", 
"row_count")
+    desc.aggregations = List(sumAgg, countAgg)
+    desc.groupByKeys = List() // global aggregation
+
+    val descJson = objectMapper.writeValueAsString(desc)
+
+    val exec = new AggregateOpExec(descJson)
+    exec.open()
+    exec.processTuple(tuple1, 0)
+    exec.processTuple(tuple2, 0)
+    exec.processTuple(tuple3, 0)
+
+    val results = exec.onFinish(0).toList
+    assert(results.size == 1)
+
+    val fields = results.head.getFields
+    // No group keys, so fields(0) is SUM(revenue), fields(1) is COUNT(revenue)
+    val totalRevenue = fields(0).asInstanceOf[Number].intValue()
+    val rowCount = fields(1).asInstanceOf[Number].intValue()
+
+    assert(totalRevenue == 350)
+    assert(rowCount == 3)
+  }
+}


Reply via email to