Repository: flink
Updated Branches:
  refs/heads/master 1a6bab3ef -> bce355093


[FLINK-2259] [ml] Add Splitter for train, validation and test data set 
generation

This closes #1898.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/bce35509
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/bce35509
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/bce35509

Branch: refs/heads/master
Commit: bce3550936d9984298cba9f4dd45e9965fc45762
Parents: 1a6bab3
Author: Trevor Grant <[email protected]>
Authored: Fri Apr 15 17:37:51 2016 -0500
Committer: Till Rohrmann <[email protected]>
Committed: Fri Jun 17 16:58:13 2016 +0200

----------------------------------------------------------------------
 docs/apis/batch/libs/ml/cross_validation.md     | 175 ++++++++++++++++
 docs/apis/batch/libs/ml/index.md                |   9 +
 .../flink/ml/preprocessing/Splitter.scala       | 210 +++++++++++++++++++
 .../ml/preprocessing/SplitterITSuite.scala      |  98 +++++++++
 4 files changed, 492 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/bce35509/docs/apis/batch/libs/ml/cross_validation.md
----------------------------------------------------------------------
diff --git a/docs/apis/batch/libs/ml/cross_validation.md 
b/docs/apis/batch/libs/ml/cross_validation.md
new file mode 100644
index 0000000..2473317
--- /dev/null
+++ b/docs/apis/batch/libs/ml/cross_validation.md
@@ -0,0 +1,175 @@
+---
+mathjax: include
+title: Cross Validation
+
+# Sub navigation
+sub-nav-group: batch
+sub-nav-parent: flinkml
+sub-nav-title: Cross Validation
+---
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+* This will be replaced by the TOC
+{:toc}
+
+## Description
+
+ A prevalent problem when utilizing machine learning algorithms is 
*overfitting*, or when an algorithm "memorizes" the training data but does a 
poor job extrapolating to out of sample cases. A common method for dealing with 
the overfitting problem is to hold back some subset of data from the original 
training algorithm and then measure the fit algorithm's performance on this 
hold-out set. This is commonly known as *cross validation*.  A model is trained 
on one subset of data and then *validated* on another set of data.
+
+## Cross Validation Strategies
+
+There are several strategies for holding out data. FlinkML has convenience 
methods for
+- Train-Test Splits
+- Train-Test-Holdout Splits
+- K-Fold Splits
+- Multi-Random Splits
+
+### Train-Test Splits
+
+The simplest method of splitting is the `trainTestSplit`. This split takes a 
DataSet and a parameter *fraction*.  The *fraction* indicates the portion of 
the DataSet that should be allocated to the training set. This split also takes 
two additional optional parameters, *precise* and *seed*.  
+
+By default, the Split is done by randomly deciding whether or not an 
observation is assigned to the training DataSet with probability = *fraction*.  
When *precise* is `true` however, additional steps are taken to ensure the 
training set is as close as possible to the length of the DataSet  $\cdot$ 
*fraction*.
+
+The method returns a new `TrainTestDataSet` object which has a `.training` 
attribute containing the training DataSet and a `.testing` attribute containing 
the testing DataSet.
+
+
+### Train-Test-Holdout Splits
+
+In some cases, algorithms have been known to 'learn' the testing set.  To 
combat this issue, a train-test-hold out strategy introduces a secondary 
holdout set, aptly called the *holdout* set.
+
+Traditionally, training and testing would be done to train an algorithms as 
normal and then a final test of the algorithm on the holdout set would be done. 
 Ideally, prediction errors/model scores in the holdout set would not be 
significantly different than those observed in the testing set.
+
+In a train-test-holdout strategy we sacrifice the sample size of the initial 
fitting algorithm for increased confidence that our model is not over-fit.
+
+When using `trainTestHoldout` splitter, the *fraction* `Double` is replaced by 
a *fraction* array of length three. The first element coresponds to the portion 
to be used for training, second for testing, and third for holdout.  The 
weights of this array are *relative*, e.g. an array `Array(3.0, 2.0, 1.0)` 
would results in approximately 50% of the observations being in the training 
set, 33% of the observations in the testing set, and 17% of the observations in 
holdout set.
+
+### K-Fold Splits
+
+In a *k-fold* strategy, the DataSet is split into *k* equal subsets. Then for 
each of the *k* subsets, a `TrainTestDataSet` is created where the subset is 
the `.training` DataSet, and the remaining subsets are the `.testing` set.
+
+For each training set, an algorithm is trained and then is evaluated based on 
the predictions based on the associated testing set. When an algorithm that has 
consistent grades (e.g. prediction errors) across held out datasets we can have 
some confidence that our approach (e.g. choice of algorithm / algorithm 
parameters / number of iterations) is robust against overfitting.
+
+<a 
href="https://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validation">K-Fold
 Cross Validatation</a>
