zhengruifeng commented on pull request #35893:
URL: https://github.com/apache/spark/pull/35893#issuecomment-1071948976
just notice that UDF will skip null values, so directly using SQL functions
is better since it can check NULL:
```
val df = sc.parallelize(Seq(("1", 2, 3.0), (null, 4, 5.0), ("0", 2,
Double.PositiveInfinity))).toDF("a", "b", "c")
def checkBinaryLabel = udf {
label: Double =>
require(label == 0 || label == 1,
s"Labels MUST be in {0, 1}, but got $label")
label
}
scala> df.select(checkBinaryLabel(col("a").cast("double"))).show
+----------------------+
|UDF(cast(a as double))|
+----------------------+
| 1.0|
| null|
| 0.0|
+----------------------+
def checkBinaryLabels(labelCol: String): Column = {
val casted = col(labelCol).cast(DoubleType)
when(casted.isNull, raise_error(lit("Labels MUST NOT be NULL")))
.when(casted =!= 0 && casted =!= 1,
raise_error(concat(lit("Labels MUST be in {0, 1}, but got "),
casted)))
.otherwise(casted)
}
scala> df.select(checkBinaryLabels("a")).show
22/03/18 09:32:59 ERROR Executor: Exception in task 2.0 in stage 5.0 (TID 19)
java.lang.RuntimeException: Labels MUST NOT be NULL
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]