Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/10355#discussion_r47933061
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
---
@@ -275,6 +274,40 @@ class DecisionTreeClassifierSuite extends
SparkFunSuite with MLlibTestSparkConte
val model = dt.fit(df)
}
+ test("DecisionTree should support all NumericType labels") {
+ val dfWithIntLabels =
TreeTests.setMetadata(sqlContext.createDataFrame(Seq(
--- End diff --
It might be less verbose to create the dataframe once, and then add the
other label column types to the same data frame. Something like:
```scala
val dfWithTypes = df
.withColumn("shortLabel", df("labelIndex").cast(ShortType))
.withColumn("longLabel", df("labelIndex").cast(LongType))
.withColumn("intLabel", df("labelIndex").cast(IntegerType))
.withColumn("floatLabel", df("labelIndex").cast(FloatType))
.withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0)))
```
Then just change the label column between training. I'm not sure which way
is better, but this would reduce copying the code ~5 times per test.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]