Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/21195#discussion_r185984527
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
---
@@ -256,6 +258,42 @@ class GaussianMixtureSuite extends SparkFunSuite with
MLlibTestSparkContext
val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4,
triangularValues)
assert(symmetricMatrix === expectedMatrix)
}
+
+ test("GaussianMixture with Array input") {
+ val featuresColNameD = "array_double_features"
+ val featuresColNameF = "array_float_features"
+ val doubleUDF = udf { (features: Vector) =>
+ val featureArray = Array.fill[Double](features.size)(0.0)
+ features.foreachActive((idx, value) => featureArray(idx) =
value.toFloat)
+ featureArray
+ }
+ val floatUDF = udf { (features: Vector) =>
+ val featureArray = Array.fill[Float](features.size)(0.0f)
+ features.foreachActive((idx, value) => featureArray(idx) =
value.toFloat)
+ featureArray
+ }
+ val newdatasetD = dataset.withColumn(featuresColNameD,
doubleUDF(col("features")))
+ .drop("features")
+ val newdatasetF = dataset.withColumn(featuresColNameF,
floatUDF(col("features")))
+ .drop("features")
+ assert(newdatasetD.schema(featuresColNameD).dataType.equals(new
ArrayType(DoubleType, false)))
+ assert(newdatasetF.schema(featuresColNameF).dataType.equals(new
ArrayType(FloatType, false)))
+
+ val gmD = new GaussianMixture().setK(k).setMaxIter(1)
+ .setFeaturesCol(featuresColNameD).setSeed(1)
+ val gmF = new GaussianMixture().setK(k).setMaxIter(1)
+ .setFeaturesCol(featuresColNameF).setSeed(1)
+ val modelD = gmD.fit(newdatasetD)
+ val modelF = gmF.fit(newdatasetF)
+ val transformedD = modelD.transform(newdatasetD)
+ val transformedF = modelF.transform(newdatasetF)
+ val predictDifference = transformedD.select("prediction")
+ .except(transformedF.select("prediction"))
+ assert(predictDifference.count() == 0)
+ val probabilityDifference = transformedD.select("probability")
+ .except(transformedF.select("probability"))
+ assert(probabilityDifference.count() == 0)
--- End diff --
ditto
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]