This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 9125e6a0 feat: Add random row generator in data generator (#451)
9125e6a0 is described below

commit 9125e6a04dce7c86bb0e4f4f297c8c2b70a39559
Author: advancedxy <[email protected]>
AuthorDate: Fri May 24 07:00:16 2024 +0800

    feat: Add random row generator in data generator (#451)
    
    * feat: Add random row generator in data gen
    
    * fix
    
    * remove array type match case, which should already been handled in 
RandomDataGenerator.forType
    
    * fix style issue
    
    * address comments
---
 .../scala/org/apache/comet/DataGenerator.scala     | 41 ++++++++++++++++++
 .../org/apache/comet/DataGeneratorSuite.scala      | 49 ++++++++++++++++++++++
 2 files changed, 90 insertions(+)

diff --git a/spark/src/test/scala/org/apache/comet/DataGenerator.scala 
b/spark/src/test/scala/org/apache/comet/DataGenerator.scala
index 691a371b..80e7c228 100644
--- a/spark/src/test/scala/org/apache/comet/DataGenerator.scala
+++ b/spark/src/test/scala/org/apache/comet/DataGenerator.scala
@@ -21,14 +21,20 @@ package org.apache.comet
 
 import scala.util.Random
 
+import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.types.{StringType, StructType}
+
 object DataGenerator {
   // note that we use `def` rather than `val` intentionally here so that
   // each test suite starts with a fresh data generator to help ensure
   // that tests are deterministic
   def DEFAULT = new DataGenerator(new Random(42))
+  // matches the probability of nulls in Spark's RandomDataGenerator
+  private val PROBABILITY_OF_NULL: Float = 0.1f
 }
 
 class DataGenerator(r: Random) {
+  import DataGenerator._
 
   /** Generate a random string using the specified characters */
   def generateString(chars: String, maxLen: Int): String = {
@@ -95,4 +101,39 @@ class DataGenerator(r: Random) {
       Range(0, n).map(_ => r.nextLong())
   }
 
+  // Generate a random row according to the schema, the string filed in the 
struct could be
+  // configured to generate strings by passing a stringGen function. Other 
types are delegated
+  // to Spark's RandomDataGenerator.
+  def generateRow(schema: StructType, stringGen: Option[() => String] = None): 
Row = {
+    val fields = schema.fields.map { f =>
+      f.dataType match {
+        case StructType(children) =>
+          generateRow(StructType(children), stringGen)
+        case StringType if stringGen.isDefined =>
+          val gen = stringGen.get
+          val data = if (f.nullable && r.nextFloat() <= PROBABILITY_OF_NULL) {
+            null
+          } else {
+            gen()
+          }
+          data
+        case _ =>
+          val gen = RandomDataGenerator.forType(f.dataType, f.nullable, r) 
match {
+            case Some(g) => g
+            case None =>
+              throw new IllegalStateException(s"No RandomDataGenerator for 
type ${f.dataType}")
+          }
+          gen()
+      }
+    }.toSeq
+    Row.fromSeq(fields)
+  }
+
+  def generateRows(
+      num: Int,
+      schema: StructType,
+      stringGen: Option[() => String] = None): Seq[Row] = {
+    Range(0, num).map(_ => generateRow(schema, stringGen))
+  }
+
 }
diff --git a/spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala 
b/spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala
new file mode 100644
index 00000000..02dfb9d7
--- /dev/null
+++ b/spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet
+
+import org.apache.spark.sql.CometTestBase
+import org.apache.spark.sql.types.StructType
+
+class DataGeneratorSuite extends CometTestBase {
+
+  test("test configurable stringGen in row generator") {
+    val gen = DataGenerator.DEFAULT
+    val chars = "abcde"
+    val maxLen = 10
+    val stringGen = () => gen.generateString(chars, maxLen)
+    val numRows = 100
+    val schema = new StructType().add("a", "string")
+    var numNulls = 0
+    gen
+      .generateRows(numRows, schema, Some(stringGen))
+      .foreach(row => {
+        if (row.getString(0) != null) {
+          assert(row.getString(0).forall(chars.toSeq.contains))
+          assert(row.getString(0).length <= maxLen)
+        } else {
+          numNulls += 1
+        }
+      })
+    // 0.1 null probability
+    assert(numNulls >= 0.05 * numRows && numNulls <= 0.15 * numRows)
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to