Repository: spark
Updated Branches:
  refs/heads/master 4bac703eb -> 43b149fb8


[SPARK-14850][ML] convert primitive array from/to unsafe array directly in 
VectorUDT/MatrixUDT

## What changes were proposed in this pull request?

This PR adds `fromPrimitiveArray` and `toPrimitiveArray` in `UnsafeArrayData`, 
so that we can do the conversion much faster in VectorUDT/MatrixUDT.

## How was this patch tested?

existing tests and new test suite `UnsafeArraySuite`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #12640 from cloud-fan/ml.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/43b149fb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/43b149fb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/43b149fb

Branch: refs/heads/master
Commit: 43b149fb885a27f9467aab28e5195f6f03aadcf0
Parents: 4bac703
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Fri Apr 29 23:04:51 2016 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Fri Apr 29 23:04:51 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/linalg/Matrices.scala    | 11 ++-
 .../org/apache/spark/mllib/linalg/Vectors.scala |  9 ++-
 .../linalg/UDTSerializationBenchmark.scala      | 70 ++++++++++++++++++++
 .../catalyst/expressions/UnsafeArrayData.java   | 64 +++++++++++++++++-
 .../sql/catalyst/expressions/UnsafeMapData.java |  2 +-
 .../sql/catalyst/util/UnsafeArraySuite.scala    | 44 ++++++++++++
 6 files changed, 186 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/43b149fb/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 90fa4fb..076cca6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -27,8 +27,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.{linalg => newlinalg}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, 