+
+### Multi-Random Splits
+
+The *multi-random* strategy can be thought of as a more general form of the 
*train-test-holdout* strategy. In fact, `.trainTestHoldoutSplit` is a simple 
wrapper for `multiRandomSplit` which also packages the datasets into a 
`trainTestHoldoutDataSet` object.
+
+The first major difference, is that `multiRandomSplit` takes an array of 
fractions of any length. E.g. one can create multiple holdout sets.  
Alternatively, one could think of `kFoldSplit` as a wrapper for 
`multiRandomSplit` (which it is), the difference being `kFoldSplit` creates 
subsets of approximately equal size, where `multiRandomSplit` will create 
subsets of any size.
+
+The second major difference is that `multiRandomSplit` returns an array of 
DataSets, equal in size and proportion to the *fraction array* that it was 
passed as an argument.
+
+## Parameters
+
+The various `Splitter` methods share many parameters.
+
+ <table class="table table-bordered">
+  <thead>
+    <tr>
+      <th class="text-left" style="width: 20%">Parameter</th>
+      <th class="text-center">Type</th>
+      <th class="text-center">Description</th>
+      <th class="text-right">Used by Method</th>
+    </tr>
+  </thead>
+
+  <tbody>
+    <tr>
+      <td><code>input</code></td>
+      <td><code>DataSet[Any]</code></td>
+      <td>DataSet to be split.</td>
+      <td>
+      <code>randomSplit</code><br>
+      <code>multiRandomSplit</code><br>
+      <code>kFoldSplit</code><br>
+      <code>trainTestSplit</code><br>
+      <code>trainTestHoldoutSplit</code>
+      </td>
+    </tr>
+    <tr>
+      <td><code>seed</code></td>
+      <td><code>Long</code></td>
+      <td>
+        <p>
+          Used for seeding the random number generator which sorts DataSets 
into other DataSets.
+        </p>
+      </td>
+      <td>
+      <code>randomSplit</code><br>
+      <code>multiRandomSplit</code><br>
+      <code>kFoldSplit</code><br>
+      <code>trainTestSplit</code><br>
+      <code>trainTestHoldoutSplit</code>
+      </td>
+    </tr>
+    <tr>
+      <td><code>precise</code></td>
+      <td><code>Boolean</code></td>
+      <td>When true, make additional effort to make DataSets as close to the 
prescribed proportions as possible.</td>
+      <td>
+      <code>randomSplit</code><br>
+      <code>trainTestSplit</code>
+      </td>
+    </tr>
+    <tr>
+      <td><code>fraction</code></td>
+      <td><code>Double</code></td>
+      <td>The portion of the `input` to assign to the first or 
<code>.training</code> DataSet. Must be in the range (0,1)</td>
+      <td><code>randomSplit</code><br>
+        <code>trainTestSplit</code>
+      </td>
+    </tr>
+    <tr>
+      <td><code>fracArray</code></td>
+      <td><code>Array[Double]</code></td>
+      <td>An array that prescribes the proportions of the output datasets 
(proportions need not sum to 1 or be within the range (0,1))</td>
+      <td>
+      <code>multiRandomSplit</code><br>
+      <code>trainTestHoldoutSplit</code>
+      </td>
+    </tr>
+    <tr>
+      <td><code>kFolds</code></td>
+      <td><code>Int</code></td>
+      <td>The number of subsets to break the <code>input</code> DataSet 
into.</td>
+      <td><code>kFoldSplit</code></td>
+      </tr>
+
+  </tbody>
+</table>
+
+## Examples
+
+{% highlight scala %}
+// An input dataset- does not have to be of type LabeledVector
+val data: DataSet[LabeledVector] = ...
+
+// A Simple Train-Test-Split
+val dataTrainTest: TrainTestDataSet = Splitter.trainTestSplit(data, 0.6, true)
+
+// Create a train test holdout DataSet
+val dataTrainTestHO: trainTestHoldoutDataSet = 
Splitter.trainTestHoldoutSplit(data, Array(6.0, 3.0, 1.0))
+
+// Create an Array of K TrainTestDataSets
+val dataKFolded: Array[TrainTestDataSet] =  Splitter.kFoldSplit(data, 10)
+
+// create an array of 5 datasets
+val dataMultiRandom: Array[DataSet[T]] = Splitter.multiRandomSplit(data, 
Array(0.5, 0.1, 0.1, 0.1, 0.1))
+{% endhighlight %}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/bce35509/docs/apis/batch/libs/ml/index.md
----------------------------------------------------------------------
diff --git a/docs/apis/batch/libs/ml/index.md b/docs/apis/batch/libs/ml/index.md
index b956287..39b3a02 100644
--- a/docs/apis/batch/libs/ml/index.md
+++ b/docs/apis/batch/libs/ml/index.md
@@ -66,6 +66,7 @@ FlinkML currently supports the following algorithms:
 ### Utilities
 
 * [Distance Metrics](distance_metrics.html)
