This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b249cb8af35 [SPARK-46538][ML] Fix the ambiguous column reference issue 
in `ALSModel.transform`
b249cb8af35 is described below

commit b249cb8af35588583a63785fdf9b683955fb7ce1
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Fri Dec 29 09:27:22 2023 +0800

    [SPARK-46538][ML] Fix the ambiguous column reference issue in 
`ALSModel.transform`
    
    ### What changes were proposed in this pull request?
    the column references  in `ALSModel.transform` maybe ambiguous in some case
    
    ### Why are the changes needed?
    to fix a bug
    
    before this fix, the test fails with:
    ```
    JVM stacktrace:
    org.apache.spark.sql.AnalysisException: 
[MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_APPEAR_IN_OPERATION] Resolved 
attribute(s) "features", "features" missing from "user", "item", "id", 
"features", "id", "features" in operator !Project [user#60, item#63, 
UDF(features#50, features#54) AS prediction#94]. Attribute(s) with the same 
name appear in the operation: "features", "features".
    Please check if the right attribute(s) are used. SQLSTATE: XX000;
    ```
    
    and
    
    ```
    
    pyspark.errors.exceptions.captured.AnalysisException: Column features#50, 
features#46 are ambiguous. It's probably because you joined several Datasets 
together, and some of these Datasets are the same. This column points to one of 
the Datasets but Spark is unable to figure out which one. Please alias the 
Datasets with different names via `Dataset.as` before joining them, and specify 
the column using qualified name, e.g. `df.as("a").join(df.as("b"), $"a.id" > 
$"b.id")`. You can also se [...]
    
    JVM stacktrace:
    org.apache.spark.sql.AnalysisException: Column features#50, features#46 are 
ambiguous. It's probably because you joined several Datasets together, and some 
of these Datasets are the same. This column points to one of the Datasets but 
Spark is unable to figure out which one. Please alias the Datasets with 
different names via `Dataset.as` before joining them, and specify the column 
using qualified name, e.g. `df.as("a").join(df.as("b"), $"a.id" > $"b.id")`. 
You can also set spark.sql.an [...]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes, bug fix
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44526 from zhengruifeng/ml_als_reference.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 dev/sparktestsupport/modules.py                    |  1 +
 .../org/apache/spark/ml/recommendation/ALS.scala   | 21 +++++--
 python/pyspark/ml/tests/test_als.py                | 68 ++++++++++++++++++++++
 3 files changed, 84 insertions(+), 6 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 4ccef788ce8..8595e7ec0e6 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -627,6 +627,7 @@ pyspark_ml = Module(
         "pyspark.ml.tuning",
         # unittests
         "pyspark.ml.tests.test_algorithms",
+        "pyspark.ml.tests.test_als",
         "pyspark.ml.tests.test_base",
         "pyspark.ml.tests.test_evaluation",
         "pyspark.ml.tests.test_feature",
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 65c7d399a88..1e6be16ef62 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -324,13 +324,22 @@ class ALSModel private[ml] (
     // create a new column named map(predictionCol) by running the predict UDF.
     val validatedUsers = checkIntegers(dataset, $(userCol))
     val validatedItems = checkIntegers(dataset, $(itemCol))
+
+    val validatedInputAlias = Identifiable.randomUID("__als_validated_input")
+    val itemFactorsAlias = Identifiable.randomUID("__als_item_factors")
+    val userFactorsAlias = Identifiable.randomUID("__als_user_factors")
+
     val predictions = dataset
-      .join(userFactors,
-        validatedUsers === userFactors("id"), "left")
-      .join(itemFactors,
-        validatedItems === itemFactors("id"), "left")
-      .select(dataset("*"),
-        predict(userFactors("features"), 
itemFactors("features")).as($(predictionCol)))
+      .withColumns(Seq($(userCol), $(itemCol)), Seq(validatedUsers, 
validatedItems))
+      .alias(validatedInputAlias)
+      .join(userFactors.alias(userFactorsAlias),
+        col(s"${validatedInputAlias}.${$(userCol)}") === 
col(s"${userFactorsAlias}.id"), "left")
+      .join(itemFactors.alias(itemFactorsAlias),
+        col(s"${validatedInputAlias}.${$(itemCol)}") === 
col(s"${itemFactorsAlias}.id"), "left")
+      .select(col(s"${validatedInputAlias}.*"),
+        predict(col(s"${userFactorsAlias}.features"), 
col(s"${itemFactorsAlias}.features"))
+          .alias($(predictionCol)))
+
     getColdStartStrategy match {
       case ALSModel.Drop =>
         predictions.na.drop("all", Seq($(predictionCol)))
diff --git a/python/pyspark/ml/tests/test_als.py 
b/python/pyspark/ml/tests/test_als.py
new file mode 100644
index 00000000000..8eec0d93776
--- /dev/null
+++ b/python/pyspark/ml/tests/test_als.py
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import unittest
+
+import pyspark.sql.functions as sf
+from pyspark.ml.recommendation import ALS, ALSModel
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class ALSTest(ReusedSQLTestCase):
+    def test_ambiguous_column(self):
+        data = self.spark.createDataFrame(
+            [[1, 15, 1], [1, 2, 2], [2, 3, 4], [2, 2, 5]],
+            ["user", "item", "rating"],
+        )
+        model = ALS(
+            userCol="user",
+            itemCol="item",
+            ratingCol="rating",
+            numUserBlocks=10,
+            numItemBlocks=10,
+            maxIter=1,
+            seed=42,
+        ).fit(data)
+
+        with tempfile.TemporaryDirectory() as d:
+            model.write().overwrite().save(d)
+            loaded_model = ALSModel().load(d)
+
+            with self.sql_conf({"spark.sql.analyzer.failAmbiguousSelfJoin": 
False}):
+                users = 
loaded_model.userFactors.select(sf.col("id").alias("user"))
+                items = 
loaded_model.itemFactors.select(sf.col("id").alias("item"))
+                predictions = loaded_model.transform(users.crossJoin(items))
+                self.assertTrue(predictions.count() > 0)
+
+            with self.sql_conf({"spark.sql.analyzer.failAmbiguousSelfJoin": 
True}):
+                users = 
loaded_model.userFactors.select(sf.col("id").alias("user"))
+                items = 
loaded_model.itemFactors.select(sf.col("id").alias("item"))
+                predictions = loaded_model.transform(users.crossJoin(items))
+                self.assertTrue(predictions.count() > 0)
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_als import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to