UnsafeArrayData}
 import org.apache.spark.sql.types._
 
 /**
@@ -194,9 +193,9 @@ private[spark] class MatrixUDT extends 
UserDefinedType[Matrix] {
         row.setByte(0, 0)
         row.setInt(1, sm.numRows)
         row.setInt(2, sm.numCols)
-        row.update(3, new 
GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
-        row.update(4, new 
GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
-        row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
+        row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs))
+        row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices))
+        row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values))
         row.setBoolean(6, sm.isTransposed)
 
       case dm: DenseMatrix =>
@@ -205,7 +204,7 @@ private[spark] class MatrixUDT extends 
UserDefinedType[Matrix] {
         row.setInt(2, dm.numCols)
         row.setNullAt(3)
         row.setNullAt(4)
-        row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
+        row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values))
         row.setBoolean(6, dm.isTransposed)
     }
     row

http://git-wip-us.apache.org/repos/asf/spark/blob/43b149fb/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6e3da6b..132e54a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -33,8 +33,7 @@ import org.apache.spark.annotation.{AlphaComponent, Since}
 import org.apache.spark.ml.{linalg => newlinalg}
 import org.apache.spark.mllib.util.NumericParser
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, 
UnsafeArrayData}
 import org.apache.spark.sql.types._
 
 /**
@@ -216,15 +215,15 @@ class VectorUDT extends UserDefinedType[Vector] {
         val row = new GenericMutableRow(4)
         row.setByte(0, 0)
         row.setInt(1, size)
-        row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any])))
-        row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
+        row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
+        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
         row
       case DenseVector(values) =>
         val row = new GenericMutableRow(4)
         row.setByte(0, 1)
         row.setNullAt(1)
         row.setNullAt(2)
-        row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
+        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
         row
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/43b149fb/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
new file mode 100644
index 0000000..be7110a
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.util.Benchmark
+
+/**
+ * Serialization benchmark for VectorUDT.
+ */
+object UDTSerializationBenchmark {
+
+  def main(args: Array[String]): Unit = {
+    val iters = 1e2.toInt
+    val numRows = 1e3.toInt
+
+    val encoder = ExpressionEncoder[Vector].defaultBinding
+
+    val vectors = (1 to numRows).map { i =>
+      Vectors.dense(Array.fill(1e5.toInt)(1.0 * i))
+    }.toArray
+    val rows = vectors.map(encoder.toRow)
+
+    val benchmark = new Benchmark("VectorUDT de/serialization", numRows, iters)
+
+    benchmark.addCase("serialize") { _ =>
+      var sum = 0
+      var i = 0
+      while (i < numRows) {
+        sum += encoder.toRow(vectors(i)).numFields
+        i += 1
+      }
+    }
+
+    benchmark.addCase("deserialize") { _ =>
+      var sum = 0
+      var i = 0
+      while (i < numRows) {
+        sum += encoder.fromRow(rows(i)).numActives
+        i += 1
+      }
+    }
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+
+    VectorUDT de/serialization:         Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
+    
-------------------------------------------------------------------------------------------
+    serialize                                 380 /  392          0.0      
379730.0       1.0X
+    deserialize                               138 /  142          0.0      
137816.6       2.8X
+    */
+    benchmark.run()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/43b149fb/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 648625b..02a863b 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -47,7 +47,7 @@ import org.apache.spark.unsafe.types.UTF8String;
  * Instances of `UnsafeArrayData` act as pointers to row data stored in this 
format.
  */
 // todo: there is a lof of duplicated code between UnsafeRow and 
UnsafeArrayData.
-public class UnsafeArrayData extends ArrayData {
+public final class UnsafeArrayData extends ArrayData {
 
   private Object baseObject;
   private long baseOffset;
@@ -81,7 +81,7 @@ public class UnsafeArrayData extends ArrayData {
   }
 
   public Object[] array() {
-    throw new UnsupportedOperationException("Only supported on 
GenericArrayData.");
+    throw new UnsupportedOperationException("Not supported on 
UnsafeArrayData.");
   }
 
   /**
@@ -336,4 +336,64 @@ public class UnsafeArrayData extends ArrayData {
     arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
     return arrayCopy;
   }
+
+  public static UnsafeArrayData fromPrimitiveArray(int[] arr) {
+    if (arr.length > (Integer.MAX_VALUE - 4) / 8) {
+      throw new UnsupportedOperationException("Cannot convert this array to 
unsafe format as " +
+        "it's too big.");
+    }
+
+    final int offsetRegionSize = 4 * arr.length;
+    final int valueRegionSize = 4 * arr.length;
+    final int totalSize = 4 + offsetRegionSize + valueRegionSize;
+    final byte[] data = new byte[totalSize];
+
+    Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
+
+    int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4;
+    int valueOffset = 4 + offsetRegionSize;
+    for (int i = 0; i < arr.length; i++) {
+      Platform.putInt(data, offsetPosition, valueOffset);
+      offsetPosition += 4;
+      valueOffset += 4;
+    }
+
+    Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data,
+      Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize);
+
+    UnsafeArrayData result = new UnsafeArrayData();
+    result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
+    return result;
+  }
+
+  public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
+    if (arr.length > (Integer.MAX_VALUE - 4) / 12) {
+      throw new UnsupportedOperationException("Cannot convert this array to 
unsafe format as " +
+        "it's too big.");
+    }
+
+    final int offsetRegionSize = 4 * arr.length;
+    final int valueRegionSize = 8 * arr.length;
+    final int totalSize = 4 + offsetRegionSize + valueRegionSize;
+    final byte[] data = new byte[totalSize];
+
+    Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
+
+    int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4;
+    int valueOffset = 4 + offsetRegionSize;
+    for (int i = 0; i < arr.length; i++) {
+      Platform.putInt(data, offsetPosition, valueOffset);
+      offsetPosition += 4;
+      valueOffset += 8;
+    }
+
+    Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data,
+      Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize);
+
+    UnsafeArrayData result = new UnsafeArrayData();
+    result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
+    return result;
+  }
+
+  // TODO: add more specialized methods.
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/43b149fb/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
index 651eb1f..0700148 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
@@ -30,7 +30,7 @@ import org.apache.spark.unsafe.Platform;
  * [unsafe key array numBytes] [unsafe key array] [unsafe value array]
  */
 // TODO: Use a more efficient format which doesn't depend on unsafe array.
-public class UnsafeMapData extends MapData {
+public final class UnsafeMapData extends MapData {
 
   private Object baseObject;
   private long baseOffset;

http://git-wip-us.apache.org/repos/asf/spark/blob/43b149fb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
new file mode 100644
index 0000000..1685276
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
+
+class UnsafeArraySuite extends SparkFunSuite {
+
+  test("from primitive int array") {
+    val array = Array(1, 10, 100)
+    val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
+    assert(unsafe.numElements == 3)
+    assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3)
+    assert(unsafe.getInt(0) == 1)
+    assert(unsafe.getInt(1) == 10)
+    assert(unsafe.getInt(2) == 100)
+  }
+
+  test("from primitive double array") {
+    val array = Array(1.1, 2.2, 3.3)
+    val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
+    assert(unsafe.numElements == 3)
+    assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3)
+    assert(unsafe.getDouble(0) == 1.1)
+    assert(unsafe.getDouble(1) == 2.2)
+    assert(unsafe.getDouble(2) == 3.3)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to