This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch scala_compat
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git

commit 9b6ab613cc6f5a5eeefbd5ff0f72bb0b3ba25c62
Author: Jon Malkin <[email protected]>
AuthorDate: Thu Jan 23 18:24:20 2025 -0800

    Rename KllExpressions file, make compatible with scala 2.12 and 2.13
---
 .github/workflows/ci.yaml                          | 10 +++--
 build.sbt                                          | 17 +++++---
 ...ons.scala => KllDoublesSketchExpressions.scala} |  4 +-
 src/test/scala/org/apache/spark/sql/KllTest.scala  | 45 +++++++++++-----------
 .../scala/org/apache/spark/sql/ThetaTest.scala     | 12 +++---
 5 files changed, 49 insertions(+), 39 deletions(-)

diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 42d8359..28e5cae 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -22,10 +22,12 @@ jobs:
       fail-fast: false
       matrix:
         jdk: [ 8, 11, 17 ]
+        scala: [ 2.12.20, 2.13.16 ]
         spark: [ 3.4.4, 3.5.4 ]
 
     env:
       JDK_VERSION: ${{ matrix.jdk }}
+      SCALA_VERSION: ${{ matrix.scala }}
       SPARK_VERSION: ${{ matrix.spark }}
 
     steps:
@@ -38,8 +40,8 @@ jobs:
         uses: actions/cache@v4
         with:
           path: ~/.m2/repository
-          key: build-${{ runner.os }}-jdk-${{ matrix.jdk }}-spark-${{ 
matrix.spark }}-${{ hashFiles('**/pom.xml') }}
-          restore-keys: build-${{ runner.os }}-jdk-${{matrix.jdk}}-spark-${{ 
matrix.spark }}-maven-
+          key: build-${{ runner.os }}-jdk-${{ matrix.jdk }}-scala-${{ 
matrix.scala }}-spark-${{ matrix.spark }}-${{ hashFiles('**/pom.xml') }}
+          restore-keys: build-${{ runner.os }}-jdk-${{matrix.jdk}}-scala-${{ 
matrix.scala }}-spark-${{ matrix.spark }}-maven-
 
       - name: Setup JDK
         uses: actions/setup-java@v4
@@ -53,9 +55,11 @@ jobs:
       - name: Setup SBT
         uses: sbt/setup-sbt@v1
 
-      - name: Echo Java Version
+      - name: Echo config versions
         run: >
           java -version
+          echo Scala version: $SCALA_VERSION
+          echo Spark version: $SPARK_VERSION
 
       - name: Build and test
         run: >