+* [Cross Validation](cross_validation.html)
 
 ## Getting Started
 
@@ -90,10 +91,18 @@ Now you can start solving your analysis task.
 The following code snippet shows how easy it is to train a multiple linear 
regression model.
 
 {% highlight scala %}
+
+
 // LabeledVector is a feature vector with a label (class or real value)
 val trainingData: DataSet[LabeledVector] = ...
 val testingData: DataSet[Vector] = ...
 
+// Alternatively, a Splitter is used to break up a DataSet into training and 
testing data.
+val dataSet: DataSet[LabeledVector] = ...
+val trainTestData: DataSet[TrainTestDataSet] = Splitter.trainTestSplit(dataSet)
+val trainingData: DataSet[LabeledVector] = trainTestData.training
+val testingData: DataSet[Vector] = trainTestData.testing.map(lv => lv.vector)
+
 val mlr = MultipleLinearRegression()
   .setStepsize(1.0)
   .setIterations(100)

http://git-wip-us.apache.org/repos/asf/flink/blob/bce35509/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/Splitter.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/Splitter.scala
 
b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/Splitter.scala
new file mode 100644
index 0000000..46b1462
--- /dev/null
+++ 
b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/Splitter.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.preprocessing
+
+import org.apache.flink.api.common.typeinfo.{TypeInformation, BasicTypeInfo}
+import org.apache.flink.api.java.Utils
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.DataSet
+import org.apache.flink.api.scala.utils._
+
+
+import org.apache.flink.ml.common.{FlinkMLTools, ParameterMap, WithParameters}
+import org.apache.flink.util.Collector
+import _root_.scala.reflect.ClassTag
+
+object Splitter {
+
+  case class TrainTestDataSet[T: TypeInformation : ClassTag](training: 
DataSet[T],
+                                                             testing: 
DataSet[T])
+
+  case class TrainTestHoldoutDataSet[T: TypeInformation : ClassTag](training: 
DataSet[T],
+                                                                    testing: 
DataSet[T],
+                                                                    holdout: 
DataSet[T])
+  // 
--------------------------------------------------------------------------------------------
+  //  randomSplit
+  // 
--------------------------------------------------------------------------------------------
+  /**
+   * Split a DataSet by the probability fraction of each element.
+   *
+   * @param input           DataSet to be split
+   * @param fraction        Probability that each element is chosen, should be 
[0,1] This fraction
+   *                        refers to the first element in the resulting array.
+   * @param precise         Sampling by default is random and can result in 
slightly lop-sided
+   *                        sample sets. When precise is true, equal sample 
set size are forced,
+   *                        however this is somewhat less efficient.
+   * @param seed            Random number generator seed.
+   * @return An array of two datasets
+   */
+
+  def randomSplit[T: TypeInformation : ClassTag](
+      input: DataSet[T],
+      fraction: Double,
+      precise: Boolean = false,
+      seed: Long = Utils.RNG.nextLong())
+    : Array[DataSet[T]] = {
+    import org.apache.flink.api.scala._
+
+    val indexedInput: DataSet[(Long, T)] = input.zipWithUniqueId
+
+    if ((fraction >= 1) || (fraction <= 0)) {
+      throw new IllegalArgumentException("sampling fraction outside of (0,1) 
bounds is nonsensical")
+    }
+
+    val leftSplit: DataSet[(Long, T)] = precise match {
+      case false => indexedInput.sample(false, fraction, seed)
+      case true => {
+        val count = indexedInput.count()  // todo: count only needed for 
precise and kills perf.
+        val numOfSamples = math.round(fraction * count).toInt
+        indexedInput.sampleWithSize(false, numOfSamples, seed)
+      }
+    }
+
+    val leftSplitLight = leftSplit.map(o => (o._1, false))
+
+    val rightSplit: DataSet[T] = indexedInput.leftOuterJoin[(Long, 
Boolean)](leftSplitLight)
+      .where(0)
+      .equalTo(0).apply {
+        (full: (Long,T) , left: (Long, Boolean), collector: Collector[T]) =>
+        if (left == null) {
+          collector.collect(full._2)
+        }
+    }
+
+    Array(leftSplit.map(o => o._2), rightSplit)
+  }
+
+  // 
--------------------------------------------------------------------------------------------
+  //  multiRandomSplit
+  // 
--------------------------------------------------------------------------------------------
+  /**
+   * Split a DataSet by the probability fraction of each element of a vector.
+   *
+   * @param input           DataSet to be split
+   * @param fracArray       An array of PROPORTIONS for splitting the DataSet. 
Unlike the
+   *                        randomSplit function, number greater than 1 do not 
lead to over
+   *                        sampling. The number of splits is dictated by the 
length of this array.
+   *                        The number are normalized, eg. Array(1.0, 2.0) 
would yield
+   *                        two data sets with a 33/66% split.
+   * @param seed            Random number generator seed.
+   * @return An array of DataSets whose length is equal to the length of 
fracArray
+   */
+  def multiRandomSplit[T: TypeInformation : ClassTag](
+      input: DataSet[T],
+      fracArray: Array[Double],
+      seed: Long = Utils.RNG.nextLong())
+    : Array[DataSet[T]] = {
+
+    import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution
+
+    val eid = new EnumeratedIntegerDistribution((0 to fracArray.length - 
1).toArray, fracArray)
+
+    eid.reseedRandomGenerator(seed)
+
+    val tempDS: DataSet[(Int,T)] = input.map(o => (eid.sample, o))
+
+    val splits = fracArray.length
+    val outputArray = new Array[DataSet[T]]( splits )
+
+    for (k <- 0 to splits-1){
+      outputArray(k) = tempDS.filter(o => o._1 == k)
+                             .map(o => o._2)
+    }
+
+    outputArray
+  }
+
+  // 
--------------------------------------------------------------------------------------------
+  //  kFoldSplit
+  // 
--------------------------------------------------------------------------------------------
+  /**
+   * Split a DataSet into an array of TrainTest DataSets
+   *
+   * @param input           DataSet to be split
+   * @param kFolds          The number of TrainTest DataSets to be returns. 
Each 'testing' will be
+   *                        1/k of the dataset, randomly sampled, the training 
will be the remainder
+   *                        of the dataset.  The DataSet is split into kFolds 
first, so that no
+   *                        observation will occurin in multiple folds.
+   * @param seed            Random number generator seed.
+   * @return An array of TrainTestDataSets
+   */
+  def kFoldSplit[T: TypeInformation : ClassTag](
+      input: DataSet[T],
+      kFolds: Int,
+      seed: Long = Utils.RNG.nextLong())
+    : Array[TrainTestDataSet[T]] = {
+
+    val fracs = Array.fill(kFolds)(1.0)
+    val dataSetArray = multiRandomSplit(input, fracs, seed)
+
+    dataSetArray.map( ds => TrainTestDataSet(dataSetArray.filter(_ != ds)
+                                                         .reduce(_ union _),
+                                             ds))
+
+  }
+
+  // 
--------------------------------------------------------------------------------------------
+  //  trainTestSplit
+  // 
--------------------------------------------------------------------------------------------
+  /**
+   * A wrapper for randomSplit that yields a TrainTestDataSet
+   *
+   * @param input           DataSet to be split
+   * @param fraction        Probability that each element is chosen, should be 
[0,1].
+   *                        This fraction refers to the training element in 
TrainTestSplit
+   * @param precise         Sampling by default is random and can result in 
slightly lop-sided
+   *                        sample sets. When precise is true, equal sample 
set size are forced,
+   *                        however this is somewhat less efficient.
+   * @param seed            Random number generator seed.
+   * @return A TrainTestDataSet
+   */
+  def trainTestSplit[T: TypeInformation : ClassTag](
+      input: DataSet[T],
+      fraction: Double = 0.6,
+      precise: Boolean = false,
+      seed: Long = Utils.RNG.nextLong())
+    : TrainTestDataSet[T] = {
+    val dataSetArray = randomSplit(input, fraction, precise, seed)
+    TrainTestDataSet(dataSetArray(0), dataSetArray(1))
+  }
+
+  // 
--------------------------------------------------------------------------------------------
+  //  trainTestHoldoutSplit
+  // 
--------------------------------------------------------------------------------------------
+  /**
+   * A wrapper for multiRandomSplit that yields a TrainTestHoldoutDataSet
+   *
+   * @param input           DataSet to be split
+   * @param fracTuple       A tuple of three doubles, where the first element 
specifies the
+   *                        size of the training set, the second element the 
testing set, and
+   *                        the third element is the holdout set. These are 
proportional and
+   *                        will be normalized internally.
+   * @param seed            Random number generator seed.
+   * @return A TrainTestDataSet
+   */
+  def trainTestHoldoutSplit[T: TypeInformation : ClassTag](
+      input: DataSet[T],
+      fracTuple: Tuple3[Double, Double, Double] = (0.6,0.3,0.1),
+      seed: Long = Utils.RNG.nextLong())
+    : TrainTestHoldoutDataSet[T] = {
+    val fracArray = Array(fracTuple._1, fracTuple._2, fracTuple._3)
+    val dataSetArray = multiRandomSplit(input, fracArray, seed)
+    TrainTestHoldoutDataSet(dataSetArray(0), dataSetArray(1), dataSetArray(2))
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/bce35509/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/preprocessing/SplitterITSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/preprocessing/SplitterITSuite.scala
 
b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/preprocessing/SplitterITSuite.scala
new file mode 100644
index 0000000..231cb69
--- /dev/null
+++ 
b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/preprocessing/SplitterITSuite.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.ml.preprocessing
+
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.api.scala._
+import org.apache.flink.test.util.FlinkTestBase
+import org.scalatest.{Matchers, FlatSpec}
+import org.apache.flink.ml.math.Vector
+import org.apache.flink.api.scala.utils._
+
+
+class SplitterITSuite extends FlatSpec
+  with Matchers
+  with FlinkTestBase {
+
+  behavior of "Flink's DataSet Splitter"
+
+  import MinMaxScalerData._
+
+ it should "result in datasets with no elements in common and all elements 
used" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val dataSet = env.fromCollection(data)
+
+    val splitDataSets = Splitter.randomSplit(dataSet.zipWithUniqueId, 0.5)
+
+   (splitDataSets(0).union(splitDataSets(1)).count()) should 
equal(dataSet.count())
+
+
+   splitDataSets(0).join(splitDataSets(1)).where(0).equalTo(0).count() should 
equal(0)
+  }
+
+  it should "result in datasets of an expected size when precise" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val dataSet = env.fromCollection(data)
+
+    val splitDataSets = Splitter.randomSplit(dataSet, 0.5, true)
+
+    val expectedLength = data.size.toDouble * 0.5
+
+    splitDataSets(0).count().toDouble should equal(expectedLength +- 1.0)
+  }
+
+  it should "result in expected number of datasets" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val dataSet = env.fromCollection(data)
+
+    val fracArray = Array(0.5, 0.25, 0.25)
+
+    val splitDataSets = Splitter.multiRandomSplit(dataSet, fracArray)
+
+    splitDataSets.length should equal(fracArray.length)
+  }
+
+  it should "produce TrainTestDataSets in which training size is greater than 
testing size" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val dataSet = env.fromCollection(data)
+
+    val dataSetArray = Splitter.kFoldSplit(dataSet, 4)
+
+    (dataSetArray(1).testing.count() < dataSetArray(1).training.count()) 
should be(true)
+
+  }
+
+  it should "throw an exception if sample fraction ins nonsensical" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val dataSet = env.fromCollection(data)
+
+    intercept[IllegalArgumentException] {
+      val splitDataSets = Splitter.randomSplit(dataSet, -0.2)
+    }
+
+    intercept[IllegalArgumentException] {
+      val splitDataSets = Splitter.randomSplit(dataSet, -1.2)
+    }
+
+  }
+}

Reply via email to