This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 29e09473a81 [SPARK-44154] Implement bitmap functions
29e09473a81 is described below

commit 29e09473a816f1cb0ddea3e49d6dcb492ad473d9
Author: Gene Pang <gene.p...@databricks.com>
AuthorDate: Wed Jul 5 17:39:03 2023 -0700

    [SPARK-44154] Implement bitmap functions
    
    ### What changes were proposed in this pull request?
    
    Implemented bitmap functions. The functions are:
    - `bitmap_bucket_number()`: returns the bucket number for a given input 
number
    - `bitmap_bit_position()`: returns bit position for a given input number
    - `bitmap_count()`: returns the number of set bits from an input bitmap
    - `bitmap_construct_agg()`: aggregation function that aggregates input bit 
positions, and creates a bitmap
    - `bitmap_or_agg()`: aggregation function that performs a bitwise OR on all 
the input bitmaps
    
    ### Why are the changes needed?
    
    These functions can be used to count distinct values for integer columns. 
For example:
    
    ```sql
    SELECT sum(cnt) FROM (
      SELECT bitmap_bucket_number(c), 
bitmap_count(bitmap_construct_agg(bitmap_bit_position(c))) cnt
      FROM table GROUP BY 1
    )
    ```
    
    is equivalent to:
    
    ```sql
    SELECT count(distinct c) FROM table
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. After this PR, these functions are usable in queries.
    
    ### How was this patch tested?
    
    New tests were added.
    
    Closes #41623 from gene-db/bitmap-fns.
    
    Authored-by: Gene Pang <gene.p...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../src/main/resources/error/error-classes.json    |   6 +
 .../expressions/BitmapExpressionUtils.java         |  59 ++++
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   7 +
 .../catalyst/expressions/bitmapExpressions.scala   | 305 +++++++++++++++++++++
 .../spark/sql/errors/QueryExecutionErrors.scala    |  12 +
 .../expressions/BitmapExpressionUtilsSuite.scala   |  90 ++++++
 .../sql-functions/sql-expression-schema.md         |   5 +
 .../spark/sql/BitmapExpressionsQuerySuite.scala    | 151 ++++++++++
 .../sql/errors/QueryExecutionErrorsSuite.scala     |  28 ++
 9 files changed, 663 insertions(+)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 3126fb9519b..8bdb02470ef 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -1038,6 +1038,12 @@
     ],
     "sqlState" : "22003"
   },
+  "INVALID_BITMAP_POSITION" : {
+    "message" : [
+      "The 0-indexed bitmap position <bitPosition> is out of bounds. The 
bitmap has <bitmapNumBits> bits (<bitmapNumBytes> bytes)."
+    ],
+    "sqlState" : "22003"
+  },
   "INVALID_BOUNDARY" : {
     "message" : [
       "The boundary <boundary> is invalid: <invalidValue>."
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtils.java
new file mode 100644
index 00000000000..e11aea62664
--- /dev/null
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtils.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+/**
+ * A utility class for constructing bitmap expressions.
+ */
+public class BitmapExpressionUtils {
+  /** Number of bytes in a bitmap. */
+  public static final int NUM_BYTES = 4 * 1024;
+
+  /** Number of bits in a bitmap. */
+  public static final int NUM_BITS = 8 * NUM_BYTES;
+
+  public static long bitmapBucketNumber(long value) {
+    if (value > 0) {
+      return 1 + (value - 1) / NUM_BITS;
+    }
+    return value / NUM_BITS;
+  }
+
+  public static long bitmapBitPosition(long value) {
+    if (value > 0) {
+      // inputs: (1 -> NUM_BITS) map to positions (0 -> NUM_BITS - 1)
+      return (value - 1) % NUM_BITS;
+    }
+    return (-value) % NUM_BITS;
+  }
+
+  public static long bitmapCount(byte[] bitmap) {
+    long count = 0;
+    for (byte b : bitmap) {
+      count += Integer.bitCount(b & 0x0FF);
+    }
+    return count;
+  }
+
+  /** Merges both bitmaps and writes the result into bitmap1. */
+  public static void bitmapMerge(byte[] bitmap1, byte[] bitmap2) {
+    for (int i = 0; i < java.lang.Math.min(bitmap1.length, bitmap2.length); 
++i) {
+      bitmap1[i] = (byte) ((bitmap1[i] & 0x0FF) | (bitmap2[i] & 0x0FF));
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index eaf1bf21d34..a9bda2e0b7c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -792,6 +792,13 @@ object FunctionRegistry {
     expression[BitwiseGet]("bit_get"),
     expression[BitwiseGet]("getbit", true),
 
+    // bitmap functions and aggregates
+    expression[BitmapBucketNumber]("bitmap_bucket_number"),
+    expression[BitmapBitPosition]("bitmap_bit_position"),
+    expression[BitmapConstructAgg]("bitmap_construct_agg"),
+    expression[BitmapCount]("bitmap_count"),
+    expression[BitmapOrAgg]("bitmap_or_agg"),
+
     // json
     expression[StructsToJson]("to_json"),
     expression[JsonToStructs]("from_json"),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala
new file mode 100644
index 00000000000..350ca5c2525
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
+import org.apache.spark.sql.catalyst.trees.UnaryLike
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, 
LongType, StructType}
+
+@ExpressionDescription(
+  usage = "_FUNC_(child) - Returns the bucket number for the given input child 
expression.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(123);
+       1
+      > SELECT _FUNC_(0);
+       0
+  """,
+  since = "3.5.0",
+  group = "misc_funcs"
+)
+case class BitmapBucketNumber(child: Expression)
+  extends UnaryExpression with RuntimeReplaceable with ImplicitCastInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(LongType)
+
+  override def dataType: DataType = LongType
+
+  override def prettyName: String = "bitmap_bucket_number"
+
+  override lazy val replacement: Expression = StaticInvoke(
+    classOf[BitmapExpressionUtils],
+    LongType,
+    "bitmapBucketNumber",
+    Seq(child),
+    inputTypes,
+    returnNullable = false)
+
+  override protected def withNewChildInternal(newChild: Expression): 
BitmapBucketNumber =
+    copy(child = newChild)
+}
+
+@ExpressionDescription(
+  usage = "_FUNC_(child) - Returns the bit position for the given input child 
expression.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(1);
+       0
+      > SELECT _FUNC_(123);
+       122
+  """,
+  since = "3.5.0",
+  group = "misc_funcs"
+)
+case class BitmapBitPosition(child: Expression)
+  extends UnaryExpression with RuntimeReplaceable with ImplicitCastInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(LongType)
+
+  override def dataType: DataType = LongType
+
+  override def prettyName: String = "bitmap_bit_position"
+
+  override lazy val replacement: Expression = StaticInvoke(
+    classOf[BitmapExpressionUtils],
+    LongType,
+    "bitmapBitPosition",
+    Seq(child),
+    inputTypes,
+    returnNullable = false)
+
+  override protected def withNewChildInternal(newChild: Expression): 
BitmapBitPosition =
+    copy(child = newChild)
+}
+
+@ExpressionDescription(
+  usage = "_FUNC_(child) - Returns the number of set bits in the child 
bitmap.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(X '1010');
+       2
+      > SELECT _FUNC_(X 'FFFF');
+       16
+      > SELECT _FUNC_(X '0');
+       0
+  """,
+  since = "3.5.0",
+  group = "misc_funcs"
+)
+case class BitmapCount(child: Expression)
+  extends UnaryExpression with RuntimeReplaceable {
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (child.dataType != BinaryType) {
+      TypeCheckResult.TypeCheckFailure("Bitmap must be a BinaryType")
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override def dataType: DataType = LongType
+
+  override def prettyName: String = "bitmap_count"
+
+  override lazy val replacement: Expression = StaticInvoke(
+    classOf[BitmapExpressionUtils],
+    LongType,
+    "bitmapCount",
+    Seq(child),
+    Seq(BinaryType),
+    returnNullable = false)
+
+  override protected def withNewChildInternal(newChild: Expression): 
BitmapCount =
+    copy(child = newChild)
+}
+
+@ExpressionDescription(
+  usage = """
+    _FUNC_(child) - Returns a bitmap with the positions of the bits set from 
all the values from
+    the child expression. The child expression will most likely be 
bitmap_bit_position().
+  """,
+  // scalastyle:off line.size.limit
+  examples = """
+    Examples:
+      > SELECT substring(hex(_FUNC_((bitmap_bit_position(col)))), 0, 6) FROM 
VALUES (1), (2), (3) AS tab(col);
+       070000
+      > SELECT substring(hex(_FUNC_((bitmap_bit_position(col)))), 0, 6) FROM 
VALUES (1), (1), (1) AS tab(col);
+       010000
+  """,
+  // scalastyle:on line.size.limit
+  since = "3.5.0",
+  group = "agg_funcs"
+)
+case class BitmapConstructAgg(child: Expression,
+                              mutableAggBufferOffset: Int = 0,
+                              inputAggBufferOffset: Int = 0)
+  extends ImperativeAggregate with ImplicitCastInputTypes with 
UnaryLike[Expression] {
+
+  def this(child: Expression) = {
+    this(child = child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(LongType)
+
+  override def dataType: DataType = BinaryType
+
+  override def prettyName: String = "bitmap_construct_agg"
+
+  override protected def withNewChildInternal(newChild: Expression): 
BitmapConstructAgg =
+    copy(child = newChild)
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def nullable: Boolean = false
+
+  override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
+
+  override def aggBufferAttributes: Seq[AttributeReference] = bitmapAttr :: Nil
+
+  override def defaultResult: Option[Literal] =
+    Option(Literal(Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0)))
+
+  override def inputAggBufferAttributes: Seq[AttributeReference] =
+    aggBufferAttributes.map(_.newInstance())
+
+  // The aggregation buffer is a fixed size binary.
+  private val bitmapAttr = AttributeReference("bitmap", BinaryType, nullable = 
false)()
+
+  override def initialize(buffer: InternalRow): Unit = {
+    buffer.update(mutableAggBufferOffset, 
Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0))
+  }
+
+  override def update(buffer: InternalRow, input: InternalRow): Unit = {
+    val position = child.eval(input)
+    if (position != null) {
+      val bitmap = buffer.getBinary(mutableAggBufferOffset)
+      val bitPosition = position.asInstanceOf[Long]
+
+      if (bitPosition < 0 || bitPosition >= (8 * bitmap.length)) {
+        throw QueryExecutionErrors.invalidBitmapPositionError(bitPosition, 
bitmap.length)
+      }
+
+      val bytePosition = (bitPosition / 8).toInt
+      val bit = (bitPosition % 8).toInt
+      bitmap.update(bytePosition, (bitmap(bytePosition) | (1 << bit)).toByte)
+    }
+  }
+
+  override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
+    val bitmap1 = buffer1.getBinary(mutableAggBufferOffset)
+    val bitmap2 = buffer2.getBinary(inputAggBufferOffset)
+    BitmapExpressionUtils.bitmapMerge(bitmap1, bitmap2)
+  }
+
+  override def eval(buffer: InternalRow): Any = {
+    buffer.getBinary(mutableAggBufferOffset)
+  }
+}
+
+@ExpressionDescription(
+  usage = """
+    _FUNC_(child) - Returns a bitmap that is the bitwise OR of all of the 
bitmaps from the child
+    expression. The input should be bitmaps created from 
bitmap_construct_agg().
+  """,
+  // scalastyle:off line.size.limit
+  examples = """
+    Examples:
+      > SELECT substring(hex(_FUNC_(col)), 0, 6) FROM VALUES (X '10'), (X 
'20'), (X '40') AS tab(col);
+       700000
+      > SELECT substring(hex(_FUNC_(col)), 0, 6) FROM VALUES (X '10'), (X 
'10'), (X '10') AS tab(col);
+       100000
+  """,
+  // scalastyle:on line.size.limit
+  since = "3.5.0",
+  group = "agg_funcs"
+)
+case class BitmapOrAgg(child: Expression,
+                       mutableAggBufferOffset: Int = 0,
+                       inputAggBufferOffset: Int = 0)
+  extends ImperativeAggregate with UnaryLike[Expression] {
+
+  def this(child: Expression) = {
+    this(child = child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (child.dataType != BinaryType) {
+      TypeCheckResult.TypeCheckFailure("Bitmap must be a BinaryType")
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override def dataType: DataType = BinaryType
+
+  override def prettyName: String = "bitmap_or_agg"
+
+  override protected def withNewChildInternal(newChild: Expression): 
BitmapOrAgg =
+    copy(child = newChild)
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def nullable: Boolean = false
+
+  override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
+
+  override def aggBufferAttributes: Seq[AttributeReference] = bitmapAttr :: Nil
+
+  override def defaultResult: Option[Literal] =
+    Option(Literal(Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0)))
+
+  override def inputAggBufferAttributes: Seq[AttributeReference] =
+    aggBufferAttributes.map(_.newInstance())
+
+  // The aggregation buffer is a fixed size binary.
+  private val bitmapAttr = AttributeReference("bitmap", BinaryType, false)()
+
+  override def initialize(buffer: InternalRow): Unit = {
+    buffer.update(mutableAggBufferOffset, 
Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0))
+  }
+
+  override def update(buffer: InternalRow, input: InternalRow): Unit = {
+    val input_bitmap = child.eval(input).asInstanceOf[Array[Byte]]
+    if (input_bitmap != null) {
+      val bitmap = buffer.getBinary(mutableAggBufferOffset)
+      BitmapExpressionUtils.bitmapMerge(bitmap, input_bitmap)
+    }
+  }
+
+  override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
+    val bitmap1 = buffer1.getBinary(mutableAggBufferOffset)
+    val bitmap2 = buffer2.getBinary(inputAggBufferOffset)
+    BitmapExpressionUtils.bitmapMerge(bitmap1, bitmap2)
+  }
+
+  override def eval(buffer: InternalRow): Any = {
+    buffer.getBinary(mutableAggBufferOffset)
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 630cf2fa55a..eded5e6534f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -262,6 +262,18 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase {
       summary = getSummary(context))
   }
 
+  def invalidBitmapPositionError(bitPosition: Long,
+                                 bitmapNumBytes: Long): 
ArrayIndexOutOfBoundsException = {
+    new SparkArrayIndexOutOfBoundsException(
+      errorClass = "INVALID_BITMAP_POSITION",
+      messageParameters = Map(
+        "bitPosition" -> s"$bitPosition",
+        "bitmapNumBytes" -> s"$bitmapNumBytes",
+        "bitmapNumBits" -> s"${bitmapNumBytes * 8}"),
+      context = Array.empty,
+      summary = "")
+  }
+
   def invalidFractionOfSecondError(): DateTimeException = {
     new SparkDateTimeException(
       errorClass = "INVALID_FRACTION_OF_SECOND",
diff --git 
a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala
 
b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala
new file mode 100644
index 00000000000..ee1f4026fed
--- /dev/null
+++ 
b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+
+class BitmapExpressionUtilsSuite extends SparkFunSuite {
+
+  test("bitmap_bucket_number with positive inputs") {
+    Seq((0L, 0L), (1L, 1L), (2L, 1L), (3L, 1L),
+      (32768L, 1L), (32769L, 2L), (32770L, 2L)).foreach {
+      case (input, expected) =>
+        assert(BitmapExpressionUtils.bitmapBucketNumber(input) == expected)
+    }
+  }
+
+  test("bitmap_bucket_number with negative inputs") {
+    Seq((-1L, 0L), (-2L, 0L), (-3L, 0L),
+      (-32767L, 0L), (-32768L, -1L), (-32769L, -1L)).foreach {
+      case (input, expected) =>
+        assert(BitmapExpressionUtils.bitmapBucketNumber(input) == expected)
+    }
+  }
+
+  test("bitmap_bit_position with positive inputs") {
+    Seq((0L, 0L), (1L, 0L), (2L, 1L), (3L, 2L),
+      (32768L, 32767L), (32769L, 0L), (32770L, 1L)).foreach {
+      case (input, expected) =>
+        assert(BitmapExpressionUtils.bitmapBitPosition(input) == expected)
+    }
+  }
+
+  test("bitmap_bit_position with negative inputs") {
+    Seq((-1L, 1L), (-2L, 2L), (-3L, 3L),
+      (-32767L, 32767L), (-32768L, 0L), (-32769L, 1L)).foreach {
+      case (input, expected) =>
+        assert(BitmapExpressionUtils.bitmapBitPosition(input) == expected)
+    }
+  }
+
+  private def createBitmap(): Array[Byte] = {
+    Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0)
+  }
+
+  private def clearBitmap(bitmap: Array[Byte]): Unit = {
+    for (i <- bitmap.indices) {
+      bitmap(i) = 0
+    }
+  }
+
+  private def setBitmapBits(bitmap: Array[Byte], bytePos: Int, bits: Int): 
Unit = {
+    bitmap.update(bytePos, (bitmap(bytePos) & 0x0ff | bits & 0x0ff).toByte)
+  }
+
+  test("bitmap_count empty") {
+    val bitmap = createBitmap()
+    assert(BitmapExpressionUtils.bitmapCount(bitmap) == 0L)
+  }
+
+  test("bitmap_count") {
+    val bitmap = createBitmap()
+    setBitmapBits(bitmap, 0, 0x01)
+    assert(BitmapExpressionUtils.bitmapCount(bitmap) == 1L)
+
+    clearBitmap(bitmap)
+    setBitmapBits(bitmap, 0, 0xff)
+    assert(BitmapExpressionUtils.bitmapCount(bitmap) == 8L)
+
+    setBitmapBits(bitmap, 1, 0x22)
+    assert(BitmapExpressionUtils.bitmapCount(bitmap) == 10L)
+
+    setBitmapBits(bitmap, bitmap.length - 1, 0x67)
+    assert(BitmapExpressionUtils.bitmapCount(bitmap) == 15L)
+  }
+}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 32c4c02b1b2..f979a138e20 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -46,6 +46,11 @@
 | org.apache.spark.sql.catalyst.expressions.Base64 | base64 | SELECT 
base64('Spark SQL') | struct<base64(Spark SQL):string> |
 | org.apache.spark.sql.catalyst.expressions.Bin | bin | SELECT bin(13) | 
struct<bin(13):string> |
 | org.apache.spark.sql.catalyst.expressions.BitLength | bit_length | SELECT 
bit_length('Spark SQL') | struct<bit_length(Spark SQL):int> |
+| org.apache.spark.sql.catalyst.expressions.BitmapBitPosition | 
bitmap_bit_position | SELECT bitmap_bit_position(1) | 
struct<bitmap_bit_position(1):bigint> |
+| org.apache.spark.sql.catalyst.expressions.BitmapBucketNumber | 
bitmap_bucket_number | SELECT bitmap_bucket_number(123) | 
struct<bitmap_bucket_number(123):bigint> |
+| org.apache.spark.sql.catalyst.expressions.BitmapConstructAgg | 
bitmap_construct_agg | SELECT 
substring(hex(bitmap_construct_agg((bitmap_bit_position(col)))), 0, 6) FROM 
VALUES (1), (2), (3) AS tab(col) | 
struct<substring(hex(bitmap_construct_agg(bitmap_bit_position(col))), 0, 
6):string> |
+| org.apache.spark.sql.catalyst.expressions.BitmapCount | bitmap_count | 
SELECT bitmap_count(X '1010') | struct<bitmap_count(X'1010'):bigint> |
+| org.apache.spark.sql.catalyst.expressions.BitmapOrAgg | bitmap_or_agg | 
SELECT substring(hex(bitmap_or_agg(col)), 0, 6) FROM VALUES (X '10'), (X '20'), 
(X '40') AS tab(col) | struct<substring(hex(bitmap_or_agg(col)), 0, 6):string> |
 | org.apache.spark.sql.catalyst.expressions.BitwiseAnd | & | SELECT 3 & 5 | 
struct<(3 & 5):int> |
 | org.apache.spark.sql.catalyst.expressions.BitwiseCount | bit_count | SELECT 
bit_count(0) | struct<bit_count(0):int> |
 | org.apache.spark.sql.catalyst.expressions.BitwiseGet | bit_get | SELECT 
bit_get(11, 0) | struct<bit_get(11, 0):tinyint> |
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/BitmapExpressionsQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/BitmapExpressionsQuerySuite.scala
new file mode 100644
index 00000000000..76b9019475a
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/BitmapExpressionsQuerySuite.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.test.SharedSparkSession
+
+class BitmapExpressionsQuerySuite extends QueryTest with SharedSparkSession {
+  import testImplicits._
+
+  test("bitmap_construct_agg") {
+    val table = "bitmaps_table"
+    withTable(table) {
+      (0 until 10000).toDF("id").selectExpr("100 * cast(id / 2 as int) col")
+        .createOrReplaceTempView(table)
+
+      val expected = spark.sql(
+        s"""
+           | select count (distinct col) c from $table
+           |""".stripMargin).collect()
+
+      val df = spark.sql(
+        s"""
+          | select sum(c) from (
+          |   select bitmap_bucket_number(col) bn,
+          |   bitmap_count(bitmap_construct_agg(bitmap_bit_position(col))) c
+          |   from $table
+          |   group by 1
+          | )
+          |""".stripMargin)
+      checkAnswer(df, expected)
+    }
+  }
+
+  test("grouping bitmap_construct_agg") {
+    val table = "bitmaps_table"
+    withTable(table) {
+      (0 until 10000).toDF("id").selectExpr(
+        "(id % 4) part",
+        "100 * cast(id / 8 as int) col")
+        .createOrReplaceTempView(table)
+
+      val expected = spark.sql(
+        s"""
+           | select part, count (distinct col) c from $table group by 1 order 
by 1
+           |""".stripMargin).collect()
+
+      val df = spark.sql(
+        s"""
+           | select part, sum(c) from (
+           |   select part, bitmap_bucket_number(col) bn,
+           |   bitmap_count(bitmap_construct_agg(bitmap_bit_position(col))) c
+           |   from $table group by 1, 2 order by 1, 2
+           | ) group by 1 order by 1
+           |""".stripMargin)
+      checkAnswer(df, expected)
+    }
+  }
+
+  test("precomputed bitmaps") {
+    val table = "bitmaps_table"
+    val precomputed = "precomputed_table"
+    withTable(table) {
+      withTable(precomputed) {
+        (0 until 10000).toDF("id").selectExpr(
+          "(id % 4) part1",
+          "((id + 7) % 3) part2",
+          "100 * cast(id / 17 as int) col")
+          .createOrReplaceTempView(table)
+        spark.sql(
+          s"""
+             | select part1, part2, bitmap_bucket_number(col) bn,
+             | bitmap_construct_agg(bitmap_bit_position(col)) bm
+             | from $table group by 1, 2, 3
+             |""".stripMargin).createOrReplaceTempView(precomputed)
+
+        // Compute over both partitions
+        {
+          val expected = spark.sql(
+            s"""
+               | select part1, part2, count (distinct col) c from $table group 
by 1, 2 order by 1, 2
+               |""".stripMargin).collect()
+
+          val df = spark.sql(
+            s"""
+               | select part1, part2, sum(bitmap_count(bm))
+               | from $precomputed group by 1, 2 order by 1, 2
+               |""".stripMargin)
+          checkAnswer(df, expected)
+        }
+
+        // Compute over one of the partitions
+        Seq("part1", "part2").foreach {
+          case part =>
+            val expected = spark.sql(
+              s"""
+                 | select $part, count (distinct col) c from $table group by 1 
order by 1
+                 |""".stripMargin).collect()
+
+            val df = spark.sql(
+              s"""
+                 | select $part, sum(c) from (
+                 |   select $part, bn, bitmap_count(bitmap_or_agg(bm)) c
+                 |   from $precomputed group by 1, 2
+                 | ) group by 1 order by 1
+                 |""".stripMargin)
+            checkAnswer(df, expected)
+        }
+      }
+    }
+  }
+
+  test("bitmap functions with floats") {
+    val table = "bitmaps_table"
+    withTable(table) {
+      (0 until 10000).toDF("id").selectExpr(
+        "(id % 4) part",
+        "100 * id + cast(id / 8.0 as float) col")
+        .createOrReplaceTempView(table)
+
+      val expected = spark.sql(
+        s"""
+           | select part, count (distinct col) c from $table group by 1 order 
by 1
+           |""".stripMargin).collect()
+
+      val df = spark.sql(
+        s"""
+           | select part, sum(c) from (
+           |   select part, bitmap_bucket_number(col) bn,
+           |   bitmap_count(bitmap_construct_agg(bitmap_bit_position(col))) c
+           |   from $table group by 1, 2 order by 1, 2
+           | ) group by 1 order by 1
+           |""".stripMargin)
+      checkAnswer(df, expected)
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 37cdfcf10c7..df82e3c268f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -937,6 +937,34 @@ class QueryExecutionErrorsSuite
       sqlState = "XX000")
   }
 
+  test("INVALID_BITMAP_POSITION: position out of bounds") {
+    val e = intercept[SparkException] {
+      sql("select bitmap_construct_agg(col) from values (32768) as 
tab(col)").collect()
+    }.getCause.asInstanceOf[SparkArrayIndexOutOfBoundsException]
+    checkError(
+      exception = e,
+      errorClass = "INVALID_BITMAP_POSITION",
+      parameters = Map(
+        "bitPosition" -> "32768",
+        "bitmapNumBytes" -> "4096",
+        "bitmapNumBits" -> "32768"),
+      sqlState = "22003")
+  }
+
+  test("INVALID_BITMAP_POSITION: negative position") {
+    val e = intercept[SparkException] {
+      sql("select bitmap_construct_agg(col) from values (-1) as 
tab(col)").collect()
+    }.getCause.asInstanceOf[SparkArrayIndexOutOfBoundsException]
+    checkError(
+      exception = e,
+      errorClass = "INVALID_BITMAP_POSITION",
+      parameters = Map(
+        "bitPosition" -> "-1",
+        "bitmapNumBytes" -> "4096",
+        "bitmapNumBits" -> "32768"),
+      sqlState = "22003")
+  }
+
   test("SPARK-43589: Use bytesToString instead of shift operation") {
     checkError(
       exception = intercept[SparkException] {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to