diff --git a/build.sbt b/build.sbt
index a399f4f..acef8a1 100644
--- a/build.sbt
+++ b/build.sbt
@@ -1,3 +1,4 @@
+import scala.xml.dtd.DEFAULT
 /*
  * Licensed to the Apache Software Foundation (ASF) under one or more
  * contributor license agreements.  See the NOTICE file distributed with
@@ -17,15 +18,21 @@
 
 name := "datasketches-spark"
 version := "1.0-SNAPSHOT"
-scalaVersion := "2.12.20"
+
+DEFAULT_SCALA_VERSION := "2.12.20"
+DEFAULT_SPARK_VERSION := "3.5.4"
+DEFAULT_JDK_VERSION := "11"
 
 organization := "org.apache.datasketches"
 description := "The Apache DataSketches package for Spark"
 
 licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0";))
 
+val scalaVersion = settingKey[String]("The version of Scala")
+scalaVersion := sys.env.getOrElse("SCALA_VERSION", DEFAULT_SCALA_VERSION)
+
 val sparkVersion = settingKey[String]("The version of Spark")
-sparkVersion := sys.env.getOrElse("SPARK_VERSION", "3.5.4")
+sparkVersion := sys.env.getOrElse("SPARK_VERSION", DEFAULT_SPARK_VERSION)
 
 // determine our java version
 val jvmVersionString = settingKey[String]("The JVM version")
@@ -45,7 +52,7 @@ val jvmVersionMap = Map(
 val jvmVersion = settingKey[String]("The JVM major version")
 jvmVersion := jvmVersionMap.collectFirst {
   case (prefix, (major, _)) if jvmVersionString.value.startsWith(prefix) => 
major
-}.getOrElse("11")
+}.getOrElse(DEFAULT_JDK_VERSION)
 
 // look up the associated datasketches-java version
 val dsJavaVersion = settingKey[String]("The DataSketches Java version")
@@ -59,8 +66,8 @@ Test / scalacOptions ++= Seq("-encoding", "UTF-8", 
"-release", jvmVersion.value)
 
 libraryDependencies ++= Seq(
   "org.apache.datasketches" % "datasketches-java" % dsJavaVersion.value % 
"compile",
-  "org.scala-lang" % "scala-library" % "2.12.6",
-  "org.apache.spark" %% "spark-sql" % sparkVersion.value % "provided",
+  "org.scala-lang" % "scala-library" % scalaVersion.value, // scala3-library 
may need to use %%
+  ("org.apache.spark" %% "spark-sql" % sparkVersion.value % 
"provided").cross(CrossVersion.for3Use2_13),
   "org.scalatest" %% "scalatest" % "3.2.19" % "test",
   "org.scalatestplus" %% "junit-4-13" % "3.2.19.0" % "test"
 )
diff --git 
a/src/main/scala/org/apache/spark/sql/kll/expressions/KllExpressions.scala 
b/src/main/scala/org/apache/spark/sql/kll/expressions/KllDoublesSketchExpressions.scala
similarity index 99%
rename from 
src/main/scala/org/apache/spark/sql/kll/expressions/KllExpressions.scala
rename to 
src/main/scala/org/apache/spark/sql/kll/expressions/KllDoublesSketchExpressions.scala
index 246af27..0ff03e6 100644
--- a/src/main/scala/org/apache/spark/sql/kll/expressions/KllExpressions.scala
+++ 
b/src/main/scala/org/apache/spark/sql/kll/expressions/KllDoublesSketchExpressions.scala
@@ -260,7 +260,7 @@ case class KllDoublesSketchGetPmfCdf(sketchExpr: Expression,
     if (!isInclusiveExpr.foldable) {
       return TypeCheckResult.TypeCheckFailure(s"isInclusiveExpr must be 
foldable, but got: ${isInclusiveExpr}")
     }
-    if (splitPointsExpr.eval().asInstanceOf[GenericArrayData].numElements == 
0) {
+    if (splitPointsExpr.eval().asInstanceOf[GenericArrayData].numElements() == 
0) {
       return TypeCheckResult.TypeCheckFailure(s"splitPointsExpr must not be 
empty")
     }
 
@@ -269,7 +269,7 @@ case class KllDoublesSketchGetPmfCdf(sketchExpr: Expression,
 
   override def nullSafeEval(sketchInput: Any, splitPointsInput: Any, 
isInclusiveInput: Any): Any = {
     val sketchBytes = sketchInput.asInstanceOf[Array[Byte]]
-    val splitPoints = 
splitPointsInput.asInstanceOf[GenericArrayData].toDoubleArray
+    val splitPoints = 
splitPointsInput.asInstanceOf[GenericArrayData].toDoubleArray()
     val sketch = KllDoublesSketch.wrap(Memory.wrap(sketchBytes))
 
     val result: Array[Double] =
diff --git a/src/test/scala/org/apache/spark/sql/KllTest.scala 
b/src/test/scala/org/apache/spark/sql/KllTest.scala
index c18f6cd..5123774 100644
--- a/src/test/scala/org/apache/spark/sql/KllTest.scala
+++ b/src/test/scala/org/apache/spark/sql/KllTest.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
 
 import scala.util.Random
 import org.apache.spark.sql.functions._
-import scala.collection.mutable.WrappedArray
 import org.apache.spark.sql.types.{StructType, StructField, IntegerType, 
BinaryType}
 
 import org.apache.spark.sql.functions_datasketches_kll._
@@ -31,7 +30,7 @@ class KllTest extends SparkSessionManager {
   import spark.implicits._
 
   // helper method to check if two arrays are equal
-  private def compareArrays(ref: Array[Double], tst: WrappedArray[Double]) {
+  private def compareArrays(ref: Array[Double], tst: Array[Double]): Unit = {
     val tstArr = tst.toArray
     if (ref.length != tstArr.length)
       throw new AssertionError("Array lengths do not match: " + ref.length + " 
!= " + tstArr.length)
@@ -48,7 +47,7 @@ class KllTest extends SparkSessionManager {
     // produce a List[Row] of (id, sk)
     for (i <- 1 to numClass) yield {
       val sk = KllDoublesSketch.newHeapInstance(200)
-      for (j <- 0 until numSamples) sk.update(Random.nextDouble)
+      for (j <- 0 until numSamples) sk.update(Random.nextDouble())
       dataList.add(Row(i, sk))
     }
 
@@ -68,7 +67,7 @@ class KllTest extends SparkSessionManager {
     // produce a Seq(Array(id, sk))
     val data = for (i <- 1 to numClass) yield {
       val sk = KllDoublesSketch.newHeapInstance(200)
-      for (j <- 0 until numSamples) sk.update(Random.nextDouble)
+      for (j <- 0 until numSamples) sk.update(Random.nextDouble())
       Row(i, sk.toByteArray)
     }
 
@@ -90,7 +89,7 @@ class KllTest extends SparkSessionManager {
     val sketchDf = data.agg(kll_sketch_double_agg_build("value").as("sketch"))
     val result: Row = 
sketchDf.select(kll_sketch_double_get_min($"sketch").as("min"),
                                       
kll_sketch_double_get_max($"sketch").as("max")
-                                      ).head
+                                      ).head()
 
     val minValue = result.getAs[Double]("min")
     val maxValue = result.getAs[Double]("max")
@@ -103,19 +102,19 @@ class KllTest extends SparkSessionManager {
       kll_sketch_double_get_pmf($"sketch", splitPoints, 
false).as("pmf_exclusive"),
       kll_sketch_double_get_cdf($"sketch", splitPoints).as("cdf_inclusive"),
       kll_sketch_double_get_cdf($"sketch", splitPoints, 
false).as("cdf_exclusive")
-    ).head
+    ).head()
 
     val pmf_incl = Array[Double](0.2, 0.3, 0.5, 0.0)
-    compareArrays(pmf_incl, 
pmfCdfResult.getAs[WrappedArray[Double]]("pmf_inclusive"))
+    compareArrays(pmf_incl, 
pmfCdfResult.getAs[Seq[Double]]("pmf_inclusive").toArray)
 
     val pmf_excl = Array[Double](0.2, 0.29, 0.51, 0.0)
-    compareArrays(pmf_excl, 
pmfCdfResult.getAs[WrappedArray[Double]]("pmf_exclusive"))
+    compareArrays(pmf_excl, 
pmfCdfResult.getAs[Seq[Double]]("pmf_exclusive").toArray)
 
     val cdf_incl = Array[Double](0.2, 0.5, 1.0, 1.0)
-    compareArrays(cdf_incl, 
pmfCdfResult.getAs[WrappedArray[Double]]("cdf_inclusive"))
+    compareArrays(cdf_incl, 
pmfCdfResult.getAs[Seq[Double]]("cdf_inclusive").toArray)
 
     val cdf_excl = Array[Double](0.2, 0.49, 1.0, 1.0)
-    compareArrays(cdf_excl, 
pmfCdfResult.getAs[WrappedArray[Double]]("cdf_exclusive"))
+    compareArrays(cdf_excl, 
pmfCdfResult.getAs[Seq[Double]]("cdf_exclusive").toArray)
   }
 
   test("Kll Doubles Sketch via SQL") {
@@ -135,8 +134,8 @@ class KllTest extends SparkSessionManager {
       |  data_table
     """.stripMargin
     )
-    val minValue = kllDf.head.getAs[Double]("min")
-    val maxValue = kllDf.head.getAs[Double]("max")
+    val minValue = kllDf.head().getAs[Double]("min")
+    val maxValue = kllDf.head().getAs[Double]("max")
     assert(minValue == 1.0)
     assert(maxValue == n.toDouble)
 
@@ -154,26 +153,26 @@ class KllTest extends SparkSessionManager {
       |   FROM
       |     data_table) t
       """.stripMargin
-    ).head
+    ).head()
 
     val pmf_incl = Array[Double](0.2, 0.3, 0.5, 0.0)
-    compareArrays(pmf_incl, 
pmfCdfResult.getAs[WrappedArray[Double]]("pmf_inclusive"))
+    compareArrays(pmf_incl, 
pmfCdfResult.getAs[Seq[Double]]("pmf_inclusive").toArray)
 
     val pmf_excl = Array[Double](0.2, 0.29, 0.51, 0.0)
-    compareArrays(pmf_excl, 
pmfCdfResult.getAs[WrappedArray[Double]]("pmf_exclusive"))
+    compareArrays(pmf_excl, 
pmfCdfResult.getAs[Seq[Double]]("pmf_exclusive").toArray)
 
     val cdf_incl = Array[Double](0.2, 0.5, 1.0, 1.0)
-    compareArrays(cdf_incl, 
pmfCdfResult.getAs[WrappedArray[Double]]("cdf_inclusive"))
+    compareArrays(cdf_incl, 
pmfCdfResult.getAs[Seq[Double]]("cdf_inclusive").toArray)
 
     val cdf_excl = Array[Double](0.2, 0.49, 1.0, 1.0)
-    compareArrays(cdf_excl, 
pmfCdfResult.getAs[WrappedArray[Double]]("cdf_exclusive"))
+    compareArrays(cdf_excl, 
pmfCdfResult.getAs[Seq[Double]]("cdf_exclusive").toArray)
   }
 
   test("KLL Doubles Merge via Scala") {
     val data = generateData().toDF("id", "value")
 
     // compute global min and max
-    val minMax: Row = data.agg(min("value").as("min"), 
max("value").as("max")).collect.head
+    val minMax: Row = data.agg(min("value").as("min"), 
max("value").as("max")).collect().head
     val globalMin = minMax.getAs[Double]("min")
     val globalMax = minMax.getAs[Double]("max")
 
@@ -187,7 +186,7 @@ class KllTest extends SparkSessionManager {
     // check min and max
     var result: Row = 
mergedSketchDf.select(kll_sketch_double_get_min($"sketch").as("min"),
                                             
kll_sketch_double_get_max($"sketch").as("max"))
-                                    .head
+                                    .head()
 
     var sketchMin = result.getAs[Double]("min")
     var sketchMax = result.getAs[Double]("max")
@@ -202,7 +201,7 @@ class KllTest extends SparkSessionManager {
     // check min and max
     result = 
mergedSketchDf.select(kll_sketch_double_get_min($"sketch").as("min"),
                                    
kll_sketch_double_get_max($"sketch").as("max"))
-                           .head
+                           .head()
 
     sketchMin = result.getAs[Double]("min")
     sketchMax = result.getAs[Double]("max")
@@ -219,7 +218,7 @@ class KllTest extends SparkSessionManager {
     data.createOrReplaceTempView("data_table")
 
     // compute global min and max from dataframe
-    val minMax: Row = data.agg(min("value").as("min"), 
max("value").as("max")).head
+    val minMax: Row = data.agg(min("value").as("min"), 
max("value").as("max")).head()
     val globalMin = minMax.getAs[Double]("min")
     val globalMax = minMax.getAs[Double]("max")
 
@@ -255,7 +254,7 @@ class KllTest extends SparkSessionManager {
     )
 
     // check min and max
-    var result: Row = mergedSketchDf.head
+    var result: Row = mergedSketchDf.head()
     var sketchMin = result.getAs[Double]("min")
     var sketchMax = result.getAs[Double]("max")
 
@@ -279,7 +278,7 @@ class KllTest extends SparkSessionManager {
     )
 
     // check min and max
-    result = mergedSketchDf.head
+    result = mergedSketchDf.head()
     sketchMin = result.getAs[Double]("min")
     sketchMax = result.getAs[Double]("max")
 
diff --git a/src/test/scala/org/apache/spark/sql/ThetaTest.scala 
b/src/test/scala/org/apache/spark/sql/ThetaTest.scala
index 4b7f0a4..6c93d25 100644
--- a/src/test/scala/org/apache/spark/sql/ThetaTest.scala
+++ b/src/test/scala/org/apache/spark/sql/ThetaTest.scala
@@ -28,7 +28,7 @@ class ThetaTest extends SparkSessionManager {
     val data = (for (i <- 1 to n) yield i).toDF("value")
 
     val sketchDf = data.agg(theta_sketch_agg_build("value").as("sketch"))
-    val result: Row = 
sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head
+    val result: Row = 
sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head()
 
     assert(result.getAs[Double]("estimate") == 100.0)
   }
@@ -46,7 +46,7 @@ class ThetaTest extends SparkSessionManager {
       FROM
         theta_input_table
     """)
-    assert(df.head.getAs[Double]("estimate") == 100.0)
+    assert(df.head().getAs[Double]("estimate") == 100.0)
   }
 
   test("Theta Sketch build via SQL with lgk") {
@@ -62,7 +62,7 @@ class ThetaTest extends SparkSessionManager {
       FROM
         theta_input_table
     """)
-    assert(df.head.getAs[Double]("estimate") == 100.0)
+    assert(df.head().getAs[Double]("estimate") == 100.0)
   }
 
   test("Theta Union via Scala") {
@@ -72,7 +72,7 @@ class ThetaTest extends SparkSessionManager {
 
     val groupedDf = 
data.groupBy("group").agg(theta_sketch_agg_build("value").as("sketch"))
     val mergedDf = groupedDf.agg(theta_sketch_agg_union("sketch").as("merged"))
-    val result: Row = 
mergedDf.select(theta_sketch_get_estimate("merged").as("estimate")).head
+    val result: Row = 
mergedDf.select(theta_sketch_get_estimate("merged").as("estimate")).head()
     assert(result.getAs[Double]("estimate") == numDistinct)
   }
 
@@ -100,7 +100,7 @@ class ThetaTest extends SparkSessionManager {
       FROM
         theta_sketch_table
     """)
-    assert(mergedDf.head.getAs[Double]("estimate") == numDistinct)
+    assert(mergedDf.head().getAs[Double]("estimate") == numDistinct)
   }
 
   test("Theta Union via SQL with lgk") {
@@ -126,7 +126,7 @@ class ThetaTest extends SparkSessionManager {
       FROM
         theta_sketch_table
     """)
-    assert(mergedDf.head.getAs[Double]("estimate") == numDistinct)
+    assert(mergedDf.head().getAs[Double]("estimate") == numDistinct)
   }
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to