Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r171760358
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala ---
@@ -32,10 +31,20 @@ class RFormulaSuite extends MLTest with
DefaultReadWriteTest {
def testRFormulaTransform[A: Encoder](
dataframe: DataFrame,
formulaModel: RFormulaModel,
- expected: DataFrame): Unit = {
+ expected: DataFrame,
+ expectedAttributes: AttributeGroup*): Unit = {
+ val resultSchema = formulaModel.transformSchema(dataframe.schema)
+ assert(resultSchema.json == expected.schema.json)
+ assert(resultSchema == expected.schema)
val (first +: rest) = expected.schema.fieldNames.toSeq
val expectedRows = expected.collect()
testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first,
rest: _*) { rows =>
+ assert(rows.head.schema.toString() == resultSchema.toString())
+ for (expectedAttributeGroup <- expectedAttributes) {
+ val attributeGroup =
+
AttributeGroup.fromStructField(rows.head.schema(expectedAttributeGroup.name))
+ assert(attributeGroup == expectedAttributeGroup)
--- End diff --
Should we use `===` instead ?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]