[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199478745
  
--- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala ---
@@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils {
   }
 
   /**
-   * Python Callable function to convert ArrowPayloads into a 
[[DataFrame]].
+   * Python callable function to read a file in Arrow stream format and 
create a [[DataFrame]]
+   * using each serialized ArrowRecordBatch as a partition.
*
-   * @param payloadRDD A JavaRDD of ArrowPayloads.
-   * @param schemaString JSON Formatted Schema for ArrowPayloads.
* @param sqlContext The active [[SQLContext]].
-   * @return The converted [[DataFrame]].
+   * @param filename File to read the Arrow stream from.
+   * @param schemaString JSON Formatted Spark schema for Arrow batches.
+   * @return A new [[DataFrame]].
*/
-  def arrowPayloadToDataFrame(
-  payloadRDD: JavaRDD[Array[Byte]],
-  schemaString: String,
-  sqlContext: SQLContext): DataFrame = {
-ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext)
+  def arrowReadStreamFromFile(
--- End diff --

Can we call it `arrowFileToDataFrame` or something... 
`arrowReadStreamFromFile` and `readArrowStreamFromFile` are just too similar...


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199498622
  
--- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
@@ -38,70 +39,75 @@ import org.apache.spark.util.Utils
 
 
 /**
- * Store Arrow data in a form that can be serialized by Spark and served 
to a Python process.
+ * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow 
stream format.
  */
-private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) 
extends Serializable {
+private[sql] class ArrowBatchStreamWriter(
+schema: StructType,
+out: OutputStream,
+timeZoneId: String) {
 
-  /**
-   * Convert the ArrowPayload to an ArrowRecordBatch.
-   */
-  def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = {
-ArrowConverters.byteArrayToBatch(payload, allocator)
-  }
+  val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+  val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+  // Write the Arrow schema first, before batches
+  MessageSerializer.serialize(writeChannel, arrowSchema)
 
   /**
-   * Get the ArrowPayload as a type that can be served to Python.
+   * Consume iterator to write each serialized ArrowRecordBatch to the 
stream.
*/
-  def asPythonSerializable: Array[Byte] = payload
-}
-
-/**
- * Iterator interface to iterate over Arrow record batches and return rows
- */
-private[sql] trait ArrowRowIterator extends Iterator[InternalRow] {
+  def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = {
+arrowBatchIter.foreach { batchBytes =>
+  writeChannel.write(batchBytes)
+}
+  }
 
   /**
-   * Return the schema loaded from the Arrow record batch being iterated 
over
+   * End the Arrow stream, does not close output stream.
*/
-  def schema: StructType
+  def end(): Unit = {
+// Write End of Stream
--- End diff --

this comment can be removed I think


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199502733
  
--- Diff: core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
---
@@ -398,6 +398,25 @@ private[spark] object PythonRDD extends Logging {
* data collected from this job, and the secret for 
authentication.
*/
   def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
+serveToStream(threadName) { out =>
+  writeIteratorToStream(items, new DataOutputStream(out))
+}
+  }
+
+  /**
+   * Create a socket server and background thread to execute the block of 
code
+   * for the given DataOutputStream.
+   *
+   * The socket server can only accept one connection, or close if no 
connection
+   * in 15 seconds.
+   *
+   * Once a connection comes in, it will execute the block of code and 
pass in
+   * the socket output stream.
+   *
+   * The thread will terminate after the block of code is executed or any
+   * exceptions happen.
+   */
+  private[spark] def serveToStream(threadName: String)(block: OutputStream 
=> Unit): Array[Any] = {
--- End diff --

can you change `block` to `writeFunc` or something? `block` makes me think 
of thread blocking


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199482134
  
--- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
@@ -183,34 +182,111 @@ private[sql] object ArrowConverters {
   }
 
   /**
-   * Convert a byte array to an ArrowRecordBatch.
+   * Load a serialized ArrowRecordBatch.
*/
-  private[arrow] def byteArrayToBatch(
+  private[arrow] def loadBatch(
   batchBytes: Array[Byte],
   allocator: BufferAllocator): ArrowRecordBatch = {
-val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
-val reader = new ArrowFileReader(in, allocator)
-
-// Read a batch from a byte stream, ensure the reader is closed
-Utils.tryWithSafeFinally {
-  val root = reader.getVectorSchemaRoot  // throws IOException
-  val unloader = new VectorUnloader(root)
-  reader.loadNextBatch()  // throws IOException
-  unloader.getRecordBatch
-} {
-  reader.close()
-}
+val in = new ByteArrayInputStream(batchBytes)
+MessageSerializer.deserializeMessageBatch(new 
ReadChannel(Channels.newChannel(in)), allocator)
+  .asInstanceOf[ArrowRecordBatch]  // throws IOException
   }
 
+  /**
+   * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+   */
   private[sql] def toDataFrame(
-  payloadRDD: JavaRDD[Array[Byte]],
+  arrowBatchRDD: JavaRDD[Array[Byte]],
   schemaString: String,
   sqlContext: SQLContext): DataFrame = {
-val rdd = payloadRDD.rdd.mapPartitions { iter =>
+val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
   val context = TaskContext.get()
-  ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), 
context)
+  ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
 }
-val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
 sqlContext.internalCreateDataFrame(rdd, schema)
   }
+
+  /**
+   * Read a file as an Arrow stream and parallelize as an RDD of 
serialized ArrowRecordBatches.
+   */
+  private[sql] def readArrowStreamFromFile(
+  sqlContext: SQLContext,
+  filename: String): JavaRDD[Array[Byte]] = {
+val fileStream = new FileInputStream(filename)
+try {
+  // Create array so that we can safely close the file
+  val batches = getBatchesFromStream(fileStream.getChannel).toArray
+  // Parallelize the record batches to create an RDD
+  JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
+} finally {
+  fileStream.close()
+}
+  }
+
+  /**
+   * Read an Arrow stream input and return an iterator of serialized 
ArrowRecordBatches.
+   */
+  private[sql] def getBatchesFromStream(in: SeekableByteChannel): 
Iterator[Array[Byte]] = {
+
+// TODO: this could be moved to Arrow
+def readMessageLength(in: ReadChannel): Int = {
+  val buffer = ByteBuffer.allocate(4)
+  if (in.readFully(buffer) != 4) {
+return 0
+  }
+  MessageSerializer.bytesToInt(buffer.array())
+}
+
+// TODO: this could be moved to Arrow
+def loadMessage(in: ReadChannel, messageLength: Int, buffer: 
ByteBuffer): Message = {
+  if (in.readFully(buffer) != messageLength) {
+throw new java.io.IOException(
+  "Unexpected end of stream trying to read message.")
+  }
+  buffer.rewind()
+  Message.getRootAsMessage(buffer)
+}
+
+
+// Create an iterator to get each serialized ArrowRecordBatch from a 
stream
+new Iterator[Array[Byte]] {
+  val inputChannel = new ReadChannel(in)
--- End diff --

do we not need to close this when the iterator has been consumed?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199476976
  
--- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
@@ -183,34 +182,111 @@ private[sql] object ArrowConverters {
   }
 
   /**
-   * Convert a byte array to an ArrowRecordBatch.
+   * Load a serialized ArrowRecordBatch.
*/
-  private[arrow] def byteArrayToBatch(
+  private[arrow] def loadBatch(
   batchBytes: Array[Byte],
   allocator: BufferAllocator): ArrowRecordBatch = {
-val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
-val reader = new ArrowFileReader(in, allocator)
-
-// Read a batch from a byte stream, ensure the reader is closed
-Utils.tryWithSafeFinally {
-  val root = reader.getVectorSchemaRoot  // throws IOException
-  val unloader = new VectorUnloader(root)
-  reader.loadNextBatch()  // throws IOException
-  unloader.getRecordBatch
-} {
-  reader.close()
-}
+val in = new ByteArrayInputStream(batchBytes)
+MessageSerializer.deserializeMessageBatch(new 
ReadChannel(Channels.newChannel(in)), allocator)
+  .asInstanceOf[ArrowRecordBatch]  // throws IOException
   }
 
+  /**
+   * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+   */
   private[sql] def toDataFrame(
-  payloadRDD: JavaRDD[Array[Byte]],
+  arrowBatchRDD: JavaRDD[Array[Byte]],
   schemaString: String,
   sqlContext: SQLContext): DataFrame = {
-val rdd = payloadRDD.rdd.mapPartitions { iter =>
+val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
   val context = TaskContext.get()
-  ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), 
context)
+  ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
 }
-val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
 sqlContext.internalCreateDataFrame(rdd, schema)
   }
+
+  /**
+   * Read a file as an Arrow stream and parallelize as an RDD of 
serialized ArrowRecordBatches.
+   */
+  private[sql] def readArrowStreamFromFile(
+  sqlContext: SQLContext,
+  filename: String): JavaRDD[Array[Byte]] = {
+val fileStream = new FileInputStream(filename)
+try {
+  // Create array so that we can safely close the file
+  val batches = getBatchesFromStream(fileStream.getChannel).toArray
+  // Parallelize the record batches to create an RDD
+  JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
+} finally {
+  fileStream.close()
+}
+  }
+
+  /**
+   * Read an Arrow stream input and return an iterator of serialized 
ArrowRecordBatches.
+   */
+  private[sql] def getBatchesFromStream(in: SeekableByteChannel): 
Iterator[Array[Byte]] = {
+
+// TODO: this could be moved to Arrow
+def readMessageLength(in: ReadChannel): Int = {
+  val buffer = ByteBuffer.allocate(4)
+  if (in.readFully(buffer) != 4) {
+return 0
+  }
+  MessageSerializer.bytesToInt(buffer.array())
+}
+
+// TODO: this could be moved to Arrow
+def loadMessage(in: ReadChannel, messageLength: Int, buffer: 
ByteBuffer): Message = {
+  if (in.readFully(buffer) != messageLength) {
+throw new java.io.IOException(
+  "Unexpected end of stream trying to read message.")
+  }
+  buffer.rewind()
+  Message.getRootAsMessage(buffer)
+}
+
+
+// Create an iterator to get each serialized ArrowRecordBatch from a 
stream
+new Iterator[Array[Byte]] {
+  val inputChannel = new ReadChannel(in)
+  var batch: Array[Byte] = readNextBatch()
+
+  override def hasNext: Boolean = batch != null
+
+  override def next(): Array[Byte] = {
+val prevBatch = batch
+batch = readNextBatch()
+prevBatch
+  }
+
+  def readNextBatch(): Array[Byte] = {
+val messageLength = readMessageLength(inputChannel)
+if (messageLength == 0) {
+  return null
+}
+
+val buffer = ByteBuffer.allocate(messageLength)
+val msg = loadMessage(inputChannel, messageLength, buffer)
+val bodyLength = msg.bodyLength().asInstanceOf[Int]
--- End diff --

why not `toInt`?


---

-
To unsubscribe, e-mail: 

[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199496002
  
--- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
@@ -3236,13 +3237,50 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
+   * Collect a Dataset as Arrow batches and serve stream to PySpark.
*/
   private[sql] def collectAsArrowToPython(): Array[Any] = {
+val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
+
 withAction("collectAsArrowToPython", queryExecution) { plan =>
-  val iter: Iterator[Array[Byte]] =
-toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
-  PythonRDD.serveIterator(iter, "serve-Arrow")
+  PythonRDD.serveToStream("serve-Arrow") { outputStream =>
+val out = new DataOutputStream(outputStream)
+val batchWriter = new ArrowBatchStreamWriter(schema, out, 
timeZoneId)
+val arrowBatchRdd = getArrowBatchRdd(plan)
+val numPartitions = arrowBatchRdd.partitions.length
+
+// Batches ordered by index of partition + fractional value of 
batch # in partition
+val batchOrder = new ArrayBuffer[Float]()
+var partitionCount = 0
+
+// Handler to eagerly write batches to Python out of order
+def handlePartitionBatches(index: Int, arrowBatches: 
Array[Array[Byte]]): Unit = {
+  if (arrowBatches.nonEmpty) {
+batchWriter.writeBatches(arrowBatches.iterator)
+(0 until arrowBatches.length).foreach { i =>
+  batchOrder.append(index + i / arrowBatches.length)
--- End diff --

This code: `(0 until array.length).map(i => i / array.length)` is 
guaranteed to produce only zero values isn't it? The code works, since `sortBy` 
evidently preserves the ordering of equal elements, but you may as well do 
`batchOrder.append(index)` since it's the same. 


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199497456
  
--- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
@@ -183,34 +182,111 @@ private[sql] object ArrowConverters {
   }
 
   /**
-   * Convert a byte array to an ArrowRecordBatch.
+   * Load a serialized ArrowRecordBatch.
*/
-  private[arrow] def byteArrayToBatch(
+  private[arrow] def loadBatch(
   batchBytes: Array[Byte],
   allocator: BufferAllocator): ArrowRecordBatch = {
-val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
-val reader = new ArrowFileReader(in, allocator)
-
-// Read a batch from a byte stream, ensure the reader is closed
-Utils.tryWithSafeFinally {
-  val root = reader.getVectorSchemaRoot  // throws IOException
-  val unloader = new VectorUnloader(root)
-  reader.loadNextBatch()  // throws IOException
-  unloader.getRecordBatch
-} {
-  reader.close()
-}
+val in = new ByteArrayInputStream(batchBytes)
+MessageSerializer.deserializeMessageBatch(new 
ReadChannel(Channels.newChannel(in)), allocator)
+  .asInstanceOf[ArrowRecordBatch]  // throws IOException
   }
 
+  /**
+   * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+   */
   private[sql] def toDataFrame(
-  payloadRDD: JavaRDD[Array[Byte]],
+  arrowBatchRDD: JavaRDD[Array[Byte]],
   schemaString: String,
   sqlContext: SQLContext): DataFrame = {
-val rdd = payloadRDD.rdd.mapPartitions { iter =>
+val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
   val context = TaskContext.get()
-  ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), 
context)
+  ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
 }
-val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
 sqlContext.internalCreateDataFrame(rdd, schema)
   }
+
+  /**
+   * Read a file as an Arrow stream and parallelize as an RDD of 
serialized ArrowRecordBatches.
+   */
+  private[sql] def readArrowStreamFromFile(
+  sqlContext: SQLContext,
+  filename: String): JavaRDD[Array[Byte]] = {
+val fileStream = new FileInputStream(filename)
+try {
+  // Create array so that we can safely close the file
+  val batches = getBatchesFromStream(fileStream.getChannel).toArray
+  // Parallelize the record batches to create an RDD
+  JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
+} finally {
+  fileStream.close()
+}
+  }
+
+  /**
+   * Read an Arrow stream input and return an iterator of serialized 
ArrowRecordBatches.
+   */
+  private[sql] def getBatchesFromStream(in: SeekableByteChannel): 
Iterator[Array[Byte]] = {
+
+// TODO: this could be moved to Arrow
+def readMessageLength(in: ReadChannel): Int = {
+  val buffer = ByteBuffer.allocate(4)
+  if (in.readFully(buffer) != 4) {
+return 0
+  }
+  MessageSerializer.bytesToInt(buffer.array())
+}
+
+// TODO: this could be moved to Arrow
+def loadMessage(in: ReadChannel, messageLength: Int, buffer: 
ByteBuffer): Message = {
+  if (in.readFully(buffer) != messageLength) {
+throw new java.io.IOException(
+  "Unexpected end of stream trying to read message.")
+  }
+  buffer.rewind()
+  Message.getRootAsMessage(buffer)
+}
+
+
+// Create an iterator to get each serialized ArrowRecordBatch from a 
stream
+new Iterator[Array[Byte]] {
+  val inputChannel = new ReadChannel(in)
+  var batch: Array[Byte] = readNextBatch()
+
+  override def hasNext: Boolean = batch != null
+
+  override def next(): Array[Byte] = {
+val prevBatch = batch
+batch = readNextBatch()
+prevBatch
+  }
+
+  def readNextBatch(): Array[Byte] = {
+val messageLength = readMessageLength(inputChannel)
+if (messageLength == 0) {
+  return null
+}
+
+val buffer = ByteBuffer.allocate(messageLength)
+val msg = loadMessage(inputChannel, messageLength, buffer)
+val bodyLength = msg.bodyLength().asInstanceOf[Int]
+
+if (msg.headerType() == MessageHeader.RecordBatch) {
+  val allbuf = ByteBuffer.allocate(4 + messageLen

[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199484323
  
--- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
@@ -3236,13 +3237,50 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
+   * Collect a Dataset as Arrow batches and serve stream to PySpark.
*/
   private[sql] def collectAsArrowToPython(): Array[Any] = {
+val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
+
 withAction("collectAsArrowToPython", queryExecution) { plan =>
-  val iter: Iterator[Array[Byte]] =
-toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
-  PythonRDD.serveIterator(iter, "serve-Arrow")
+  PythonRDD.serveToStream("serve-Arrow") { outputStream =>
+val out = new DataOutputStream(outputStream)
+val batchWriter = new ArrowBatchStreamWriter(schema, out, 
timeZoneId)
+val arrowBatchRdd = getArrowBatchRdd(plan)
+val numPartitions = arrowBatchRdd.partitions.length
+
+// Batches ordered by index of partition + fractional value of 
batch # in partition
+val batchOrder = new ArrayBuffer[Float]()
+var partitionCount = 0
+
+// Handler to eagerly write batches to Python out of order
+def handlePartitionBatches(index: Int, arrowBatches: 
Array[Array[Byte]]): Unit = {
+  if (arrowBatches.nonEmpty) {
+batchWriter.writeBatches(arrowBatches.iterator)
+(0 until arrowBatches.length).foreach { i =>
--- End diff --

intellij would like you to know about `arrowBatches.indices` :grin:


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199508609
  
--- Diff: python/pyspark/serializers.py ---
@@ -184,27 +184,59 @@ def loads(self, obj):
 raise NotImplementedError
 
 
-class ArrowSerializer(FramedSerializer):
+class BatchOrderSerializer(Serializer):
 """
-Serializes bytes as Arrow data with the Arrow file format.
+Deserialize a stream of batches followed by batch order information.
 """
 
-def dumps(self, batch):
+def __init__(self, serializer):
+self.serializer = serializer
+self.batch_order = []
+
+def dump_stream(self, iterator, stream):
+return self.serializer.dump_stream(iterator, stream)
+
+def load_stream(self, stream):
+for batch in self.serializer.load_stream(stream):
+yield batch
+num = read_int(stream)
+for i in xrange(num):
+index = read_int(stream)
+self.batch_order.append(index)
+raise StopIteration()
+
+def get_batch_order(self):
--- End diff --

maybe we should initialize `self.batch_order = None`, and add `assert 
self.batch_order is not None` here.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199482021
  
--- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
@@ -183,34 +182,111 @@ private[sql] object ArrowConverters {
   }
 
   /**
-   * Convert a byte array to an ArrowRecordBatch.
+   * Load a serialized ArrowRecordBatch.
*/
-  private[arrow] def byteArrayToBatch(
+  private[arrow] def loadBatch(
   batchBytes: Array[Byte],
   allocator: BufferAllocator): ArrowRecordBatch = {
-val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
-val reader = new ArrowFileReader(in, allocator)
-
-// Read a batch from a byte stream, ensure the reader is closed
-Utils.tryWithSafeFinally {
-  val root = reader.getVectorSchemaRoot  // throws IOException
-  val unloader = new VectorUnloader(root)
-  reader.loadNextBatch()  // throws IOException
-  unloader.getRecordBatch
-} {
-  reader.close()
-}
+val in = new ByteArrayInputStream(batchBytes)
+MessageSerializer.deserializeMessageBatch(new 
ReadChannel(Channels.newChannel(in)), allocator)
+  .asInstanceOf[ArrowRecordBatch]  // throws IOException
   }
 
+  /**
+   * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+   */
   private[sql] def toDataFrame(
-  payloadRDD: JavaRDD[Array[Byte]],
+  arrowBatchRDD: JavaRDD[Array[Byte]],
   schemaString: String,
   sqlContext: SQLContext): DataFrame = {
-val rdd = payloadRDD.rdd.mapPartitions { iter =>
+val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
   val context = TaskContext.get()
-  ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), 
context)
+  ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
 }
-val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
 sqlContext.internalCreateDataFrame(rdd, schema)
   }
+
+  /**
+   * Read a file as an Arrow stream and parallelize as an RDD of 
serialized ArrowRecordBatches.
+   */
+  private[sql] def readArrowStreamFromFile(
+  sqlContext: SQLContext,
+  filename: String): JavaRDD[Array[Byte]] = {
+val fileStream = new FileInputStream(filename)
+try {
+  // Create array so that we can safely close the file
+  val batches = getBatchesFromStream(fileStream.getChannel).toArray
+  // Parallelize the record batches to create an RDD
+  JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
+} finally {
+  fileStream.close()
+}
+  }
+
+  /**
+   * Read an Arrow stream input and return an iterator of serialized 
ArrowRecordBatches.
+   */
+  private[sql] def getBatchesFromStream(in: SeekableByteChannel): 
Iterator[Array[Byte]] = {
+
+// TODO: this could be moved to Arrow
+def readMessageLength(in: ReadChannel): Int = {
+  val buffer = ByteBuffer.allocate(4)
+  if (in.readFully(buffer) != 4) {
+return 0
+  }
+  MessageSerializer.bytesToInt(buffer.array())
+}
+
+// TODO: this could be moved to Arrow
+def loadMessage(in: ReadChannel, messageLength: Int, buffer: 
ByteBuffer): Message = {
+  if (in.readFully(buffer) != messageLength) {
+throw new java.io.IOException(
+  "Unexpected end of stream trying to read message.")
+  }
+  buffer.rewind()
+  Message.getRootAsMessage(buffer)
+}
+
+
--- End diff --

delete extra line


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199499070
  
--- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
@@ -38,70 +39,75 @@ import org.apache.spark.util.Utils
 
 
 /**
- * Store Arrow data in a form that can be serialized by Spark and served 
to a Python process.
+ * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow 
stream format.
  */
-private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) 
extends Serializable {
+private[sql] class ArrowBatchStreamWriter(
+schema: StructType,
+out: OutputStream,
+timeZoneId: String) {
 
-  /**
-   * Convert the ArrowPayload to an ArrowRecordBatch.
-   */
-  def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = {
-ArrowConverters.byteArrayToBatch(payload, allocator)
-  }
+  val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+  val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+  // Write the Arrow schema first, before batches
+  MessageSerializer.serialize(writeChannel, arrowSchema)
 
   /**
-   * Get the ArrowPayload as a type that can be served to Python.
+   * Consume iterator to write each serialized ArrowRecordBatch to the 
stream.
*/
-  def asPythonSerializable: Array[Byte] = payload
-}
-
-/**
- * Iterator interface to iterate over Arrow record batches and return rows
- */
-private[sql] trait ArrowRowIterator extends Iterator[InternalRow] {
+  def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = {
+arrowBatchIter.foreach { batchBytes =>
--- End diff --

nit: `arrowBatchIter.foreach(writeChannel.write)`


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-07-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199371158
  
--- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
@@ -183,34 +182,111 @@ private[sql] object ArrowConverters {
   }
 
   /**
-   * Convert a byte array to an ArrowRecordBatch.
+   * Load a serialized ArrowRecordBatch.
*/
-  private[arrow] def byteArrayToBatch(
+  private[arrow] def loadBatch(
   batchBytes: Array[Byte],
   allocator: BufferAllocator): ArrowRecordBatch = {
-val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
-val reader = new ArrowFileReader(in, allocator)
-
-// Read a batch from a byte stream, ensure the reader is closed
-Utils.tryWithSafeFinally {
-  val root = reader.getVectorSchemaRoot  // throws IOException
-  val unloader = new VectorUnloader(root)
-  reader.loadNextBatch()  // throws IOException
-  unloader.getRecordBatch
-} {
-  reader.close()
-}
+val in = new ByteArrayInputStream(batchBytes)
+MessageSerializer.deserializeMessageBatch(new 
ReadChannel(Channels.newChannel(in)), allocator)
+  .asInstanceOf[ArrowRecordBatch]  // throws IOException
   }
 
+  /**
+   * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+   */
   private[sql] def toDataFrame(
-  payloadRDD: JavaRDD[Array[Byte]],
+  arrowBatchRDD: JavaRDD[Array[Byte]],
   schemaString: String,
   sqlContext: SQLContext): DataFrame = {
-val rdd = payloadRDD.rdd.mapPartitions { iter =>
+val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
   val context = TaskContext.get()
-  ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), 
context)
+  ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
 }
-val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
 sqlContext.internalCreateDataFrame(rdd, schema)
   }
+
+  /**
+   * Read a file as an Arrow stream and parallelize as an RDD of 
serialized ArrowRecordBatches.
+   */
+  private[sql] def readArrowStreamFromFile(
+  sqlContext: SQLContext,
+  filename: String): JavaRDD[Array[Byte]] = {
+val fileStream = new FileInputStream(filename)
+try {
+  // Create array so that we can safely close the file
+  val batches = getBatchesFromStream(fileStream.getChannel).toArray
+  // Parallelize the record batches to create an RDD
+  JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
+} finally {
+  fileStream.close()
+}
+  }
+
+  /**
+   * Read an Arrow stream input and return an iterator of serialized 
ArrowRecordBatches.
+   */
+  private[sql] def getBatchesFromStream(in: SeekableByteChannel): 
Iterator[Array[Byte]] = {
+
+// TODO: this could be moved to Arrow
+def readMessageLength(in: ReadChannel): Int = {
+  val buffer = ByteBuffer.allocate(4)
+  if (in.readFully(buffer) != 4) {
+return 0
+  }
+  MessageSerializer.bytesToInt(buffer.array())
+}
+
+// TODO: this could be moved to Arrow
+def loadMessage(in: ReadChannel, messageLength: Int, buffer: 
ByteBuffer): Message = {
+  if (in.readFully(buffer) != messageLength) {
+throw new java.io.IOException(
+  "Unexpected end of stream trying to read message.")
+  }
+  buffer.rewind()
+  Message.getRootAsMessage(buffer)
+}
+
+
+// Create an iterator to get each serialized ArrowRecordBatch from a 
stream
+new Iterator[Array[Byte]] {
+  val inputChannel = new ReadChannel(in)
+  var batch: Array[Byte] = readNextBatch()
+
+  override def hasNext: Boolean = batch != null
+
+  override def next(): Array[Byte] = {
+val prevBatch = batch
+batch = readNextBatch()
+prevBatch
+  }
+
+  def readNextBatch(): Array[Byte] = {
--- End diff --

Mostly I'm just curious, is there any point in making this a private method?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...

2018-06-29 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/21546#discussion_r199275753
  
--- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
@@ -3236,13 +3237,50 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
+   * Collect a Dataset as Arrow batches and serve stream to PySpark.
*/
   private[sql] def collectAsArrowToPython(): Array[Any] = {
+val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
+
 withAction("collectAsArrowToPython", queryExecution) { plan =>
-  val iter: Iterator[Array[Byte]] =
-toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
-  PythonRDD.serveIterator(iter, "serve-Arrow")
+  PythonRDD.serveToStream("serve-Arrow") { outputStream =>
+val out = new DataOutputStream(outputStream)
+val batchWriter = new ArrowBatchStreamWriter(schema, out, 
timeZoneId)
+val arrowBatchRdd = getArrowBatchRdd(plan)
+val numPartitions = arrowBatchRdd.partitions.length
+
+// Batches ordered by index of partition + batch number for that 
partition
+val batchOrder = new ArrayBuffer[Int]()
+var partitionCount = 0
+
+// Handler to eagerly write batches to Python out of order
+def handlePartitionBatches(index: Int, arrowBatches: 
Array[Array[Byte]]): Unit = {
+  if (arrowBatches.nonEmpty) {
+batchWriter.writeBatches(arrowBatches.iterator)
+(0 until arrowBatches.length).foreach { i =>
+  batchOrder.append(index + i)
+}
+  }
+  partitionCount += 1
+
+  // After last batch, end the stream and write batch order
+  if (partitionCount == numPartitions) {
+batchWriter.end()
+out.writeInt(batchOrder.length)
+// Batch order indices are from 0 to N-1 batches, sorted by 
order they arrived
+batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, i) =>
--- End diff --

Does this logic do what you intend? It interleaves batches.

```python
df = spark.range(64).toDF("a")
df.rdd.getNumPartitions()  # 8
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 4)
pdf = df.toPandas()
pdf['a'].values
# array([ 0,  1,  2,  3,  8,  9, 10, 11,  4,  5,  6,  7, 16, 17, 18, 19, 12,
#   13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 32, 33, 34, 35, 28, 29,
#   30, 31, 36, 37, 38, 39, 40, 41, 42, 43, 48, 49, 50, 51, 44, 45, 46,
#   47, 56, 57, 58, 59, 52, 53, 54, 55, 60, 61, 62, 63])
```


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20629: [SPARK-23451][ML] Deprecate KMeans.computeCost

2018-06-28 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20629
  
+1 for @mgaido91's plan


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #16722: [SPARK-19591][ML][MLlib] Add sample weights to decision ...

2018-06-22 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/16722
  
Yes, feel free to take this over.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19680: [SPARK-22461][ML] Refactor Spark ML model summari...

2018-06-01 Thread sethah
Github user sethah closed the pull request at:

https://github.com/apache/spark/pull/19680


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20701: [SPARK-23528][ML] Add numIter to ClusteringSummar...

2018-03-26 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20701#discussion_r177200509
  
--- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala ---
@@ -46,6 +47,10 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val 
clusterCenters: Array[Vec
   private val clusterCentersWithNorm =
 if (clusterCenters == null) null else clusterCenters.map(new 
VectorWithNorm(_))
 
+  @Since("2.4.0")
--- End diff --

Why does this constructor need to be public?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20632: [SPARK-3159][ML] Add decision tree pruning

2018-03-02 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20632
  
Merged with master. Thanks!


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159][ML] Add decision tree pruning

2018-03-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r171982559
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -703,4 +707,16 @@ private object RandomForestSuite {
 val (indices, values) = map.toSeq.sortBy(_._1).unzip
 Vectors.sparse(size, indices.toArray, values.toArray)
   }
+
+  @tailrec
+  private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long =
--- End diff --

Need to enclose the function body in curly braces


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20632: [SPARK-3159][ML] Add decision tree pruning

2018-03-02 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20632
  
@asolimando I thought of one more thing on the tests that I'd like to have. 
Other than that I think this is ready. 

@srowen For some reason the tests won't run... Do you have any insight into 
this?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159][ML] Add decision tree pruning

2018-03-02 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r171899289
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -631,10 +634,70 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
 assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }
+
+  
///
+  // Tests for pruning of redundant subtrees (generated by a split 
improving the
+  // impurity measure, but always leading to the same prediction).
+  
///
+
+  test("SPARK-3159 tree model redundancy - binary classification") {
+// The following dataset is set up such that splitting over feature 1 
for points having
+// feature 0 = 0 improves the impurity measure, despite the prediction 
will always be 0
+// in both branches.
+val arr = Array(
+  LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+  LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+  LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+  LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+  LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+  LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+)
+val rdd = sc.parallelize(arr)
+
+val numClasses = 2
+val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
+  numClasses = numClasses, maxBins = 32)
+
+val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+  seed = 42, instr = None).head
+
+val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+  seed = 42, instr = None, prune = false).head
+
+assert(prunedTree.numNodes === 5)
+assert(unprunedTree.numNodes === 7)
+  }
+
+  test("SPARK-3159 tree model redundancy - regression") {
+// The following dataset is set up such that splitting over feature 0 
for points having
+// feature 1 = 1 improves the impurity measure, despite the prediction 
will always be 0.5
+// in both branches.
+val arr = Array(
+  LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+  LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+  LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+  LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+  LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+  LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
+  LabeledPoint(0.5, Vectors.dense(1.0, 1.0))
+)
+val rdd = sc.parallelize(arr)
+
+val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance, maxDepth = 4,
+  numClasses = 0, maxBins = 32)
+
+val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+  seed = 42, instr = None).head
+
+val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+  seed = 42, instr = None, prune = false).head
+
--- End diff --

Would you mind adding a check in both tests to make sure that the count of 
all the leaf nodes sums to the total count (i.e. 6)? That way we make sure we 
don't lose information when merging the leaves? You can do it via 
`leafNode.impurityStats.count`.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20709: [SPARK-18844][MLLIB] Adding more binary classification e...

2018-03-02 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20709
  
You don't need to (and should not) open a new PR to fix merge conflicts. 
Just fix them through git, on the same branch.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20632: [SPARK-3159][ML] Add decision tree pruning

2018-03-01 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20632
  
Jenkins test this please.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20709: [SPARK-18844][MLLIB] Adding more binary classification e...

2018-03-01 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20709
  
Why did you close the old one and re-open this? The discussion is lost now.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20708: [SPARK-21209][MLLLIB] Implement Incremental PCA algorith...

2018-03-01 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20708
  
* Only committers can trigger the tests. 
* MLlib is in maintenance only mode, so we wouldn't accept this patch as 
is. 
* If this were to go into ML, I think you'd need to discuss it more 
thoroughly on the JIRA. Another good alternative is to make this a Spark 
package. 
* It's best if you write unit tests and adhere to the style guides when you 
submit new patches. 

I would recommend closing this PR until more discussion has taken place 
about whether or not this is a good fit for Spark ML. Thanks!


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #18998: [SPARK-21748][ML] Migrate the implementation of H...

2018-02-28 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/18998#discussion_r171425701
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala 
---
@@ -93,11 +97,21 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
 val outputSchema = transformSchema(dataset.schema)
-val hashingTF = new 
feature.HashingTF($(numFeatures)).setBinary($(binary))
-// TODO: Make the hashingTF.transform natively in ml framework to 
avoid extra conversion.
-val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML }
+val hashUDF = udf { (terms: Seq[_]) =>
+  val ids = terms.map { term =>
--- End diff --

Ok, I'm just wondering why you changed the code? The old one uses a mutable 
hashmap, which is different than the approach here  (and is slower, in my 
tests).


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20632: [SPARK-3159][ML] Add decision tree pruning

2018-02-28 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20632
  
Jenkins retest this please.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20632: [SPARK-3159][ML] Add decision tree pruning

2018-02-28 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20632
  
Squashing makes it impossible to review the history of the code review, so 
I don't think it's a good idea. It's fine for now. 

This LGTM. Let's see if @srowen or @jkbradley have any thoughts.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20632: [SPARK-3159] added subtree pruning in the translation fr...

2018-02-28 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20632
  
@asolimando Can you change the title to include `[ML]` and also shorten it. 
Maybe just: `Add decision tree pruning`


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20632: [SPARK-3159] added subtree pruning in the translation fr...

2018-02-28 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20632
  
@srowen Do we need you to trigger the tests? I'm not sure why they haven't 
been run...


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20632: [SPARK-3159] added subtree pruning in the translation fr...

2018-02-28 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20632
  
Jenkins test this please.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-27 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r171071499
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -266,15 +265,24 @@ private[tree] class LearningNode(
 var isLeaf: Boolean,
 var stats: ImpurityStats) extends Serializable {
 
+  def toNode: Node = toNode(prune = true)
+
   /**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on 
any children.
*/
-  def toNode: Node = {
-if (leftChild.nonEmpty) {
-  assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
+  def toNode(prune: Boolean = true): Node = {
+
+if (!leftChild.isEmpty || !rightChild.isEmpty) {
+  assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty 
&& stats != null,
 "Unknown error during Decision Tree learning.  Could not convert 
LearningNode to Node.")
-  new InternalNode(stats.impurityCalculator.predict, stats.impurity, 
stats.gain,
-leftChild.get.toNode, rightChild.get.toNode, split.get, 
stats.impurityCalculator)
+  (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
+// when both children make the same prediction, collapse into a 
single leaf
--- End diff --

On second thought, I'm not sure the comment is useful since it just 
explains what the code does. I vote either no comment or we explain why this 
happens, i.e. you can improve impurity without changing the prediction. 


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-27 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r171071028
  
--- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
@@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging {
   featureSubsetStrategy: String,
   seed: Long,
   instr: Option[Instrumentation[_]],
+  prune: Boolean = true,
--- End diff --

how about a comment here: `prune: Boolean = true, // exposed for testing 
only, real trees are always pruned`


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-27 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r171070831
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -631,10 +634,99 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
 assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }
+
+  
///
+  // Tests for pruning of redundant subtrees (generated by a split 
improving the
+  // impurity measure, but always leading to the same prediction).
+  
///
+
+  test("[SPARK-3159] tree model redundancy - binary classification") {
+// The following dataset is set up such that splitting over feature 1 
for points having
+// feature 0 = 0 improves the impurity measure, despite the prediction 
will always be 0
+// in both branches.
+val arr = Array(
+  LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+  LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+  LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+  LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+  LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+  LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+)
+val rdd = sc.parallelize(arr)
+
+val numClasses = 2
+val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
+  numClasses = numClasses, maxBins = 32)
+
+val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+  seed = 42, instr = None).head
+
+val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+  seed = 42, instr = None, prune = false).head
+
+assert(prunedTree.numNodes === 5)
+assert(unprunedTree.numNodes === 7)
+  }
+
+  test("[SPARK-3159] tree model redundancy - multiclass classification") {
--- End diff --

I think it's fine to just leave the binary classification test.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #18998: [SPARK-21748][ML] Migrate the implementation of H...

2018-02-27 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/18998#discussion_r171025256
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala 
---
@@ -93,11 +97,21 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
 val outputSchema = transformSchema(dataset.schema)
-val hashingTF = new 
feature.HashingTF($(numFeatures)).setBinary($(binary))
-// TODO: Make the hashingTF.transform natively in ml framework to 
avoid extra conversion.
-val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML }
+val hashUDF = udf { (terms: Seq[_]) =>
+  val ids = terms.map { term =>
--- End diff --

Why did you implement this differently than the old one?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-26 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170738256
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala ---
@@ -541,7 +541,9 @@ object DecisionTreeSuite extends SparkFunSuite {
 Array[LabeledPoint] = {
 val arr = new Array[LabeledPoint](3000)
 for (i <- 0 until 3000) {
-  if (i < 1000) {
+  // [SPARK-3159] 1001 instead of 1000 to adapt "Multiclass 
classification stump with 10-ary
--- End diff --

I think leaving this out is fine. The tests will fail if it gets changed.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-26 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170738944
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -631,10 +634,99 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
 assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }
+
+  
///
+  // Tests for pruning of redundant subtrees (generated by a split 
improving the
+  // impurity measure, but always leading to the same prediction).
+  
///
+
+  test("[SPARK-3159] tree model redundancy - binary classification") {
+// The following dataset is set up such that splitting over feature 1 
for points having
+// feature 0 = 0 improves the impurity measure, despite the prediction 
will always be 0
+// in both branches.
+val arr = Array(
+  LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+  LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+  LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+  LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+  LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+  LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+)
+val rdd = sc.parallelize(arr)
+
+val numClasses = 2
+val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
+  numClasses = numClasses, maxBins = 32)
+
+val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+  seed = 42, instr = None).head
+
+val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+  seed = 42, instr = None, prune = false).head
+
+assert(prunedTree.numNodes === 5)
+assert(unprunedTree.numNodes === 7)
+  }
+
+  test("[SPARK-3159] tree model redundancy - multiclass classification") {
--- End diff --

Why do we need to test binary and multiclass separately?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-26 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170738068
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -269,12 +268,19 @@ private[tree] class LearningNode(
   /**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on 
any children.
*/
-  def toNode: Node = {
-if (leftChild.nonEmpty) {
-  assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
+  def toNode(prune: Boolean = true): Node = {
--- End diff --

If you just overload the method then you don't need to change the existing 
function calls.

```scala
def toNode: Node = toNode(prune = true)
```


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-23 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170410905
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala ---
@@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite {
 Array[LabeledPoint] = {
 val arr = new Array[LabeledPoint](3000)
 for (i <- 0 until 3000) {
-  if (i < 1000) {
+  if (i < 1001) {
--- End diff --

this is the type of thing that will puzzle someone down the line. I'm ok 
with it, though. :stuck_out_tongue_closed_eyes:


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-23 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170412046
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -631,6 +651,160 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
 assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }
+
+  test("[SPARK-3159] tree model redundancy - binary classification") {
+val numClasses = 2
+
+val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
+  numClasses = numClasses, maxBins = 32)
+
+val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
+
+/* Expected tree structure tested below:
+ root
+  left1   right1
+left2  right2
+
+  pred(left1)  = 0
+  pred(left2)  = 1
+  pred(right2) = 0
+ */
+assert(dt.rootNode.numDescendants === 4)
+assert(dt.rootNode.subtreeDepth === 2)
+
+assert(dt.rootNode.isInstanceOf[InternalNode])
+
+// left 1 prediction test
+assert(dt.rootNode.asInstanceOf[InternalNode].leftChild.prediction === 
0)
+
+val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
+assert(right1.isInstanceOf[InternalNode])
+
+// left 2 prediction test
+assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 1)
+// right 2 prediction test
+assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 0)
+  }
+
+  test("[SPARK-3159] tree model redundancy - multiclass classification") {
+val numClasses = 4
+
+val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
+  numClasses = numClasses, maxBins = 32)
+
+val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
+
+/* Expected tree structure tested below:
+root
+left1 right1
+left2  right2  left3  right3
+
+  pred(left2)  = 0
+  pred(right2)  = 1
+  pred(left3) = 2
+  pred(right3) = 1
+ */
+assert(dt.rootNode.numDescendants === 6)
+assert(dt.rootNode.subtreeDepth === 2)
+
+assert(dt.rootNode.isInstanceOf[InternalNode])
+
+val left1 = dt.rootNode.asInstanceOf[InternalNode].leftChild
+val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
+
+assert(left1.isInstanceOf[InternalNode])
+
+// left 2 prediction test
+assert(left1.asInstanceOf[InternalNode].leftChild.prediction === 0)
+// right 2 prediction test
+assert(left1.asInstanceOf[InternalNode].rightChild.prediction === 1)
+
+assert(right1.isInstanceOf[InternalNode])
+
+// left 3 prediction test
+assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 2)
+// right 3 prediction test
+assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 1)
+  }
+
+  test("[SPARK-3159] tree model redundancy - regression") {
+val numClasses = 2
+
+val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance,
+  maxDepth = 3, maxBins = 10, numClasses = numClasses)
+
+val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
+
+/* Expected tree structure tested below:
+root
+1 2
+   1_1 1_2 2_1 2_2
+  1_1_1   1_1_2   1_2_1   1_2_2   2_1_1   2_1_2
+
+  pred(1_1_1)  = 0.5
+  pred(1_1_2)  = 0.0
+  pred(1_2_1)  = 0.0
+  pred(1_2_2)  = 0.25
+  pred(2_1_1)  = 1.0
+  pred(2_1_2)  = 0.
+  pred(2_2)= 0.5
+ */
+
+assert(dt.rootNode.numDescendants === 12)
--- End diff --

Ok, trying to understand these tests. From what I can tell, you've written 
a data generator that generates random points, and, somewhat by chance, 
generates redundant tree nodes if the tree is not pruned. Your relying on the 
random seed to give you a tree which should have exactly 12 descendants after 
pruning. 

I think these may be overly complicated. IMO, we just need to test the 
situation that causes this by creating simple dataset that can improve the 
impurity by splitting, but which does not change the prediction. For example, 
for the Gini impurity you might have the following:

```scala
val data = Array(
  LabeledPoint(0.0, Vectors.dense(1.0)),
 

[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-23 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170410747
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -402,20 +405,40 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   LabeledPoint(1.0, Vectors.dense(2.0)))
 val input = sc.parallelize(arr)
 
+val seed = 42
+val numTrees = 1
+
 // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
 val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 1,
   numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
 
-val model = RandomForest.run(input, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
-  seed = 42, instr = None).head
-model.rootNode match {
-  case n: InternalNode => n.split match {
-case s: CategoricalSplit =>
-  assert(s.leftCategories === Array(1.0))
-case _ => throw new AssertionError("model.rootNode.split was not a 
CategoricalSplit")
-  }
-  case _ => throw new AssertionError("model.rootNode was not an 
InternalNode")
-}
+val metadata = DecisionTreeMetadata.buildMetadata(input, strategy, 
numTrees = numTrees,
+  featureSubsetStrategy = "all")
+val splits = RandomForest.findSplits(input, metadata, seed = seed)
+
+val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
+  strategy.subsamplingRate, numTrees, false, seed = seed)
+
+val topNode = LearningNode.emptyNode(nodeIndex = 1)
+assert(topNode.isLeaf === false)
+assert(topNode.stats === null)
+
+val nodesForGroup = Map(0 -> Array(topNode))
+val treeToNodeToIndexInfo = Map(0 -> Map(
+  topNode.id -> new RandomForest.NodeIndexInfo(0, None)
+))
+val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
+val bestSplit = RandomForest.findBestSplits(baggedInput, metadata, 
Map(0 -> topNode),
+  nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
+
+assert(topNode.split.isDefined, "rootNode does not have a split")
--- End diff --

I'm a fan of just calling `foreach` like:

```scala
topNode.split.foreach { split =>
assert(split.isInstanceOf[CategoricalSplit])
assert(split.toOld.categories === Array(1.0))
}
```


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-23 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170410687
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -402,20 +407,35 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   LabeledPoint(1.0, Vectors.dense(2.0)))
 val input = sc.parallelize(arr)
 
+val numTrees = 1
+
 // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
 val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 1,
   numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
 
-val model = RandomForest.run(input, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
-  seed = 42, instr = None).head
-model.rootNode match {
-  case n: InternalNode => n.split match {
-case s: CategoricalSplit =>
-  assert(s.leftCategories === Array(1.0))
-case _ => throw new AssertionError("model.rootNode.split was not a 
CategoricalSplit")
-  }
-  case _ => throw new AssertionError("model.rootNode was not an 
InternalNode")
-}
+val metadata = DecisionTreeMetadata.buildMetadata(input, strategy, 
numTrees = numTrees,
+  featureSubsetStrategy = "all")
+val splits = RandomForest.findSplits(input, metadata, seed = seed)
+
+val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
+  strategy.subsamplingRate, numTrees, false, seed = seed)
+
+val topNode = LearningNode.emptyNode(nodeIndex = 1)
+val nodesForGroup = Map(0 -> Array(topNode))
+val treeToNodeToIndexInfo = Map(0 -> Map(
+  topNode.id -> new RandomForest.NodeIndexInfo(0, None)
+))
+val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
+val bestSplit = RandomForest.findBestSplits(baggedInput, metadata, 
Map(0 -> topNode),
--- End diff --

This method returns unit. Doesn't make sense to assign its output.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-23 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170410775
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -283,10 +292,12 @@ private[tree] class LearningNode(
 // Here we want to keep same behavior with the old 
mllib.DecisionTreeModel
 new LeafNode(stats.impurityCalculator.predict, -1.0, 
stats.impurityCalculator)
   }
-
 }
   }
 
+  /** @return true iff a node is a leaf. */
+  private def isLeafNode(): Boolean = leftChild.isEmpty && 
rightChild.isEmpty
--- End diff --

I wouldn't mind just removing this change. What constitutes a leaf node is 
now fuzzy, and if you just inline it the one place it's used there is no 
confusion. At any rate, you don't need the parentheses after method name.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-23 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170412098
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -631,6 +651,160 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
 assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }
+
+  test("[SPARK-3159] tree model redundancy - binary classification") {
+val numClasses = 2
+
+val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
+  numClasses = numClasses, maxBins = 32)
+
+val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
+
+/* Expected tree structure tested below:
+ root
+  left1   right1
+left2  right2
+
+  pred(left1)  = 0
+  pred(left2)  = 1
+  pred(right2) = 0
+ */
+assert(dt.rootNode.numDescendants === 4)
+assert(dt.rootNode.subtreeDepth === 2)
+
+assert(dt.rootNode.isInstanceOf[InternalNode])
+
+// left 1 prediction test
+assert(dt.rootNode.asInstanceOf[InternalNode].leftChild.prediction === 
0)
+
+val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
+assert(right1.isInstanceOf[InternalNode])
+
+// left 2 prediction test
+assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 1)
+// right 2 prediction test
+assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 0)
+  }
+
+  test("[SPARK-3159] tree model redundancy - multiclass classification") {
+val numClasses = 4
+
+val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
+  numClasses = numClasses, maxBins = 32)
+
+val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
+
+/* Expected tree structure tested below:
+root
+left1 right1
+left2  right2  left3  right3
+
+  pred(left2)  = 0
+  pred(right2)  = 1
+  pred(left3) = 2
+  pred(right3) = 1
+ */
+assert(dt.rootNode.numDescendants === 6)
+assert(dt.rootNode.subtreeDepth === 2)
+
+assert(dt.rootNode.isInstanceOf[InternalNode])
+
+val left1 = dt.rootNode.asInstanceOf[InternalNode].leftChild
+val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
+
+assert(left1.isInstanceOf[InternalNode])
+
+// left 2 prediction test
+assert(left1.asInstanceOf[InternalNode].leftChild.prediction === 0)
+// right 2 prediction test
+assert(left1.asInstanceOf[InternalNode].rightChild.prediction === 1)
+
+assert(right1.isInstanceOf[InternalNode])
+
+// left 3 prediction test
+assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 2)
+// right 3 prediction test
+assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 1)
+  }
+
+  test("[SPARK-3159] tree model redundancy - regression") {
+val numClasses = 2
+
+val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance,
+  maxDepth = 3, maxBins = 10, numClasses = numClasses)
+
+val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
+
+/* Expected tree structure tested below:
+root
+1 2
+   1_1 1_2 2_1 2_2
+  1_1_1   1_1_2   1_2_1   1_2_2   2_1_1   2_1_2
+
+  pred(1_1_1)  = 0.5
+  pred(1_1_2)  = 0.0
+  pred(1_2_1)  = 0.0
+  pred(1_2_2)  = 0.25
+  pred(2_1_1)  = 1.0
+  pred(2_1_2)  = 0.
+  pred(2_2)= 0.5
+ */
+
+assert(dt.rootNode.numDescendants === 12)
--- End diff --

The tree tests are already so long and complicated that I think it's 
important to simplify where possible. These tests are useful as they are, but 
it probably won't be obvious why/how they work to future devs. Also, if we can 
avoid adding data generation code, that would be nice (there's already tons of 
code like that laying around the test suites). 


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-23 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170410851
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -18,17 +18,20 @@
 package org.apache.spark.ml.tree.impl
 
 import scala.collection.mutable
+import scala.util.Random
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkContext, SparkFunSuite}
 import org.apache.spark.ml.classification.DecisionTreeClassificationModel
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, 
EnsembleTestHelper}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, 
QuantileStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.configuration.FeatureType._
--- End diff --

unused


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-23 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r170410834
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -270,11 +269,21 @@ private[tree] class LearningNode(
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on 
any children.
*/
   def toNode: Node = {
-if (leftChild.nonEmpty) {
-  assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
+
+// convert to an inner node only when:
+//  -) the node is not a leaf, and
+//  -) the subtree rooted at this node cannot be replaced by a single 
leaf
+// (i.e., there at least two different leaf predictions appear in 
the subtree)
--- End diff --

This comment seems out of place now. You might just say `// when both 
children make the same prediction, collapse into single leaf` or something 
similar below the first case statement.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-21 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r169806418
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -402,20 +405,40 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   LabeledPoint(1.0, Vectors.dense(2.0)))
 val input = sc.parallelize(arr)
 
+val seed = 42
+val numTrees = 1
+
 // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
 val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 1,
   numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
 
-val model = RandomForest.run(input, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
-  seed = 42, instr = None).head
-model.rootNode match {
-  case n: InternalNode => n.split match {
-case s: CategoricalSplit =>
-  assert(s.leftCategories === Array(1.0))
-case _ => throw new AssertionError("model.rootNode.split was not a 
CategoricalSplit")
-  }
-  case _ => throw new AssertionError("model.rootNode was not an 
InternalNode")
-}
+val metadata = DecisionTreeMetadata.buildMetadata(input, strategy, 
numTrees = numTrees,
+  featureSubsetStrategy = "all")
+val splits = RandomForest.findSplits(input, metadata, seed = seed)
+
+val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
+  strategy.subsamplingRate, numTrees, false, seed = seed)
+
+val topNode = LearningNode.emptyNode(nodeIndex = 1)
+assert(topNode.isLeaf === false)
+assert(topNode.stats === null)
+
+val nodesForGroup = Map(0 -> Array(topNode))
+val treeToNodeToIndexInfo = Map(0 -> Map(
+  topNode.id -> new RandomForest.NodeIndexInfo(0, None)
+))
+val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
+val bestSplit = RandomForest.findBestSplits(baggedInput, metadata, 
Map(0 -> topNode),
+  nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
+
+assert(topNode.split.isDefined, "rootNode does not have a split")
--- End diff --

This isn't the purpose of the test, I'd get rid of it.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-21 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r169834234
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 assert(topNode.isLeaf === false)
 assert(topNode.stats === null)
 
-val nodesForGroup = Map((0, Array(topNode)))
-val treeToNodeToIndexInfo = Map((0, Map(
-  (topNode.id, new RandomForest.NodeIndexInfo(0, None))
-)))
+val nodesForGroup = Map(0 -> Array(topNode))
--- End diff --

These are fine, but I slightly prefer leaving stuff like this out. These 
aren't strictly style violations, and it distracts reviewers from the actual 
changes.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-21 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r169834018
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala ---
@@ -359,29 +339,6 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 assert(rootNode.stats.isEmpty)
   }
 
-  test("do not choose split that does not satisfy min instance per node 
requirements") {
-// if a split does not satisfy min instances per node requirements,
-// this split is invalid, even though the information gain of split is 
large.
-val arr = Array(
-  LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
--- End diff --

Here again. You can fix this by inverting the labels. Probably an easier 
fix than moving and re-writing the test.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-21 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r169806576
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -402,20 +405,40 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   LabeledPoint(1.0, Vectors.dense(2.0)))
 val input = sc.parallelize(arr)
 
+val seed = 42
+val numTrees = 1
+
 // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
 val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 1,
   numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
 
-val model = RandomForest.run(input, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
-  seed = 42, instr = None).head
-model.rootNode match {
-  case n: InternalNode => n.split match {
-case s: CategoricalSplit =>
-  assert(s.leftCategories === Array(1.0))
-case _ => throw new AssertionError("model.rootNode.split was not a 
CategoricalSplit")
-  }
-  case _ => throw new AssertionError("model.rootNode was not an 
InternalNode")
-}
+val metadata = DecisionTreeMetadata.buildMetadata(input, strategy, 
numTrees = numTrees,
+  featureSubsetStrategy = "all")
+val splits = RandomForest.findSplits(input, metadata, seed = seed)
+
+val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
+  strategy.subsamplingRate, numTrees, false, seed = seed)
+
+val topNode = LearningNode.emptyNode(nodeIndex = 1)
+assert(topNode.isLeaf === false)
+assert(topNode.stats === null)
+
+val nodesForGroup = Map(0 -> Array(topNode))
+val treeToNodeToIndexInfo = Map(0 -> Map(
+  topNode.id -> new RandomForest.NodeIndexInfo(0, None)
+))
+val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
+val bestSplit = RandomForest.findBestSplits(baggedInput, metadata, 
Map(0 -> topNode),
+  nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
+
+assert(topNode.split.isDefined, "rootNode does not have a split")
+
+val rootNodeSplit = topNode.split.get
+
+assert(rootNodeSplit.isInstanceOf[CategoricalSplit])
+
+assert(topNode.split.get.toOld.categories  === Array(1.0))
--- End diff --

nit: there's an extra space in there. Also, you should either remove 
`rootNodeSplit` or use it everywhere


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-21 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r169833178
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala ---
@@ -303,26 +303,6 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 assert(split.threshold < 2020)
   }
 
-  test("Multiclass classification stump with 10-ary (ordered) categorical 
features") {
--- End diff --

Regarding this test - it fails now for a silly reason. Because of the data, 
the tree built winds up with a right node with equal labels of 1.0 and 2.0. It 
breaks the tie by prediction 1.0, which left node also predicts. You can modify 
the data generating method to:

```scala
  def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](3000)
for (i <- 0 until 3000) {
  if (i < 1001) {
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
  } else if (i < 2000) {
arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
  } else {
arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0))
  }
}
arr
  }
```
so that 2.0 will be predicted. I slightly prefer this, assuming all other 
tests pass (I checked some of the suites). The less stuff we can move around 
that is mostly unrelated to this change, the better.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-21 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r169806090
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -402,20 +405,40 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   LabeledPoint(1.0, Vectors.dense(2.0)))
 val input = sc.parallelize(arr)
 
+val seed = 42
--- End diff --

make this  a `private val seed = 42` member of `RandomForestSuite`


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-21 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r169737984
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -287,6 +291,34 @@ private[tree] class LearningNode(
 }
   }
 
+  /**
+   * @return true iff the node is a leaf.
+   */
+  private def isLeafNode(): Boolean = leftChild.isEmpty && 
rightChild.isEmpty
+
+  // the set of (leaf) predictions appearing in the subtree rooted at the 
given node.
+  private lazy val leafPredictions: Set[Double] = {
--- End diff --

This is a recursive method. It essentially does this:

```scala
def prune(node: Node): Node = {
  val left = prune(node.left)
  val right = prune(node.right)
  if (left.isLeaf && right.isLeaf && left.prediction == right.prediction) {
new LeafNode
  } else {
new InternalNode
  }
}
```

It starts pruning at the bottom and then works its way up.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-21 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r169749740
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -287,6 +291,34 @@ private[tree] class LearningNode(
 }
   }
 
+  /**
+   * @return true iff the node is a leaf.
+   */
+  private def isLeafNode(): Boolean = leftChild.isEmpty && 
rightChild.isEmpty
+
+  // the set of (leaf) predictions appearing in the subtree rooted at the 
given node.
+  private lazy val leafPredictions: Set[Double] = {
--- End diff --

The recursion pattern is the same as before, since it still just calls 
`toNode` on left and right when they are defined.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-21 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r169735198
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -287,6 +291,34 @@ private[tree] class LearningNode(
 }
   }
 
+  /**
+   * @return true iff the node is a leaf.
+   */
+  private def isLeafNode(): Boolean = leftChild.isEmpty && 
rightChild.isEmpty
+
+  // the set of (leaf) predictions appearing in the subtree rooted at the 
given node.
+  private lazy val leafPredictions: Set[Double] = {
--- End diff --

What other cases are there that I'm missing? For one thing, the current 
tests pass with this change. 

This will prune a subtree of arbitrary depth that has a single distinct 
prediction in its leaf nodes.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20632: [SPARK-3159] added subtree pruning in the transla...

2018-02-21 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20632#discussion_r169703784
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -287,6 +291,34 @@ private[tree] class LearningNode(
 }
   }
 
+  /**
+   * @return true iff the node is a leaf.
+   */
+  private def isLeafNode(): Boolean = leftChild.isEmpty && 
rightChild.isEmpty
+
+  // the set of (leaf) predictions appearing in the subtree rooted at the 
given node.
+  private lazy val leafPredictions: Set[Double] = {
--- End diff --

This will store a potentially very large collection at each node. For deep 
regression trees the storage cost could be quite large. We can accomplish the 
same thing without storing them:

```scala
  def toNode: Node = {

// convert to an inner node only when:
//  -) the node is not a leaf, and
//  -) the subtree rooted at this node cannot be replaced by a single 
leaf
// (i.e., there at least two different leaf predictions appear in 
the subtree)
if (!isLeafNode) {
  assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && 
stats != null,
"Unknown error during Decision Tree learning.  Could not convert 
LearningNode to Node.")
  (leftChild.get.toNode, rightChild.get.toNode) match {
case (l: LeafNode, r: LeafNode) if l.prediction == r.prediction =>
  new LeafNode(l.prediction, stats.impurity, 
stats.impurityCalculator)
case (l, r) =>
  new InternalNode(stats.impurityCalculator.predict, 
stats.impurity, stats.gain,
l, r, split.get, stats.impurityCalculator)
  }
} else {
  if (stats.valid) {
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
  stats.impurityCalculator)
  } else {
// Here we want to keep same behavior with the old 
mllib.DecisionTreeModel
new LeafNode(stats.impurityCalculator.predict, -1.0, 
stats.impurityCalculator)
  }
}
  }
```


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20472: [SPARK-22751][ML]Improve ML RandomForest shuffle ...

2018-02-20 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20472#discussion_r169391525
  
--- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
@@ -1001,11 +996,18 @@ private[spark] object RandomForest extends Logging {
 } else {
   val numSplits = metadata.numSplits(featureIndex)
 
-  // get count for each distinct value
-  val (valueCountMap, numSamples) = 
featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
+  // get count for each distinct value except zero value
+  val (partValueCountMap, partNumSamples) = 
featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
 case ((m, cnt), x) =>
   (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
   }
+
+  // Calculate the number of samples for finding splits
+  val numSamples: Int = (samplesFractionForFindSplits(metadata) * 
metadata.numExamples).toInt
--- End diff --

The main problem I see with this is that the sampling we do for split 
finding is _approximate_. Just as an example: say you have 1000 samples, and 
you take 20% for split finding. Your actual sampled RDD has 220 samples in it, 
and 210 of those are non-zero. So, `partNumSamples = 210`, `numSamples = 200` 
and you wind up with `numSamples - partNumSamples = -10` zero values. This is 
not something you expect to happen often (since we care about the highly sparse 
case), but something that we need to consider. We could just require the 
subtraction to be non-negative (and live with a bit of approximation), or you 
could call `count` on the sampled RDD but I don't think it's worth it. Thoughts?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20472: [SPARK-22751][ML]Improve ML RandomForest shuffle ...

2018-02-20 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20472#discussion_r169386551
  
--- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
@@ -1001,11 +996,18 @@ private[spark] object RandomForest extends Logging {
 } else {
   val numSplits = metadata.numSplits(featureIndex)
 
-  // get count for each distinct value
-  val (valueCountMap, numSamples) = 
featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
+  // get count for each distinct value except zero value
+  val (partValueCountMap, partNumSamples) = 
featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
 case ((m, cnt), x) =>
   (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
   }
+
+  // Calculate the number of samples for finding splits
+  val numSamples: Int = (samplesFractionForFindSplits(metadata) * 
metadata.numExamples).toInt
+
+  // add zero value count and get complete statistics
+  val valueCountMap: Map[Double, Int] = partValueCountMap + (0.0 -> 
(numSamples - partNumSamples))
--- End diff --

There can be negative values right?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20472: [SPARK-22751][ML]Improve ML RandomForest shuffle perform...

2018-02-20 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20472
  
@srowen Can you trigger the tests?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20332: [SPARK-23138][ML][DOC] Multiclass logistic regression su...

2018-01-29 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20332
  
Thanks a lot for your review, @MLnick!


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20332: [SPARK-23138][ML][DOC] Multiclass logistic regres...

2018-01-29 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20332#discussion_r164476639
  
--- Diff: docs/ml-classification-regression.md ---
@@ -125,7 +123,8 @@ Continuing the earlier example:
 
[`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary)
 provides a summary for a
 
[`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel).
-Currently, only binary classification is supported. Support for multiclass 
model summaries will be added in the future.
+In the case of binary classification, certain additional metrics are
--- End diff --

There isn't a `binarySummary` method for python


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20411: [SPARK-17139][ML][FOLLOW-UP] update LogisticRegressionSu...

2018-01-26 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20411
  
This is already fixed in #20332.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20332: [SPARK-23138][ML][DOC] Multiclass logistic regres...

2018-01-26 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20332#discussion_r164151869
  
--- Diff: docs/ml-classification-regression.md ---
@@ -97,10 +97,6 @@ only available on the driver.
 
[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary)
 provides a summary for a
 
[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel).
-Currently, only binary classification is supported and the
--- End diff --

Done.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20332: [SPARK-23138][ML][DOC] Multiclass logistic regres...

2018-01-26 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20332#discussion_r164151796
  
--- Diff: docs/ml-classification-regression.md ---
@@ -125,7 +117,6 @@ Continuing the earlier example:
 
[`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary)
 provides a summary for a
 
[`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel).
--- End diff --

Done.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20332: [SPARK-23138][ML][DOC] Multiclass logistic regres...

2018-01-26 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20332#discussion_r164151687
  
--- Diff: 
examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala
 ---
@@ -49,6 +49,48 @@ object MulticlassLogisticRegressionWithElasticNetExample 
{
 // Print the coefficients and intercept for multinomial logistic 
regression
 println(s"Coefficients: \n${lrModel.coefficientMatrix}")
 println(s"Intercepts: \n${lrModel.interceptVector}")
+
+val trainingSummary = lrModel.summary
+
+val objectiveHistory = trainingSummary.objectiveHistory
--- End diff --

Done


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20332: [SPARK-23138][ML][DOC] Multiclass logistic regres...

2018-01-26 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/20332#discussion_r164151731
  
--- Diff: 
examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py 
---
@@ -43,6 +43,43 @@
 # Print the coefficients and intercept for multinomial logistic 
regression
 print("Coefficients: \n" + str(lrModel.coefficientMatrix))
 print("Intercept: " + str(lrModel.interceptVector))
+
+trainingSummary = lrModel.summary
+
+# Obtain the objective per iteration
+objectiveHistory = trainingSummary.objectiveHistory
+print("objectiveHistory:")
+for objective in objectiveHistory:
+print(objective)
+
+print("False positive rate by label:")
--- End diff --

Done


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20332: [SPARK-23138][ML][DOC] Multiclass summary example and us...

2018-01-19 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20332
  
@jkbradley @MLnick 


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20332: [SPARK-23138][ML][DOC] Multiclass summary example...

2018-01-19 Thread sethah
GitHub user sethah opened a pull request:

https://github.com/apache/spark/pull/20332

[SPARK-23138][ML][DOC] Multiclass summary example and user guide

## What changes were proposed in this pull request?

User guide and examples are updated to reflect multiclass logistic 
regression summary which was added in 
[SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139).

I did not make a separate summary example, but added the summary code to 
the multiclass example that already existed. I don't see the need for a 
separate example for the summary. 

## How was this patch tested?

Docs and examples only. Ran all examples locally using spark-submit.


You can merge this pull request into a Git repository by running:

$ git pull https://github.com/sethah/spark multiclass_summary_example

Alternatively you can review and apply these changes as the patch at:

https://github.com/apache/spark/pull/20332.patch

To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:

This closes #20332


commit 9299fc83d2edab956bd13b2e1c985f64dcd2643e
Author: sethah <shendrickson@...>
Date:   2018-01-19T17:52:10Z

adding examples for python, scala, and java

commit bf076ed09abb3bb474e0925b3b9c4dbc6e90771a
Author: sethah <shendrickson@...>
Date:   2018-01-19T18:43:01Z

use binaryTrainingSummary

commit d0aa9f19550deb620e515ec33004be365c5439be
Author: sethah <shendrickson@...>
Date:   2018-01-19T18:46:16Z

import cleanup

commit cb6c811e98d9739a7c1608880b2d0037cdeb5990
Author: sethah <shendrickson@...>
Date:   2018-01-19T18:51:28Z

clarify user guide




---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20188: [SPARK-22993][ML] Clarify HasCheckpointInterval param do...

2018-01-09 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20188
  
Thanks, latest commit should fix it. 


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20188: [SPARK-22993][ML] Clarify HasCheckpointInterval param do...

2018-01-09 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20188
  
Good call @felixcheung! Will update shortly.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160503466
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
--- End diff --

since tags here


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160484001
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala 
---
@@ -1044,6 +1056,50 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
   LinearRegressionSuite.allParamSettings, checkModelData)
   }
 
+  test("pmml export") {
+val lr = new LinearRegression()
+val model = lr.fit(datasetWithWeight)
+def checkModel(pmml: PMML): Unit = {
+  val dd = pmml.getDataDictionary
+  assert(dd.getNumberOfFields === 3)
+  val fields = dd.getDataFields.asScala
+  assert(fields(0).getName().toString === "field_0")
+  assert(fields(0).getOpType() == OpType.CONTINUOUS)
+  val pmmlRegressionModel = 
pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+  val pmmlPredictors = 
pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+  val pmmlWeights = 
pmmlPredictors.asScala.map(_.getCoefficient()).toList
+  assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+  assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+}
+testPMMLWrite(sc, model, checkModel)
+  }
+
+  test("unsupported export format") {
+val lr = new LinearRegression()
+val model = lr.fit(datasetWithWeight)
+intercept[SparkException] {
--- End diff --

Doesn't this and the one below it test the same thing? I think we could 
remove the first one.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160461644
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
+  private var source: String = "internal"
+
   /**
-   * Overwrites if the output path already exists.
+   * Specifies the format of ML export (e.g. PMML, internal, or
+   * the fully qualified class name for export).
*/
-  @Since("1.6.0")
-  def overwrite(): this.type = {
-shouldOverwrite = true
+  @Since("2.3.0")
+  def format(source: String): this.type = {
+this.source = source
 this
   }
 
+  /**
+   * Dispatches the save to the correct MLFormat.
+   */
+  @Since("2.3.0")
+  @throws[IOException]("If the input path already exists but overwrite is 
not enabled.")
+  @throws[SparkException]("If multiple sources for a given short name 
format are found.")
+  override protected def saveImpl(path: String) = {
+val loader = Utils.getContextOrSparkClassLoader
+val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], 
loader)
+val stageName = stage.getClass.getName
+val targetName = s"${source}+${stageName}"
--- End diff --

don't need brackets


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160503640
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite {
   protected final def sc: SparkContext = sparkSession.sparkContext
 }
 
+/**
+ * ML export formats for should implement this trait so that users can 
specify a shortname rather
+ * than the fully qualified class name of the exporter.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
+ *
+ * @since 2.3.0
+ */
+@InterfaceStability.Evolving
+trait MLFormatRegister {
+  /**
+   * The string that represents the format that this data source provider 
uses. This is
+   * overridden by children to provide a nice alias for the data source. 
For example:
+   *
+   * {{{
+   *   override def shortName(): String =
+   *   "pmml+org.apache.spark.ml.regression.LinearRegressionModel"
+   * }}}
+   * Indicates that this format is capable of saving Spark's own 
LinearRegressionModel in pmml.
+   *
+   * Format discovery is done using a ServiceLoader so make sure to list 
your format in
+   * META-INF/services.
+   * @since 2.3.0
+   */
+  def shortName(): String
+}
+
+/**
+ * Implemented by objects that provide ML exportability.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
+ *
+ * @since 2.3.0
+ */
+@InterfaceStability.Evolving
+trait MLWriterFormat {
--- End diff --

do we need the actual since annotations here, though?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160496808
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
+  private var source: String = "internal"
+
   /**
-   * Overwrites if the output path already exists.
+   * Specifies the format of ML export (e.g. PMML, internal, or
+   * the fully qualified class name for export).
*/
-  @Since("1.6.0")
-  def overwrite(): this.type = {
-shouldOverwrite = true
+  @Since("2.3.0")
+  def format(source: String): this.type = {
+this.source = source
 this
   }
 
+  /**
+   * Dispatches the save to the correct MLFormat.
+   */
+  @Since("2.3.0")
+  @throws[IOException]("If the input path already exists but overwrite is 
not enabled.")
+  @throws[SparkException]("If multiple sources for a given short name 
format are found.")
+  override protected def saveImpl(path: String) = {
--- End diff --

return type


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160462794
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite {
   protected final def sc: SparkContext = sparkSession.sparkContext
 }
 
+/**
+ * ML export formats for should implement this trait so that users can 
specify a shortname rather
+ * than the fully qualified class name of the exporter.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
+ *
+ * @since 2.3.0
+ */
+@InterfaceStability.Evolving
+trait MLFormatRegister {
+  /**
+   * The string that represents the format that this data source provider 
uses. This is
+   * overridden by children to provide a nice alias for the data source. 
For example:
+   *
+   * {{{
+   *   override def shortName(): String =
+   *   "pmml+org.apache.spark.ml.regression.LinearRegressionModel"
--- End diff --

what about making a second abstract field `def stageName(): String`, 
instead of having it packed into one string?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160502536
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite {
   protected final def sc: SparkContext = sparkSession.sparkContext
 }
 
+/**
+ * ML export formats for should implement this trait so that users can 
specify a shortname rather
+ * than the fully qualified class name of the exporter.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
+ *
+ * @since 2.3.0
+ */
+@InterfaceStability.Evolving
+trait MLFormatRegister {
+  /**
+   * The string that represents the format that this data source provider 
uses. This is
+   * overridden by children to provide a nice alias for the data source. 
For example:
+   *
+   * {{{
+   *   override def shortName(): String =
+   *   "pmml+org.apache.spark.ml.regression.LinearRegressionModel"
+   * }}}
+   * Indicates that this format is capable of saving Spark's own 
LinearRegressionModel in pmml.
+   *
+   * Format discovery is done using a ServiceLoader so make sure to list 
your format in
+   * META-INF/services.
+   * @since 2.3.0
+   */
+  def shortName(): String
+}
+
+/**
+ * Implemented by objects that provide ML exportability.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
+ *
+ * @since 2.3.0
+ */
+@InterfaceStability.Evolving
+trait MLWriterFormat {
+  /**
+   * Function write the provided pipeline stage out.
--- End diff --

Should add a full doc here with param annotations. Also should it be 
"Function to write ..."?
  


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160501723
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
+  private var source: String = "internal"
+
   /**
-   * Overwrites if the output path already exists.
+   * Specifies the format of ML export (e.g. PMML, internal, or
+   * the fully qualified class name for export).
*/
-  @Since("1.6.0")
-  def overwrite(): this.type = {
-shouldOverwrite = true
+  @Since("2.3.0")
+  def format(source: String): this.type = {
+this.source = source
 this
   }
 
+  /**
+   * Dispatches the save to the correct MLFormat.
+   */
+  @Since("2.3.0")
+  @throws[IOException]("If the input path already exists but overwrite is 
not enabled.")
+  @throws[SparkException]("If multiple sources for a given short name 
format are found.")
+  override protected def saveImpl(path: String) = {
+val loader = Utils.getContextOrSparkClassLoader
+val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], 
loader)
+val stageName = stage.getClass.getName
+val targetName = s"${source}+${stageName}"
+val formats = serviceLoader.asScala.toList
+val shortNames = formats.map(_.shortName())
+val writerCls = 
formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
+  // requested name did not match any given registered alias
+  case Nil =>
+Try(loader.loadClass(source)) match {
+  case Success(writer) =>
+// Found the ML writer using the fully qualified path
+writer
+  case Failure(error) =>
+throw new SparkException(
+  s"Could not load requested format $source for $stageName 
($targetName) had $formats" +
+  s"supporting $shortNames", error)
+}
+  case head :: Nil =>
+head.getClass
+  case _ =>
+// Multiple sources
+throw new SparkException(
+  s"Multiple writers found for $source+$stageName, try using the 
class name of the writer")
+}
+if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
+  val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
--- End diff --

This will fail, non-intuitively, if anyone ever extends `MLWriterFormat` 
with a constructor that has more than zero arguments. Meaning:

```scala
class DummyLinearRegressionWriter(someParam: Int) extends MLWriterFormat
```

will raise `java.lang.NoSuchMethodException: 
org.apache.spark.ml.regression.DummyLinearRegressionWriter.()`
  


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160503322
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
--- End diff --

need `@Since("2.3.0")` here?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160471845
  
--- Diff: 
mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---
@@ -710,15 +711,57 @@ class LinearRegressionModel private[ml] (
   }
 
   /**
-   * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML 
instance.
+   * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for 
this ML instance.
*
* For [[LinearRegressionModel]], this does NOT currently save the 
training [[summary]].
* An option to save [[summary]] may be added in the future.
*
* This also does not save the [[parent]] currently.
*/
   @Since("1.6.0")
-  override def write: MLWriter = new 
LinearRegressionModel.LinearRegressionModelWriter(this)
+  override def write: GeneralMLWriter = new GeneralMLWriter(this)
+}
+
+/** A writer for LinearRegression that handles the "internal" (or default) 
format */
+private class InternalLinearRegressionModelWriter()
+  extends MLWriterFormat with MLFormatRegister {
+
+  override def shortName(): String =
+"internal+org.apache.spark.ml.regression.LinearRegressionModel"
+
+  private case class Data(intercept: Double, coefficients: Vector, scale: 
Double)
+
+  override def write(path: String, sparkSession: SparkSession,
+optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+val instance = stage.asInstanceOf[LinearRegressionModel]
+val sc = sparkSession.sparkContext
+// Save metadata and Params
+DefaultParamsWriter.saveMetadata(instance, path, sc)
+// Save model data: intercept, coefficients, scale
+val data = Data(instance.intercept, instance.coefficients, 
instance.scale)
+val dataPath = new Path(path, "data").toString
+
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+  }
+}
+
+/** A writer for LinearRegression that handles the "pmml" format */
+private class PMMLLinearRegressionModelWriter()
--- End diff --

I could be wrong, but I think we prefer just omitting the `()`?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160463657
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
+  private var source: String = "internal"
+
   /**
-   * Overwrites if the output path already exists.
+   * Specifies the format of ML export (e.g. PMML, internal, or
--- End diff --

change to `e.g. "pmml", "internal", or the fully qualified class name for 
export)."`


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160483562
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala 
---
@@ -1044,6 +1056,50 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
   LinearRegressionSuite.allParamSettings, checkModelData)
   }
 
+  test("pmml export") {
+val lr = new LinearRegression()
+val model = lr.fit(datasetWithWeight)
+def checkModel(pmml: PMML): Unit = {
+  val dd = pmml.getDataDictionary
+  assert(dd.getNumberOfFields === 3)
+  val fields = dd.getDataFields.asScala
+  assert(fields(0).getName().toString === "field_0")
+  assert(fields(0).getOpType() == OpType.CONTINUOUS)
+  val pmmlRegressionModel = 
pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+  val pmmlPredictors = 
pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+  val pmmlWeights = 
pmmlPredictors.asScala.map(_.getCoefficient()).toList
+  assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+  assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+}
+testPMMLWrite(sc, model, checkModel)
+  }
+
+  test("unsupported export format") {
+val lr = new LinearRegression()
+val model = lr.fit(datasetWithWeight)
+intercept[SparkException] {
+  model.write.format("boop").save("boop")
+}
+intercept[SparkException] {
+  model.write.format("com.holdenkarau.boop").save("boop")
+}
+withClue("ML source org.apache.spark.SparkContext is not a valid 
MLWriterFormat") {
+  intercept[SparkException] {
+model.write.format("org.apache.spark.SparkContext").save("boop2")
+  }
+}
+  }
+
+  test("dummy export format is called") {
+val lr = new LinearRegression()
+val model = lr.fit(datasetWithWeight)
+withClue("Dummy writer doesn't write") {
+  intercept[Exception] {
--- End diff --

this just catches any exception. Can we do something like 

```scala
val thrown = intercept[Exception] {

model.write.format("org.apache.spark.ml.regression.DummyLinearRegressionWriter").save("")
  }
  assert(thrown.getMessage.contains("Dummy writer doesn't write."))
```


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160461560
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite {
   protected final def sc: SparkContext = sparkSession.sparkContext
 }
 
+/**
+ * ML export formats for should implement this trait so that users can 
specify a shortname rather
+ * than the fully qualified class name of the exporter.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
--- End diff --

Was this supposed to be retained from the `DataSourceRegister`?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160506592
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala 
---
@@ -1044,6 +1056,50 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
   LinearRegressionSuite.allParamSettings, checkModelData)
   }
 
+  test("pmml export") {
+val lr = new LinearRegression()
+val model = lr.fit(datasetWithWeight)
+def checkModel(pmml: PMML): Unit = {
+  val dd = pmml.getDataDictionary
+  assert(dd.getNumberOfFields === 3)
+  val fields = dd.getDataFields.asScala
+  assert(fields(0).getName().toString === "field_0")
+  assert(fields(0).getOpType() == OpType.CONTINUOUS)
+  val pmmlRegressionModel = 
pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+  val pmmlPredictors = 
pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+  val pmmlWeights = 
pmmlPredictors.asScala.map(_.getCoefficient()).toList
+  assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+  assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+}
+testPMMLWrite(sc, model, checkModel)
+  }
+
+  test("unsupported export format") {
+val lr = new LinearRegression()
+val model = lr.fit(datasetWithWeight)
+intercept[SparkException] {
+  model.write.format("boop").save("boop")
+}
+intercept[SparkException] {
+  model.write.format("com.holdenkarau.boop").save("boop")
+}
+withClue("ML source org.apache.spark.SparkContext is not a valid 
MLWriterFormat") {
+  intercept[SparkException] {
+model.write.format("org.apache.spark.SparkContext").save("boop2")
+  }
+}
+  }
+
+  test("dummy export format is called") {
--- End diff --

We can also add tests for the `MLFormatRegister` similar to 
`DDLSourceLoadSuite`. Just add a `META-INF/services/` directory to 
`src/test/resources/`


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...

2018-01-09 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r160463225
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite {
   protected final def sc: SparkContext = sparkSession.sparkContext
 }
 
+/**
+ * ML export formats for should implement this trait so that users can 
specify a shortname rather
+ * than the fully qualified class name of the exporter.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
+ *
+ * @since 2.3.0
+ */
+@InterfaceStability.Evolving
+trait MLFormatRegister {
+  /**
+   * The string that represents the format that this data source provider 
uses. This is
+   * overridden by children to provide a nice alias for the data source. 
For example:
--- End diff --

"data source" -> "model format"?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #20188: [SPARK-22993][ML] Clarify HasCheckpointInterval param do...

2018-01-08 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/20188
  
cc @srowen @holdenk The MLlib counterparts actually make mention of this, 
but for some reason the note never got ported over to ML package. 

The only caveat I can think of is that this doc is the same for all algos 
that inherit it, but a new algo could potentially not ignore it, but throw an 
error or manually set the checkpoint dir. For now, ALS, LDA, and GBT all use it 
and ignore it in the case the checkpoint dir is not set.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #20188: [SPARK-22993][ML] Clarify HasCheckpointInterval p...

2018-01-08 Thread sethah
GitHub user sethah opened a pull request:

https://github.com/apache/spark/pull/20188

[SPARK-22993][ML] Clarify HasCheckpointInterval param doc

## What changes were proposed in this pull request?

Add a note to the `HasCheckpointInterval` parameter doc that clarifies that 
this setting is ignored when no checkpoint directory has been set on the spark 
context.

## How was this patch tested?

No tests necessary, just a doc update.


You can merge this pull request into a Git repository by running:

$ git pull https://github.com/sethah/spark als_checkpoint_doc

Alternatively you can review and apply these changes as the patch at:

https://github.com/apache/spark/pull/20188.patch

To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:

This closes #20188


commit 752d0ba2702fd5bbd0134e335384a3694db0011a
Author: sethah <shendrickson@...>
Date:   2018-01-08T20:30:15Z

update shared param doc




---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-12 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r156389238
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
--- End diff --

Perhaps for another PR, but maybe we could add a method here:

```scala
  def pmml(path: String): Unit = {
this.source = "pmml"
save(path)
  }
```


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-12 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r155157785
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
+  private var source: String = "internal"
+
   /**
-   * Overwrites if the output path already exists.
+   * Specifies the format of ML export (e.g. PMML, internal, or
+   * the fully qualified class name for export).
*/
-  @Since("1.6.0")
-  def overwrite(): this.type = {
-shouldOverwrite = true
+  @Since("2.3.0")
+  def format(source: String): this.type = {
+this.source = source
 this
   }
 
+  /**
+   * Dispatches the save to the correct MLFormat.
+   */
+  @Since("2.3.0")
+  @throws[IOException]("If the input path already exists but overwrite is 
not enabled.")
+  @throws[SparkException]("If multiple sources for a given short name 
format are found.")
+  override protected def saveImpl(path: String) = {
+val loader = Utils.getContextOrSparkClassLoader
+val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], 
loader)
+val stageName = stage.getClass.getName
+val targetName = s"${source}+${stageName}"
+val formats = serviceLoader.asScala.toList
+val shortNames = formats.map(_.shortName())
+val writerCls = 
formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
+  // requested name did not match any given registered alias
+  case Nil =>
+Try(loader.loadClass(source)) match {
+  case Success(writer) =>
+// Found the ML writer using the fully qualified path
+writer
+  case Failure(error) =>
+throw new SparkException(
+  s"Could not load requested format $source for $stageName 
($targetName) had $formats" +
+  s"supporting $shortNames", error)
+}
+  case head :: Nil =>
+head.getClass
+  case _ =>
+// Multiple sources
+throw new SparkException(
+  s"Multiple writers found for $source+$stageName, try using the 
class name of the writer")
+}
+if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
+  val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
+  writer.write(path, sparkSession, optionMap, stage)
+} else {
+  throw new SparkException("ML source $source is not a valid 
MLWriterFormat")
--- End diff --

nit: need string interpolation here


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-12 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r156370588
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite {
   protected final def sc: SparkContext = sparkSession.sparkContext
 }
 
+/**
+ * ML export formats for should implement this trait so that users can 
specify a shortname rather
+ * than the fully qualified class name of the exporter.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
+ *
+ * @since 2.3.0
+ */
+@InterfaceStability.Evolving
+trait MLFormatRegister {
+  /**
+   * The string that represents the format that this data source provider 
uses. This is
+   * overridden by children to provide a nice alias for the data source. 
For example:
+   *
+   * {{{
+   *   override def shortName(): String =
+   *   "pmml+org.apache.spark.ml.regression.LinearRegressionModel"
+   * }}}
+   * Indicates that this format is capable of saving Spark's own 
LinearRegressionModel in pmml.
+   *
+   * Format discovery is done using a ServiceLoader so make sure to list 
your format in
+   * META-INF/services.
+   * @since 2.3.0
+   */
+  def shortName(): String
+}
+
+/**
+ * Implemented by objects that provide ML exportability.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
+ *
+ * @since 2.3.0
+ */
+@InterfaceStability.Evolving
+trait MLWriterFormat {
+  /**
+   * Function write the provided pipeline stage out.
+   */
+  def write(path: String, session: SparkSession, optionMap: 
mutable.Map[String, String],
+stage: PipelineStage)
--- End diff --

return type?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-12 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r156388871
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala 
---
@@ -994,6 +998,38 @@ class LinearRegressionSuite
   LinearRegressionSuite.allParamSettings, checkModelData)
   }
 
+  test("pmml export") {
+val lr = new LinearRegression()
+val model = lr.fit(datasetWithWeight)
+def checkModel(pmml: PMML): Unit = {
+  val dd = pmml.getDataDictionary
+  assert(dd.getNumberOfFields === 3)
+  val fields = dd.getDataFields.asScala
+  assert(fields(0).getName().toString === "field_0")
+  assert(fields(0).getOpType() == OpType.CONTINUOUS)
+  val pmmlRegressionModel = 
pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+  val pmmlPredictors = 
pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+  val pmmlWeights = 
pmmlPredictors.asScala.map(_.getCoefficient()).toList
+  assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+  assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+}
+testPMMLWrite(sc, model, checkModel)
+  }
+
+  test("unsupported export format") {
--- End diff --

Would be great to have a test that verifies that this works with third 
party implementations. Specifically, that something like 
`model.write.format("org.apache.spark.ml.MyDummyWriter").save(path)` works.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-12 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r156381361
  
--- Diff: 
mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---
@@ -554,7 +555,49 @@ class LinearRegressionModel private[ml] (
* This also does not save the [[parent]] currently.
*/
   @Since("1.6.0")
-  override def write: MLWriter = new 
LinearRegressionModel.LinearRegressionModelWriter(this)
+  override def write: GeneralMLWriter = new GeneralMLWriter(this)
--- End diff --

The doc above this is wrong.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML export to S...

2017-12-12 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/19876
  
@holdenk Do you mind leaving some comments on the intentions/benefits of 
this new API for the benefit of other reviewers? For example, what use cases 
may exist - adding third party PFA support (other third party export tools?), 
and also why we need to add PMML support when there are already tools that do 
this [jpmml-sparkml](https://github.com/jpmml/jpmml-sparkml). 

Also, this is two changes in one PR: adding an API for generic model export 
and adding PMML to LinearRegression. I think it makes sense to separate the 
two, and just focus on the new API here. What do you think?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19904: [SPARK-22707][ML] Optimize CrossValidator memory ...

2017-12-07 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19904#discussion_r155710913
  
--- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
@@ -146,25 +147,18 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") 
override val uid: String)
   val validationDataset = sparkSession.createDataFrame(validation, 
schema).cache()
   logDebug(s"Train split $splitIndex with multiple sets of 
parameters.")
 
+  val completeFitCount = new AtomicInteger(0)
--- End diff --

My understanding of Scala futures may be off here, but this seems to change 
the behavior to me. Now, the unpersist operation will happen in one of the 
training threads, instead of asynchronously in its own thread. I'm not sure how 
much of an effect that will have. 

Why can't you just put all the logic in one map statement like below:

scala
  // Fit models in a Future for training in parallel
  val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) 
=>
Future[Model[_]] {
  val model = est.fit(trainingDataset, 
paramMap).asInstanceOf[Model[_]]

  if (collectSubModelsParam) {
subModels.get(splitIndex)(paramIndex) = model
  }
  // TODO: duplicate evaluator to take extra params from input
  val metric = eval.evaluate(model.transform(validationDataset, 
paramMap))
  logDebug(s"Got metric $metric for model trained with $paramMap.")
  metric
} (executionContext)
  }

  // Unpersist training data only when all models have trained
  Future.sequence[Model[_], Iterable](modelFutures)(implicitly, 
executionContext)
.onComplete { _ => trainingDataset.unpersist() } (executionContext)



---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #19904: [SPARK-22707][ML] Optimize CrossValidator memory occupat...

2017-12-07 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/19904
  
Can you share your test/results with us?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #18729: [SPARK-21526] [MLlib] Add support to ML LogisticRegressi...

2017-12-05 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/18729
  
I actually completely agree about perfect being the enemy of good in this 
case. We should provide something workable that can be safely modified in the 
future if needed. Still, this needs to be done in the other PR 
https://github.com/apache/spark/pull/18610, so I suggest providing feedback and 
code review on that one until it gets merged, then we can revisit here. Thanks!


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19638: [SPARK-22422][ML] Add Adjusted R2 to RegressionMe...

2017-11-08 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19638#discussion_r149731177
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala 
---
@@ -764,13 +764,17 @@ class LinearRegressionSuite
   (Intercept) 6.3022157  0.00186003388   <2e-16 ***
   V2  4.6982442  0.00118053980   <2e-16 ***
   V3  7.1994344  0.00090447961   <2e-16 ***
+
+  # R code for r2adj
--- End diff --

Is there something wrong with the code I pasted above? That worked for me 
when I was using the R shell.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19638: [SPARK-22422][ML] Add Adjusted R2 to RegressionMe...

2017-11-07 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19638#discussion_r149559666
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala 
---
@@ -764,13 +764,17 @@ class LinearRegressionSuite
   (Intercept) 6.3022157  0.00186003388   <2e-16 ***
   V2  4.6982442  0.00118053980   <2e-16 ***
   V3  7.1994344  0.00090447961   <2e-16 ***
+
+  # R code for r2adj
--- End diff --

There may be some confusion. If you type that code, "as-is", into an R 
shell, it will not work. It reference a variable called `X1`, which is never 
defined. When we provide R code in comments like this, we intend for it to be 
copy and pasted into a shell and just work. So, it does not function.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #19680: [SPARK-22641][ML] Refactor Spark ML model summaries

2017-11-06 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/19680
  
I think it's important to restructure these summaries to inherit from the 
same traits, so different methods can be re-used. That structure has to live 
somewhere and there isn't really a logical place except its own module. The 
individual implementation classes don't _have_ to move, as noted above, but it 
is neater. I think it depends on how averse we are to breaking the 
compatibility.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark issue #19680: [SPARK-22641][ML] Refactor Spark ML model summaries

2017-11-06 Thread sethah
Github user sethah commented on the issue:

https://github.com/apache/spark/pull/19680
  
Mima failures :) These APIs were all marked as experimental, which does 
give us some freedom to move them, though I know we prefer to avoid it. It's 
mainly complaining that we changed these classes from 
`org.apache.spark.ml.classification` to `org.apache.spark.ml.summary`. We could 
keep the new traits in the `summary` package and keep the actual 
implementations in the `classficiation/regression/clustering` packages but that 
is less desirable IMO. I'll wait for some feedback from other reviewers.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



  1   2   3   4   5   6   7   8   9   10   >