[
https://issues.apache.org/jira/browse/SPARK-49261?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
Dongjoon Hyun reassigned SPARK-49261:
-------------------------------------
Assignee: Bruce Robbins
> Correlation between lit and round during grouping
> -------------------------------------------------
>
> Key: SPARK-49261
> URL: https://issues.apache.org/jira/browse/SPARK-49261
> Project: Spark
> Issue Type: Bug
> Components: PySpark
> Affects Versions: 3.5.0
> Environment: Databricks DBR 14.3
> Spark 3.5.0
> Scala 2.12
> Reporter: Krystian Kulig
> Assignee: Bruce Robbins
> Priority: Major
> Labels: pull-request-available
>
> Running following code:
>
> {code:java}
> import pyspark.sql.functions as F
> from decimal import Decimal
> data = [
> (1, 100, Decimal("1.1"), "L", True),
> (2, 200, Decimal("1.2"), "H", False),
> (2, 300, Decimal("2.345"), "E", False),
> ]
> columns = ["group_a", "id", "amount", "selector_a", "selector_b"]
> df = spark.createDataFrame(data, schema=columns)
> df_final = (
> df.select(
> F.lit(6).alias("run_number"),
> F.lit("AA").alias("run_type"),
> F.col("group_a"),
> F.col("id"),
> F.col("amount"),
> F.col("selector_a"),
> F.col("selector_b"),
> )
> .withColumn(
> "amount_c",
> F.when(
> (F.col("selector_b") == False)
> & (F.col("selector_a").isin(["L", "H", "E"])),
> F.col("amount"),
> ).otherwise(F.lit(None))
> )
> .withColumn(
> "count_of_amount_c",
> F.when(
> (F.col("selector_b") == False)
> & (F.col("selector_a").isin(["L", "H", "E"])),
> F.col("id")
> ).otherwise(F.lit(None))
> )
> )
> group_by_cols = [
> "run_number",
> "group_a",
> "run_type"
> ]
> df_final = df_final.groupBy(group_by_cols).agg(
> F.countDistinct("id").alias("count_of_amount"),
> F.round(F.sum("amount")/ 1000, 1).alias("total_amount"),
> F.sum("amount_c").alias("amount_c"),
> F.countDistinct("count_of_amount_c").alias(
> "count_of_amount_c"
> ),
> )
> df_final = (
> df_final
> .withColumn(
> "total_amount",
> F.round(F.col("total_amount") / 1000, 6),
> )
> .withColumn(
> "count_of_amount", F.col("count_of_amount").cast("int")
> )
> .withColumn(
> "count_of_amount_c",
> F.when(
> F.col("amount_c").isNull(), F.lit(None).cast("int")
> ).otherwise(F.col("count_of_amount_c").cast("int")),
> )
> )
> df_final = df_final.select(
> F.col("total_amount"),
> "run_number",
> "group_a",
> "run_type",
> "count_of_amount",
> "amount_c",
> "count_of_amount_c",
> )
> df_final.show() {code}
> Produces error:
> {code:java}
> [[INTERNAL_ERROR](https://docs.microsoft.com/azure/databricks/error-messages/error-classes#internal_error)]
> Couldn't find total_amount#1046 in
> [group_a#984L,count_of_amount#1054,amount_c#1033,count_of_amount_c#1034L]
> SQLSTATE: XX000 {code}
> With stack trace:
> {code:java}
> org.apache.spark.SparkException: [INTERNAL_ERROR] Couldn't find
> total_amount#1046 in
> [group_a#984L,count_of_amount#1054,amount_c#1033,count_of_amount_c#1034L]
> SQLSTATE: XX000 at
> org.apache.spark.SparkException$.internalError(SparkException.scala:97) at
> org.apache.spark.SparkException$.internalError(SparkException.scala:101) at
> org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:81)
> at
> org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:74)
> at
> org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:505)
> at
> org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:83)
> at
> org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:505)
> at
> org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:481)
> at
> org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:449) at
> org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:74)
> at
> org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:97)
> at
> scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286) at
> scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at
> scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at
> scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at
> scala.collection.TraversableLike.map(TraversableLike.scala:286) at
> scala.collection.TraversableLike.map$(TraversableLike.scala:279) at
> scala.collection.AbstractTraversable.map(Traversable.scala:108) at
> org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:97)
> at
> org.apache.spark.sql.execution.ProjectExec.doConsume(basicPhysicalOperators.scala:74)
> at
> org.apache.spark.sql.execution.CodegenSupport.consume(WholeStageCodegenExec.scala:202)
> at
> org.apache.spark.sql.execution.CodegenSupport.consume$(WholeStageCodegenExec.scala:155)
> at
> org.apache.spark.sql.execution.aggregate.HashAggregateExec.consume(HashAggregateExec.scala:51)
> at
> org.apache.spark.sql.execution.aggregate.HashAggregateExec.generateResultFunction(HashAggregateExec.scala:411)
> at
> org.apache.spark.sql.execution.aggregate.HashAggregateExec.doConsumeWithKeys(HashAggregateExec.scala:995)
> at
> org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doConsume(AggregateCodegenSupport.scala:81)
> at
> org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doConsume$(AggregateCodegenSupport.scala:77)
> at
> org.apache.spark.sql.execution.aggregate.HashAggregateExec.doConsume(HashAggregateExec.scala:51)
> at
> org.apache.spark.sql.execution.CodegenSupport.constructDoConsumeFunction(WholeStageCodegenExec.scala:229)
> at
> org.apache.spark.sql.execution.CodegenSupport.consume(WholeStageCodegenExec.scala:200)
> at
> org.apache.spark.sql.execution.CodegenSupport.consume$(WholeStageCodegenExec.scala:155)
> at
> org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:506)
> at
> org.apache.spark.sql.execution.InputRDDCodegen.doProduce(WholeStageCodegenExec.scala:493)
> at
> org.apache.spark.sql.execution.InputRDDCodegen.doProduce$(WholeStageCodegenExec.scala:466)
> at
> org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:506)
> at
> org.apache.spark.sql.execution.CodegenSupport.$anonfun$produce$1(WholeStageCodegenExec.scala:100)
> at
> org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130)
> at
> org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385)
> at
> org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
> at
> org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at
> org.apache.spark.sql.execution.CodegenSupport.produce(WholeStageCodegenExec.scala:95)
> at
> org.apache.spark.sql.execution.CodegenSupport.produce$(WholeStageCodegenExec.scala:94)
> at
> org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:506)
> at
> org.apache.spark.sql.execution.aggregate.HashAggregateExec.doProduceWithKeys(HashAggregateExec.scala:629)
> at
> org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doProduce(AggregateCodegenSupport.scala:73)
> at
> org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doProduce$(AggregateCodegenSupport.scala:69)
> at
> org.apache.spark.sql.execution.aggregate.HashAggregateExec.doProduce(HashAggregateExec.scala:51)
> at
> org.apache.spark.sql.execution.CodegenSupport.$anonfun$produce$1(WholeStageCodegenExec.scala:100)
> at
> org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130)
> at
> org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385)
> at
> org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
> at
> org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at
> org.apache.spark.sql.execution.CodegenSupport.produce(WholeStageCodegenExec.scala:95)
> at
> org.apache.spark.sql.execution.CodegenSupport.produce$(WholeStageCodegenExec.scala:94)
> at
> org.apache.spark.sql.execution.aggregate.HashAggregateExec.produce(HashAggregateExec.scala:51)
> at
> org.apache.spark.sql.execution.ProjectExec.doProduce(basicPhysicalOperators.scala:59)
> at
> org.apache.spark.sql.execution.CodegenSupport.$anonfun$produce$1(WholeStageCodegenExec.scala:100)
> at
> org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130)
> at
> org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385)
> at
> org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
> at
> org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at
> org.apache.spark.sql.execution.CodegenSupport.produce(WholeStageCodegenExec.scala:95)
> at
> org.apache.spark.sql.execution.CodegenSupport.produce$(WholeStageCodegenExec.scala:94)
> at
> org.apache.spark.sql.execution.ProjectExec.produce(basicPhysicalOperators.scala:46)
> at
> org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:666)
> at
> org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:729)
> at
> org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$2(SparkPlan.scala:327)
> at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
> at
> org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:327)
> at
> org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130)
> at
> org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385)
> at
> org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
> at
> org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at
> org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:322) at
> org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:117)
> at
> org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:131)
> at
> org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:94)
> at
> org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:90)
> at
> org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:78)
> at
> org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$computeResult$1(ResultCacheManager.scala:549)
> at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
> at
> org.apache.spark.sql.execution.qrc.ResultCacheManager.collectResult$1(ResultCacheManager.scala:540)
> at
> org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$computeResult$2(ResultCacheManager.scala:555)
> at
> org.apache.spark.sql.execution.adaptive.ResultQueryStageExec.$anonfun$doMaterialize$1(QueryStageExec.scala:663)
> at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:1175) at
> org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$6(SQLExecution.scala:778)
> at
> com.databricks.util.LexicalThreadLocal$Handle.runWith(LexicalThreadLocal.scala:63)
> at
> org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$5(SQLExecution.scala:778)
> at
> com.databricks.util.LexicalThreadLocal$Handle.runWith(LexicalThreadLocal.scala:63)
> at
> org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$4(SQLExecution.scala:778)
> at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62) at
> org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$3(SQLExecution.scala:777)
> at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62) at
> org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$2(SQLExecution.scala:776)
> at
> org.apache.spark.sql.execution.SQLExecution$.withOptimisticTransaction(SQLExecution.scala:798)
> at
> org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$1(SQLExecution.scala:775)
> at
> java.util.concurrent.CompletableFuture$AsyncSupply.run(CompletableFuture.java:1604)
> at
> org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.$anonfun$run$1(SparkThreadLocalForwardingThreadPoolExecutor.scala:134)
> at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) at
> com.databricks.spark.util.IdentityClaim$.withClaim(IdentityClaim.scala:48) at
> org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.$anonfun$runWithCaptured$4(SparkThreadLocalForwardingThreadPoolExecutor.scala:91)
> at
> com.databricks.unity.UCSEphemeralState$Handle.runWith(UCSEphemeralState.scala:45)
> at
> org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.runWithCaptured(SparkThreadLocalForwardingThreadPoolExecutor.scala:90)
> at
> org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.runWithCaptured$(SparkThreadLocalForwardingThreadPoolExecutor.scala:67)
> at
> org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.runWithCaptured(SparkThreadLocalForwardingThreadPoolExecutor.scala:131)
> at
> org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.run(SparkThreadLocalForwardingThreadPoolExecutor.scala:134)
> at
> java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
> at
> java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
> at java.lang.Thread.run(Thread.java:750)
> {code}
>
> It seems to be a correlation between *F.lit(6).alias("run_number")* and
> {*}F.round(F.col("total_amount") / 1000, 6){*}. If both *lit* and *scale* in
> *round* are set to the same number i.e. *6* code fails.
> If numbers are different all works.
> Moving *F.lit(6).alias("run_number")* to the final *select* also solves the
> problem when both numbers in *lit* and *scale* in *round* are the same.
> Example of the working code:
> {code:java}
> import pyspark.sql.functions as F
> from decimal import Decimal
> data = [ (1, 100, Decimal("1.1"), "L", True),
> (2, 200, Decimal("1.2"), "H", False),
> (2, 300, Decimal("2.345"), "E", False),
> ]
> columns = ["group_a", "id", "amount", "selector_a", "selector_b"]
> df = spark.createDataFrame(data, schema=columns)
> df_final = (
> df.select(
> F.lit(7).alias("run_number"),
> F.lit("AA").alias("run_type"),
> F.col("group_a"),
> F.col("id"),
> F.col("amount"),
> F.col("selector_a"),
> F.col("selector_b"),
> )
> .withColumn(
> "amount_c",
> F.when(
> (F.col("selector_b") == False)
> & (F.col("selector_a").isin(["L", "H", "E"])),
> F.col("amount"),
> ).otherwise(F.lit(None))
> )
> .withColumn(
> "count_of_amount_c",
> F.when(
> (F.col("selector_b") == False)
> & (F.col("selector_a").isin(["L", "H", "E"])),
> F.col("id")
> ).otherwise(F.lit(None))
> )
> )
> group_by_cols = [
> "run_number",
> "group_a",
> "run_type"
> ]
> df_final = df_final.groupBy(group_by_cols).agg(
> F.countDistinct("id").alias("count_of_amount"),
> F.round(F.sum("amount")/ 1000, 1).alias("total_amount"),
> F.sum("amount_c").alias("amount_c"),
> F.countDistinct("count_of_amount_c").alias(
> "count_of_amount_c"
> ),
> )
> df_final = (
> df_final
> .withColumn(
> "total_amount",
> F.round(F.col("total_amount") / 1000, 6),
> )
> .withColumn(
> "count_of_amount", F.col("count_of_amount").cast("int")
> )
> .withColumn(
> "count_of_amount_c",
> F.when(
> F.col("amount_c").isNull(), F.lit(None).cast("int")
> ).otherwise(F.col("count_of_amount_c").cast("int")),
> )
> )
> df_final = df_final.select(
> F.col("total_amount"),
> "run_number",
> "group_a",
> "run_type",
> "count_of_amount",
> "amount_c",
> "count_of_amount_c",
> )
> df_final.show() {code}
> Output:
> {code:java}
> +------------+----------+-------+--------+---------------+--------------------+-----------------+
> |total_amount|run_number|group_a|run_type|count_of_amount|
> amount_c|count_of_amount_c|
> +------------+----------+-------+--------+---------------+--------------------+-----------------+
> | 0.000000| 7| 2| AA|
> 2|3.545000000000000000| 2|
> | 0.000000| 7| 1| AA| 1|
> NULL| NULL|
> +------------+----------+-------+--------+---------------+--------------------+-----------------+{code}
> Expected behavior:
> Values used in the *lit* function shouldn't interfere with the *scale*
> parameter in the *round* function
>
>
>
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]