[
https://issues.apache.org/jira/browse/FLINK-1723?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=14618606#comment-14618606
]
ASF GitHub Bot commented on FLINK-1723:
---------------------------------------
Github user thvasilo commented on a diff in the pull request:
https://github.com/apache/flink/pull/891#discussion_r34148455
--- Diff:
flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/evaluation/CrossValidationITSuite.scala
---
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.ml.evaluation
+
+import org.apache.flink.api.scala._
+import org.apache.flink.ml.common.ParameterMap
+import org.apache.flink.ml.preprocessing.StandardScaler
+import org.apache.flink.ml.regression.RegressionData._
+import org.apache.flink.ml.regression.{MultipleLinearRegression,
RegressionData}
+import org.apache.flink.test.util.FlinkTestBase
+
+import org.scalatest.{FlatSpec, Matchers}
+
+class CrossValidationITSuite extends FlatSpec with Matchers with
FlinkTestBase {
+ behavior of "the cross-validation suite"
+
+ it should "be able to split the input into K folds" in {
+ // Original code from the Apache Spark project
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromCollection(1 to 100)
+ val collectedData = data.collect().sorted
+
+ val twoFolds = KFold(2).folds(data, 42L)
+ twoFolds(0)._1.collect().sorted shouldEqual
twoFolds(1)._2.collect().sorted
+ twoFolds(0)._2.collect().sorted shouldEqual
twoFolds(1)._1.collect().sorted
+
+ for (folds <- 2 to 10) {
+ for (seed <- 1 to 5) {
+ val foldedDataSets = KFold(folds).folds(data, seed)
+ foldedDataSets.length shouldEqual folds
+
+ foldedDataSets.foreach { case (training, testing) =>
+ val result = testing.union(training).collect().sorted
+ val testingSize = testing.collect().size.toDouble
+ testingSize should be > 0.0
+
+ // Within 4 standard deviations of the mean
+ val p = 1 / folds.toDouble
+ val range = 4 * math.sqrt(100 * p * (1 - p))
+ val expected = 100 * p
+ val lowerBound = expected - range
+ val upperBound = expected + range
+ //Ensure size of test data is within expected bounds
+ testingSize should be > lowerBound
+ testingSize should be < upperBound
+ training.collect().size should be > 0
+
+ // The combined set should contain all data
+ result shouldEqual collectedData
+ }
+ // K fold cross validation should only have each element in the
validation set exactly once
+ foldedDataSets.map(_._2).reduce((x, y) =>
x.union(y)).collect().sorted shouldEqual
+ data.collect().sorted
+ }
+ }
+ }
+
+ def fixture = new {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ import RegressionData._
+
+
+ val inputDS = env.fromCollection(data)
+
+ val mlr = MultipleLinearRegression()
+ .setStepsize(10.0)
+ .setIterations(100)
+
+ println()
--- End diff --
It prints a line between the consecutive test runs, I just have it there so
I can more easily see what is happening. The tests don't really do anything
yet, just print results.
> Add cross validation for model evaluation
> -----------------------------------------
>
> Key: FLINK-1723
> URL: https://issues.apache.org/jira/browse/FLINK-1723
> Project: Flink
> Issue Type: New Feature
> Components: Machine Learning Library
> Reporter: Till Rohrmann
> Assignee: Theodore Vasiloudis
> Labels: ML
>
> Cross validation [1] is a standard tool to estimate the test error for a
> model. As such it is a crucial tool for every machine learning library.
> The cross validation should work with arbitrary Estimators and error metrics.
> A first cross validation strategy it should support is the k-fold cross
> validation.
> Resources:
> [1] [http://en.wikipedia.org/wiki/Cross-validation]
--
This message was sent by Atlassian JIRA
(v6.3.4#6332)