Repository: spark Updated Branches: refs/heads/master 73c20bf32 -> 3d09ceeef
[SPARK-14850][.2][ML] use UnsafeArrayData.fromPrimitiveArray in ml.VectorUDT/MatrixUDT ## What changes were proposed in this pull request? This PR uses `UnsafeArrayData.fromPrimitiveArray` to implement `ml.VectorUDT/MatrixUDT` to avoid boxing/unboxing. ## How was this patch tested? Exiting unit tests. cc: cloud-fan Author: Xiangrui Meng <m...@databricks.com> Closes #12805 from mengxr/SPARK-14850. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d09ceee Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d09ceee Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d09ceee Branch: refs/heads/master Commit: 3d09ceeef9212d4f3a8cd286ce369ace47242358 Parents: 73c20bf Author: Xiangrui Meng <m...@databricks.com> Authored: Fri Apr 29 23:51:01 2016 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Fri Apr 29 23:51:01 2016 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/linalg/MatrixUDT.scala | 11 +++++------ .../scala/org/apache/spark/ml/linalg/VectorUDT.scala | 9 ++++----- 2 files changed, 9 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3d09ceee/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala index 53f4d55..521a216 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala @@ -18,8 +18,7 @@ package org.apache.spark.ml.linalg import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -53,9 +52,9 @@ private[ml] class MatrixUDT extends UserDefinedType[Matrix] { row.setByte(0, 0) row.setInt(1, sm.numRows) row.setInt(2, sm.numCols) - row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) - row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) - row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) + row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs)) + row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices)) + row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values)) row.setBoolean(6, sm.isTransposed) case dm: DenseMatrix => @@ -64,7 +63,7 @@ private[ml] class MatrixUDT extends UserDefinedType[Matrix] { row.setInt(2, dm.numCols) row.setNullAt(3) row.setNullAt(4) - row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) + row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values)) row.setBoolean(6, dm.isTransposed) } row http://git-wip-us.apache.org/repos/asf/spark/blob/3d09ceee/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala index fe93a12..c29f7f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala @@ -18,8 +18,7 @@ package org.apache.spark.ml.linalg import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -46,15 +45,15 @@ private[ml] class VectorUDT extends UserDefinedType[Vector] { val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) - row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row.update(2, UnsafeArrayData.fromPrimitiveArray(indices)) + row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row case DenseVector(values) => val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org