asfgit closed pull request #21509: [SPARK-24489][ML]Check for invalid input
type of weight data in ml.PowerIterationClustering
URL: https://github.com/apache/spark/pull/21509
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
index d9a330f67e8dc..149e99d2f195a 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
@@ -166,6 +166,7 @@ class PowerIterationClustering private[clustering] (
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
lit(1.0)
} else {
+ SchemaUtils.checkNumericType(dataset.schema, $(weightCol))
col($(weightCol)).cast(DoubleType)
}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
index 55b460f1a4524..0ba3ffabb75d2 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
@@ -145,6 +145,21 @@ class PowerIterationClusteringSuite extends SparkFunSuite
assert(msg.contains("Similarity must be nonnegative"))
}
+ test("check for invalid input types of weight") {
+ val invalidWeightData = spark.createDataFrame(Seq(
+ (0L, 1L, "a"),
+ (2L, 3L, "b")
+ )).toDF("src", "dst", "weight")
+
+ val msg = intercept[IllegalArgumentException] {
+ new PowerIterationClustering()
+ .setWeightCol("weight")
+ .assignClusters(invalidWeightData)
+ }.getMessage
+ assert(msg.contains("requirement failed: Column weight must be of type
numeric" +
+ " but was actually of type string."))
+ }
+
test("test default weight") {
val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]