Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/21081#discussion_r183797106
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
---
@@ -27,28 +26,38 @@ import org.apache.spark.sql.types.{ArrayType,
DoubleType, FloatType}
private[spark] object DatasetUtils {
/**
- * preprocessing the input feature column to Vector
- * @param dataset DataFrame with columns for features
- * @param colName column name for features
- * @return Vector feature column
+ * Cast a column in a Dataset to Vector type.
+ *
+ * The supported data types of the input column are
+ * - Vector
+ * - float/double type Array.
+ *
+ * Note: The returned column does not have Metadata.
+ *
+ * @param dataset input DataFrame
+ * @param colName column name.
+ * @return Vector column
*/
- @Since("2.4.0")
def columnToVector(dataset: Dataset[_], colName: String): Column = {
- val featuresDataType = dataset.schema(colName).dataType
- featuresDataType match {
+ val columnDataType = dataset.schema(colName).dataType
+ columnDataType match {
case _: VectorUDT => col(colName)
case fdt: ArrayType =>
val transferUDF = fdt.elementType match {
case _: FloatType => udf(f = (vector: Seq[Float]) => {
- val featureArray = Array.fill[Double](vector.size)(0.0)
- vector.indices.foreach(idx => featureArray(idx) =
vector(idx).toDouble)
- Vectors.dense(featureArray)
+ val inputArray = Array.fill[Double](vector.size)(0.0)
+ vector.indices.foreach(idx => inputArray(idx) =
vector(idx).toDouble)
+ Vectors.dense(inputArray)
})
case _: DoubleType => udf((vector: Seq[Double]) => {
Vectors.dense(vector.toArray)
})
+ case other =>
--- End diff --
Thanks! I forgot about this since this was generalized.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]