zhengruifeng commented on code in PR #49703:
URL: https://github.com/apache/spark/pull/49703#discussion_r1931433500
##########
python/pyspark/ml/tests/test_feature.py:
##########
@@ -1269,12 +1270,25 @@ def test_rformula_string_indexer_order_type(self):
)
rf = RFormula(formula="y ~ x + s",
stringIndexerOrderType="alphabetDesc")
self.assertEqual(rf.getStringIndexerOrderType(), "alphabetDesc")
- transformedDF = rf.fit(df).transform(df)
+ model = rf.fit(df)
+ transformedDF = model.transform(df)
observed = transformedDF.select("features").collect()
expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]]
for i in range(0, len(expected)):
self.assertTrue(all(observed[i]["features"].toArray() ==
expected[i]))
+ # save & load
+ with tempfile.TemporaryDirectory(prefix="vector_indexer") as d:
+ rf.write().overwrite().save(d)
+ rf2 = RFormula.load(d)
+ self.assertEqual(str(rf), str(rf2))
+
+ model.write().overwrite().save(d)
+ model2 = RFormulaModel.load(d)
+ # TODO: fix str(model)
Review Comment:
create SPARK-51015 to track this issue, the `__str__` needs a complex object
cannot be handled by literal
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]