This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new c33bdc1e [FLINK-31453] Improve the test utils to call methods
explicitly
c33bdc1e is described below
commit c33bdc1e27b4d8185eb45d4aba52c5bccf024e1a
Author: JiangXin <[email protected]>
AuthorDate: Wed Mar 15 10:18:29 2023 +0800
[FLINK-31453] Improve the test utils to call methods explicitly
This closes #225.
---
.../src/test/java/org/apache/flink/ml/util/TestUtils.java | 12 +++++++-----
.../java/org/apache/flink/ml/classification/KnnTest.java | 9 ++++++---
.../org/apache/flink/ml/classification/LinearSVCTest.java | 10 ++++++++--
.../flink/ml/classification/LogisticRegressionTest.java | 12 ++++++++++--
.../apache/flink/ml/classification/NaiveBayesTest.java | 13 +++++++++++--
.../flink/ml/clustering/AgglomerativeClusteringTest.java | 5 ++++-
.../java/org/apache/flink/ml/clustering/KMeansTest.java | 6 ++++--
.../ml/evaluation/BinaryClassificationEvaluatorTest.java | 6 +++++-
.../java/org/apache/flink/ml/feature/BinarizerTest.java | 5 ++++-
.../java/org/apache/flink/ml/feature/BucketizerTest.java | 6 +++++-
.../org/apache/flink/ml/feature/CountVectorizerTest.java | 11 +++++++++--
.../test/java/org/apache/flink/ml/feature/DCTTest.java | 3 ++-
.../apache/flink/ml/feature/ElementwiseProductTest.java | 10 ++++++++--
.../org/apache/flink/ml/feature/FeatureHasherTest.java | 15 ++++++++++++---
.../java/org/apache/flink/ml/feature/HashingTFTest.java | 5 ++++-
.../test/java/org/apache/flink/ml/feature/IDFTest.java | 8 ++++++--
.../java/org/apache/flink/ml/feature/ImputerTest.java | 6 ++++--
.../java/org/apache/flink/ml/feature/InteractionTest.java | 10 ++++++++--
.../org/apache/flink/ml/feature/KBinsDiscretizerTest.java | 12 ++++++++++--
.../org/apache/flink/ml/feature/MaxAbsScalerTest.java | 11 +++++++++--
.../java/org/apache/flink/ml/feature/MinHashLSHTest.java | 9 +++++++--
.../org/apache/flink/ml/feature/MinMaxScalerTest.java | 11 +++++++++--
.../test/java/org/apache/flink/ml/feature/NGramTest.java | 2 +-
.../java/org/apache/flink/ml/feature/NormalizerTest.java | 5 ++++-
.../org/apache/flink/ml/feature/OneHotEncoderTest.java | 13 +++++++++++--
.../apache/flink/ml/feature/OnlineStandardScalerTest.java | 10 ++++++++--
.../apache/flink/ml/feature/PolynomialExpansionTest.java | 5 ++++-
.../org/apache/flink/ml/feature/RandomSplitterTest.java | 5 ++++-
.../org/apache/flink/ml/feature/RegexTokenizerTest.java | 5 ++++-
.../org/apache/flink/ml/feature/RobustScalerTest.java | 11 +++++++++--
.../org/apache/flink/ml/feature/SQLTransformerTest.java | 5 ++++-
.../org/apache/flink/ml/feature/StandardScalerTest.java | 12 ++++++++++--
.../org/apache/flink/ml/feature/StopWordsRemoverTest.java | 5 ++++-
.../java/org/apache/flink/ml/feature/TokenizerTest.java | 5 ++++-
.../flink/ml/feature/UnivariateFeatureSelectorTest.java | 12 ++++++++++--
.../flink/ml/feature/VarianceThresholdSelectorTest.java | 11 +++++++++--
.../org/apache/flink/ml/feature/VectorAssemblerTest.java | 10 ++++++++--
.../org/apache/flink/ml/feature/VectorIndexerTest.java | 12 ++++++++++--
.../org/apache/flink/ml/feature/VectorSlicerTest.java | 5 ++++-
.../ml/feature/stringindexer/IndexToStringModelTest.java | 7 ++++++-
.../flink/ml/feature/stringindexer/StringIndexerTest.java | 12 ++++++++++--
.../org/apache/flink/ml/recommendation/SwingTest.java | 3 ++-
.../apache/flink/ml/regression/LinearRegressionTest.java | 12 ++++++++++--
.../java/org/apache/flink/ml/stats/ANOVATestTest.java | 3 ++-
.../java/org/apache/flink/ml/stats/ChiSqTestTest.java | 3 ++-
.../java/org/apache/flink/ml/stats/FValueTestTest.java | 6 +++++-
46 files changed, 297 insertions(+), 77 deletions(-)
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
index 26199788..78a1fa34 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
@@ -45,6 +45,7 @@ import
org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.types.DataType;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
+import org.apache.flink.util.function.BiFunctionWithException;
import org.apache.flink.util.function.FunctionWithException;
import org.apache.commons.collections.IteratorUtils;
@@ -55,7 +56,6 @@ import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.OutputStream;
-import java.lang.reflect.Method;
import java.util.Comparator;
import java.util.List;
@@ -194,7 +194,11 @@ public class TestUtils {
* stage.
*/
public static <T extends Stage<T>> T saveAndReload(
- StreamTableEnvironment tEnv, T stage, String path) throws
Exception {
+ StreamTableEnvironment tEnv,
+ T stage,
+ String path,
+ BiFunctionWithException<StreamTableEnvironment, String, T,
IOException> loadFunc)
+ throws Exception {
StreamExecutionEnvironment env =
TableUtils.getExecutionEnvironment(tEnv);
stage.save(path);
@@ -207,9 +211,7 @@ public class TestUtils {
}
}
- Method method =
- stage.getClass().getMethod("load",
StreamTableEnvironment.class, String.class);
- return (T) method.invoke(null, tEnv, path);
+ return loadFunc.apply(tEnv, path);
}
/**
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
index 28d9c4f7..483dc34d 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
@@ -195,10 +195,12 @@ public class KnnTest extends AbstractTestBase {
public void testSaveLoadAndPredict() throws Exception {
Knn knn = new Knn();
Knn loadedKnn =
- TestUtils.saveAndReload(tEnv, knn,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, knn, tempFolder.newFolder().getAbsolutePath(),
Knn::load);
KnnModel knnModel = loadedKnn.fit(trainData);
knnModel =
- TestUtils.saveAndReload(tEnv, knnModel,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, knnModel,
tempFolder.newFolder().getAbsolutePath(), KnnModel::load);
assertEquals(
Arrays.asList("packedFeatures", "featureNormSquares",
"labels"),
knnModel.getModelData()[0].getResolvedSchema().getColumnNames());
@@ -211,7 +213,8 @@ public class KnnTest extends AbstractTestBase {
Knn knn = new Knn();
KnnModel knnModel = knn.fit(trainData);
KnnModel newModel =
- TestUtils.saveAndReload(tEnv, knnModel,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, knnModel,
tempFolder.newFolder().getAbsolutePath(), KnnModel::load);
Table output = newModel.transform(predictData)[0];
verifyPredictionResult(output, knn.getLabelCol(),
knn.getPredictionCol());
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
index 9b9f7a76..b2a20eb2 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
@@ -212,9 +212,15 @@ public class LinearSVCTest extends AbstractTestBase {
public void testSaveLoadAndPredict() throws Exception {
LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
linearSVC =
- TestUtils.saveAndReload(tEnv, linearSVC,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, linearSVC,
tempFolder.newFolder().getAbsolutePath(), LinearSVC::load);
LinearSVCModel model = linearSVC.fit(trainDataTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ LinearSVCModel::load);
assertEquals(
Collections.singletonList("coefficient"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
index 20db30f0..ad9a5416 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
@@ -239,9 +239,17 @@ public class LogisticRegressionTest extends
AbstractTestBase {
LogisticRegression logisticRegression = new
LogisticRegression().setWeightCol("weight");
logisticRegression =
TestUtils.saveAndReload(
- tEnv, logisticRegression,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ logisticRegression,
+ tempFolder.newFolder().getAbsolutePath(),
+ LogisticRegression::load);
LogisticRegressionModel model =
logisticRegression.fit(binomialDataTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ LogisticRegressionModel::load);
assertEquals(
Arrays.asList("coefficient", "modelVersion"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
index d727964e..cac6e4c9 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
@@ -265,11 +265,20 @@ public class NaiveBayesTest extends AbstractTestBase {
@Test
public void testSaveLoad() throws Exception {
estimator =
- TestUtils.saveAndReload(tEnv, estimator,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ estimator,
+ tempFolder.newFolder().getAbsolutePath(),
+ NaiveBayes::load);
NaiveBayesModel model = estimator.fit(trainTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ NaiveBayesModel::load);
Table outputTable = model.transform(predictTable)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java
index 5e360efb..292e793e 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java
@@ -413,7 +413,10 @@ public class AgglomerativeClusteringTest extends
AbstractTestBase {
agglomerativeClustering =
TestUtils.saveAndReload(
- tEnv, agglomerativeClustering,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ agglomerativeClustering,
+ tempFolder.newFolder().getAbsolutePath(),
+ AgglomerativeClustering::load);
Table[] outputs = agglomerativeClustering.transform(inputDataTable);
verifyClusteringResult(
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
index ad54926f..8f196172 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
@@ -211,10 +211,12 @@ public class KMeansTest extends AbstractTestBase {
public void testSaveLoadAndPredict() throws Exception {
KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
KMeans loadedKmeans =
- TestUtils.saveAndReload(tEnv, kmeans,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, kmeans,
tempFolder.newFolder().getAbsolutePath(), KMeans::load);
KMeansModel model = loadedKmeans.fit(dataTable);
KMeansModel loadedModel =
- TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, model, tempFolder.newFolder().getAbsolutePath(),
KMeansModel::load);
Table output = loadedModel.transform(dataTable)[0];
assertEquals(
Arrays.asList("centroids", "weights"),
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
index 9dd388ab..0c146a3d 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
@@ -316,7 +316,11 @@ public class BinaryClassificationEvaluatorTest extends
AbstractTestBase {
BinaryClassificationEvaluatorParams.KS,
BinaryClassificationEvaluatorParams.AREA_UNDER_ROC);
BinaryClassificationEvaluator loadedEval =
- TestUtils.saveAndReload(tEnv, eval,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ eval,
+ tempFolder.newFolder().getAbsolutePath(),
+ BinaryClassificationEvaluator::load);
Table evalResult = loadedEval.transform(inputDataTable)[0];
assertArrayEquals(
new String[] {
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java
index bb8cff40..eb06b0c6 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java
@@ -141,7 +141,10 @@ public class BinarizerTest extends AbstractTestBase {
Binarizer loadedBinarizer =
TestUtils.saveAndReload(
- tEnv, binarizer,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ binarizer,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ Binarizer::load);
Table output = loadedBinarizer.transform(inputDataTable)[0];
verifyOutputResult(output, loadedBinarizer.getOutputCols());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BucketizerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BucketizerTest.java
index a3b81600..2c49dfa6 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BucketizerTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BucketizerTest.java
@@ -203,7 +203,11 @@ public class BucketizerTest extends AbstractTestBase {
.setHandleInvalid(HasHandleInvalid.KEEP_INVALID)
.setSplitsArray(splitsArray);
Bucketizer loadedBucketizer =
- TestUtils.saveAndReload(tEnv, bucketizer,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ bucketizer,
+ tempFolder.newFolder().getAbsolutePath(),
+ Bucketizer::load);
Table output = loadedBucketizer.transform(inputTable)[0];
verifyOutputResult(output, loadedBucketizer.getOutputCols(),
expectedKeepResult);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java
index fb6be3f4..32d58ccc 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java
@@ -208,10 +208,17 @@ public class CountVectorizerTest extends AbstractTestBase
{
CountVectorizer countVectorizer = new CountVectorizer();
CountVectorizer loadedCountVectorizer =
TestUtils.saveAndReload(
- tEnv, countVectorizer,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ countVectorizer,
+ tempFolder.newFolder().getAbsolutePath(),
+ CountVectorizer::load);
CountVectorizerModel model = loadedCountVectorizer.fit(inputTable);
CountVectorizerModel loadedModel =
- TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ CountVectorizerModel::load);
assertEquals(
Arrays.asList("vocabulary"),
loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java
index 5e6f2b92..36baea68 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java
@@ -140,7 +140,8 @@ public class DCTTest extends AbstractTestBase {
DCT dct = new DCT().setInverse(true);
DCT loadedDCT =
- TestUtils.saveAndReload(tEnv, dct,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, dct,
TEMPORARY_FOLDER.newFolder().getAbsolutePath(), DCT::load);
Table outputTable = loadedDCT.transform(inputTable)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java
index ffcf238e..4b55f0dd 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java
@@ -159,7 +159,10 @@ public class ElementwiseProductTest extends
AbstractTestBase {
.setScalingVec(Vectors.dense(1.1, 1.1));
ElementwiseProduct loadedElementwiseProduct =
TestUtils.saveAndReload(
- tEnv, elementwiseProduct,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ elementwiseProduct,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ ElementwiseProduct::load);
Table output = loadedElementwiseProduct.transform(inputDataTable)[0];
verifyOutputResult(output, loadedElementwiseProduct.getOutputCol(),
false);
}
@@ -193,7 +196,10 @@ public class ElementwiseProductTest extends
AbstractTestBase {
Vectors.sparse(5, new int[] {0, 1}, new
double[] {1.1, 1.1}));
ElementwiseProduct loadedElementwiseProduct =
TestUtils.saveAndReload(
- tEnv, elementwiseProduct,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ elementwiseProduct,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ ElementwiseProduct::load);
Table output = loadedElementwiseProduct.transform(inputDataTable)[0];
verifyOutputResult(output, loadedElementwiseProduct.getOutputCol(),
true);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java
index a09b47ef..fae16692 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java
@@ -103,7 +103,10 @@ public class FeatureHasherTest extends AbstractTestBase {
.setNumFeatures(1000);
FeatureHasher loadedFeatureHasher =
TestUtils.saveAndReload(
- tEnv, featureHash,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ featureHash,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ FeatureHasher::load);
Table output = loadedFeatureHasher.transform(inputDataTable)[0];
verifyOutputResult(output, loadedFeatureHasher.getOutputCol());
}
@@ -117,7 +120,10 @@ public class FeatureHasherTest extends AbstractTestBase {
.setNumFeatures(1000);
FeatureHasher loadedFeatureHasher =
TestUtils.saveAndReload(
- tEnv, featureHash,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ featureHash,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ FeatureHasher::load);
Table output = loadedFeatureHasher.transform(inputDataTable)[0];
verifyOutputResult(output, loadedFeatureHasher.getOutputCol());
}
@@ -137,7 +143,10 @@ public class FeatureHasherTest extends AbstractTestBase {
.setNumFeatures(1000);
FeatureHasher loadedFeatureHasher =
TestUtils.saveAndReload(
- tEnv, featureHash,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ featureHash,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ FeatureHasher::load);
Table output = loadedFeatureHasher.transform(inputDataTable)[0];
verifyOutputResult(output, loadedFeatureHasher.getOutputCol());
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/HashingTFTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/HashingTFTest.java
index 938e4f19..920eb1b3 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/HashingTFTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/HashingTFTest.java
@@ -158,7 +158,10 @@ public class HashingTFTest extends AbstractTestBase {
HashingTF hashingTF = new HashingTF();
HashingTF loadedHashingTF =
TestUtils.saveAndReload(
- tEnv, hashingTF,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ hashingTF,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ HashingTF::load);
Table output = loadedHashingTF.transform(inputDataTable)[0];
verifyOutputResult(output, loadedHashingTF.getOutputCol(),
EXPECTED_OUTPUT);
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
index 4d4d8e1b..9bed73ec 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
@@ -141,10 +141,14 @@ public class IDFTest extends AbstractTestBase {
@Test
public void testSaveLoadAndPredict() throws Exception {
IDF idf = new IDF();
- idf = TestUtils.saveAndReload(tEnv, idf,
tempFolder.newFolder().getAbsolutePath());
+ idf =
+ TestUtils.saveAndReload(
+ tEnv, idf, tempFolder.newFolder().getAbsolutePath(),
IDF::load);
IDFModel model = idf.fit(inputTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv, model, tempFolder.newFolder().getAbsolutePath(),
IDFModel::load);
assertEquals(
Arrays.asList("idf", "docFreq", "numDocs"),
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java
index 760a9979..d8805a70 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java
@@ -213,10 +213,12 @@ public class ImputerTest extends AbstractTestBase {
.setInputCols("f1", "f2", "f3", "f4")
.setOutputCols("o1", "o2", "o3", "o4");
Imputer loadedImputer =
- TestUtils.saveAndReload(tEnv, imputer,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, imputer,
tempFolder.newFolder().getAbsolutePath(), Imputer::load);
ImputerModel model = loadedImputer.fit(trainDataTable);
ImputerModel loadedModel =
- TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, model, tempFolder.newFolder().getAbsolutePath(),
ImputerModel::load);
assertEquals(
Collections.singletonList("surrogates"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java
index e38f26cd..0b0cc995 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java
@@ -142,7 +142,10 @@ public class InteractionTest extends AbstractTestBase {
Interaction loadedInteraction =
TestUtils.saveAndReload(
- tEnv, interaction,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ interaction,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ Interaction::load);
Table output = loadedInteraction.transform(inputDataTable)[0];
verifyOutputResult(output, loadedInteraction.getOutputCol(),
EXPECTED_SPARSE_OUTPUT);
@@ -155,7 +158,10 @@ public class InteractionTest extends AbstractTestBase {
Interaction loadedInteraction =
TestUtils.saveAndReload(
- tEnv, interaction,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ interaction,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ Interaction::load);
Table output = loadedInteraction.transform(inputDataTable)[0];
verifyOutputResult(output, loadedInteraction.getOutputCol(),
EXPECTED_DENSE_OUTPUT);
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
index 47fecfbd..ed49a44b 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
@@ -206,10 +206,18 @@ public class KBinsDiscretizerTest extends
AbstractTestBase {
new
KBinsDiscretizer().setNumBins(3).setStrategy(KBinsDiscretizerParams.UNIFORM);
kBinsDiscretizer =
TestUtils.saveAndReload(
- tEnv, kBinsDiscretizer,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ kBinsDiscretizer,
+ tempFolder.newFolder().getAbsolutePath(),
+ KBinsDiscretizer::load);
KBinsDiscretizerModel model = kBinsDiscretizer.fit(trainTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ KBinsDiscretizerModel::load);
assertEquals(
Collections.singletonList("binEdges"),
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java
index ece5789d..6b8edf05 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java
@@ -207,11 +207,18 @@ public class MaxAbsScalerTest {
MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
MaxAbsScaler loadedMaxAbsScaler =
TestUtils.saveAndReload(
- tEnv, maxAbsScaler,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ maxAbsScaler,
+ tempFolder.newFolder().getAbsolutePath(),
+ MaxAbsScaler::load);
MaxAbsScalerModel model = loadedMaxAbsScaler.fit(trainDataTable);
MaxAbsScalerModel loadedModel =
- TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ MaxAbsScalerModel::load);
Table output = loadedModel.transform(predictDataTable)[0];
verifyPredictionResult(output, maxAbsScaler.getOutputCol(),
EXPECTED_DATA);
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java
index 9de24b88..d92f062d 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java
@@ -271,7 +271,8 @@ public class MinHashLSHTest extends AbstractTestBase {
.setNumHashFunctionsPerTable(3);
MinHashLSH loadedLsh =
- TestUtils.saveAndReload(tEnv, lsh,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, lsh, tempFolder.newFolder().getAbsolutePath(),
MinHashLSH::load);
MinHashLSHModel lshModel = loadedLsh.fit(inputTable);
Assert.assertEquals(
Arrays.asList(
@@ -295,7 +296,11 @@ public class MinHashLSHTest extends AbstractTestBase {
.setNumHashFunctionsPerTable(3);
MinHashLSHModel lshModel = lsh.fit(inputTable);
MinHashLSHModel loadedModel =
- TestUtils.saveAndReload(tEnv, lshModel,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ lshModel,
+ tempFolder.newFolder().getAbsolutePath(),
+ MinHashLSHModel::load);
Table output =
loadedModel.transform(inputTable)[0].select($(lsh.getOutputCol()));
verifyPredictionResult(output, outputRows);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
index 7c2ac0bb..32451635 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
@@ -174,10 +174,17 @@ public class MinMaxScalerTest extends AbstractTestBase {
MinMaxScaler minMaxScaler = new MinMaxScaler();
MinMaxScaler loadedMinMaxScaler =
TestUtils.saveAndReload(
- tEnv, minMaxScaler,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ minMaxScaler,
+ tempFolder.newFolder().getAbsolutePath(),
+ MinMaxScaler::load);
MinMaxScalerModel model = loadedMinMaxScaler.fit(trainDataTable);
MinMaxScalerModel loadedModel =
- TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ MinMaxScalerModel::load);
assertEquals(
Arrays.asList("minVector", "maxVector"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NGramTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NGramTest.java
index f7f4405c..01ea7609 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NGramTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NGramTest.java
@@ -100,7 +100,7 @@ public class NGramTest extends AbstractTestBase {
NGram nGram = new NGram();
NGram loadedNGram =
TestUtils.saveAndReload(
- tEnv, nGram,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv, nGram,
TEMPORARY_FOLDER.newFolder().getAbsolutePath(), NGram::load);
Table output = loadedNGram.transform(inputDataTable)[0];
verifyOutputResult(output, loadedNGram.getOutputCol());
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java
index 82f7d689..13e28c68 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java
@@ -144,7 +144,10 @@ public class NormalizerTest extends AbstractTestBase {
Normalizer loadedNormalizer =
TestUtils.saveAndReload(
- tEnv, normalizer,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ normalizer,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ Normalizer::load);
Table output = loadedNormalizer.transform(inputDataTable)[0];
verifyOutputResult(output, loadedNormalizer.getOutputCol(),
EXPECTED_DENSE_OUTPUT);
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
index 76682440..e5a6726a 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
@@ -263,9 +263,18 @@ public class OneHotEncoderTest extends AbstractTestBase {
@Test
public void testSaveLoad() throws Exception {
estimator =
- TestUtils.saveAndReload(tEnv, estimator,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ estimator,
+ tempFolder.newFolder().getAbsolutePath(),
+ OneHotEncoder::load);
OneHotEncoderModel model = estimator.fit(trainTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ OneHotEncoderModel::load);
Table outputTable = model.transform(predictTable)[0];
Map<Double, Vector>[] actualOutput =
executeAndCollect(outputTable, model.getInputCols(),
model.getOutputCols());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java
index 02798303..303713cd 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java
@@ -374,13 +374,19 @@ public class OnlineStandardScalerTest extends
AbstractTestBase {
standardScaler =
TestUtils.saveAndReload(
- tEnv, standardScaler,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ standardScaler,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ OnlineStandardScaler::load);
OnlineStandardScalerModel model =
standardScaler.fit(inputTableWithEventTime);
Table[] modelData = model.getModelData();
model =
TestUtils.saveAndReload(
- tEnv, model,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ model,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ OnlineStandardScalerModel::load);
model.setModelData(modelData);
Table output = model.transform(inputTableWithEventTime)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java
index 0754fb0b..5454eabf 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java
@@ -141,7 +141,10 @@ public class PolynomialExpansionTest extends
AbstractTestBase {
PolynomialExpansion loadedPolynomialExpansion =
TestUtils.saveAndReload(
- tEnv, polynomialExpansion,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ polynomialExpansion,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ PolynomialExpansion::load);
Table output = loadedPolynomialExpansion.transform(inputDataTable)[0];
verifyOutputResult(output, loadedPolynomialExpansion.getOutputCol(),
EXPECTED_DENSE_OUTPUT);
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java
index 66f08aed..75ad42ab 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java
@@ -131,7 +131,10 @@ public class RandomSplitterTest extends AbstractTestBase {
RandomSplitter splitterLoad =
TestUtils.saveAndReload(
- tEnv, randomSplitter,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ randomSplitter,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ RandomSplitter::load);
Table[] output = splitterLoad.transform(data);
List<Row> result0 =
IteratorUtils.toList(tEnv.toDataStream(output[0]).executeAndCollect());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RegexTokenizerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RegexTokenizerTest.java
index 96d284a2..3c357761 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RegexTokenizerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RegexTokenizerTest.java
@@ -148,7 +148,10 @@ public class RegexTokenizerTest extends AbstractTestBase {
Row.of((Object) new String[] {"te,st.", "punct"}));
RegexTokenizer loadedRegexTokenizer =
TestUtils.saveAndReload(
- tEnv, regexTokenizer,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ regexTokenizer,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ RegexTokenizer::load);
Table output = loadedRegexTokenizer.transform(inputDataTable)[0];
verifyOutputResult(output, loadedRegexTokenizer.getOutputCol(),
expectedRows);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java
index 940dcbe2..e8179024 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java
@@ -178,10 +178,17 @@ public class RobustScalerTest extends AbstractTestBase {
RobustScaler robustScaler = new RobustScaler();
RobustScaler loadedRobustScaler =
TestUtils.saveAndReload(
- tEnv, robustScaler,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ robustScaler,
+ tempFolder.newFolder().getAbsolutePath(),
+ RobustScaler::load);
RobustScalerModel model = loadedRobustScaler.fit(trainDataTable);
RobustScalerModel loadedModel =
- TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ RobustScalerModel::load);
assertEquals(
Arrays.asList("medians", "ranges"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/SQLTransformerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/SQLTransformerTest.java
index 3e0aac5c..15314881 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/SQLTransformerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/SQLTransformerTest.java
@@ -192,7 +192,10 @@ public class SQLTransformerTest extends AbstractTestBase {
SQLTransformer loadedSQLTransformer =
TestUtils.saveAndReload(
- tEnv, sqlTransformer,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ sqlTransformer,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ SQLTransformer::load);
Table outputTable = loadedSQLTransformer.transform(inputTable)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java
index b36cd885..7d08c70c 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java
@@ -191,10 +191,18 @@ public class StandardScalerTest extends AbstractTestBase {
StandardScaler standardScaler = new StandardScaler();
standardScaler =
TestUtils.saveAndReload(
- tEnv, standardScaler,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ standardScaler,
+ tempFolder.newFolder().getAbsolutePath(),
+ StandardScaler::load);
StandardScalerModel model = standardScaler.fit(denseTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ StandardScalerModel::load);
assertEquals(
Arrays.asList("mean", "std"),
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StopWordsRemoverTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StopWordsRemoverTest.java
index a8853173..cfd1aea7 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StopWordsRemoverTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StopWordsRemoverTest.java
@@ -342,7 +342,10 @@ public class StopWordsRemoverTest extends AbstractTestBase
{
StopWordsRemover loadedRemover =
TestUtils.saveAndReload(
- tEnv, remover,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ remover,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ StopWordsRemover::load);
verifyOutputResult(loadedRemover, inputTable);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/TokenizerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/TokenizerTest.java
index e10d4083..98ec873b 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/TokenizerTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/TokenizerTest.java
@@ -95,7 +95,10 @@ public class TokenizerTest extends AbstractTestBase {
Tokenizer tokenizer = new Tokenizer();
Tokenizer loadedTokenizer =
TestUtils.saveAndReload(
- tEnv, tokenizer,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ tokenizer,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ Tokenizer::load);
Table output = loadedTokenizer.transform(inputDataTable)[0];
verifyOutputResult(output, loadedTokenizer.getOutputCol(),
EXPECTED_OUTPUT);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java
index 9851debf..f722a78a 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java
@@ -674,10 +674,18 @@ public class UnivariateFeatureSelectorTest extends
AbstractTestBase {
.setSelectionThreshold(1);
UnivariateFeatureSelector loadSelector =
- TestUtils.saveAndReload(tEnv, selector,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ selector,
+ tempFolder.newFolder().getAbsolutePath(),
+ UnivariateFeatureSelector::load);
UnivariateFeatureSelectorModel model =
loadSelector.fit(inputANOVATable);
UnivariateFeatureSelectorModel loadedModel =
- TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ UnivariateFeatureSelectorModel::load);
assertEquals(
Collections.singletonList("indices"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
index 9d584adb..43a893ce 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
@@ -168,10 +168,17 @@ public class VarianceThresholdSelectorTest extends
AbstractTestBase {
new VarianceThresholdSelector().setVarianceThreshold(8.0);
VarianceThresholdSelector loadedVarianceThresholdSelector =
TestUtils.saveAndReload(
- tEnv, varianceThresholdSelector,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ varianceThresholdSelector,
+ tempFolder.newFolder().getAbsolutePath(),
+ VarianceThresholdSelector::load);
VarianceThresholdSelectorModel model =
loadedVarianceThresholdSelector.fit(trainDataTable);
VarianceThresholdSelectorModel loadedModel =
- TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ VarianceThresholdSelectorModel::load);
assertEquals(
Arrays.asList("numOfFeatures", "indices"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
index b2e8a65f..f70d95fc 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
@@ -359,7 +359,10 @@ public class VectorAssemblerTest extends AbstractTestBase {
VectorAssembler loadedVectorAssembler =
TestUtils.saveAndReload(
- tEnv, vectorAssembler,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ vectorAssembler,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ VectorAssembler::load);
Table output = loadedVectorAssembler.transform(inputDataTable)[0];
verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 3);
@@ -383,7 +386,10 @@ public class VectorAssemblerTest extends AbstractTestBase {
VectorAssembler loadedVectorAssembler =
TestUtils.saveAndReload(
- tEnv, vectorAssembler,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ vectorAssembler,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ VectorAssembler::load);
Table output = loadedVectorAssembler.transform(inputDataTable)[0];
verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 3);
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
index 0d4f1789..de69f4d5 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
@@ -198,10 +198,18 @@ public class VectorIndexerTest extends AbstractTestBase {
new
VectorIndexer().setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
vectorIndexer =
TestUtils.saveAndReload(
- tEnv, vectorIndexer,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ vectorIndexer,
+ tempFolder.newFolder().getAbsolutePath(),
+ VectorIndexer::load);
VectorIndexerModel model = vectorIndexer.fit(trainInputTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ VectorIndexerModel::load);
assertEquals(
Collections.singletonList("categoryMaps"),
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java
index 38d78a25..fb983800 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java
@@ -116,7 +116,10 @@ public class VectorSlicerTest extends AbstractTestBase {
new
VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(0, 1, 2);
VectorSlicer loadedVectorSlicer =
TestUtils.saveAndReload(
- tEnv, vectorSlicer,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ tEnv,
+ vectorSlicer,
+ TEMPORARY_FOLDER.newFolder().getAbsolutePath(),
+ VectorSlicer::load);
Table output = loadedVectorSlicer.transform(inputDataTable)[0];
verifyOutputResult(output, loadedVectorSlicer.getOutputCol(), false);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
index de0ec388..bcf2ce20 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
@@ -129,7 +129,12 @@ public class IndexToStringModelTest extends
AbstractTestBase {
.setInputCols("inputCol1", "inputCol2")
.setOutputCols("outputCol1", "outputCol2")
.setModelData(modelTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ IndexToStringModel::load);
assertEquals(
Collections.singletonList("stringArrays"),
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
index 2b41ddb9..56f353e5 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
@@ -284,10 +284,18 @@ public class StringIndexerTest extends AbstractTestBase {
.setHandleInvalid(StringIndexerParams.KEEP_INVALID);
stringIndexer =
TestUtils.saveAndReload(
- tEnv, stringIndexer,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ stringIndexer,
+ tempFolder.newFolder().getAbsolutePath(),
+ StringIndexer::load);
StringIndexerModel model = stringIndexer.fit(trainTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ StringIndexerModel::load);
assertEquals(
Collections.singletonList("stringArrays"),
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
index 8c17fdd4..cfa0f9cd 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
@@ -222,7 +222,8 @@ public class SwingTest {
public void testSaveLoadAndTransform() throws Exception {
Swing swing = new Swing().setMinUserBehavior(1);
Swing loadedSwing =
- TestUtils.saveAndReload(tEnv, swing,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, swing, tempFolder.newFolder().getAbsolutePath(),
Swing::load);
Table outputTable = loadedSwing.transform(inputTable)[0];
List<Row> results =
IteratorUtils.toList(outputTable.execute().collect());
compareResultAndExpected(results);
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
index 9776965f..1c051e17 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
@@ -187,9 +187,17 @@ public class LinearRegressionTest extends AbstractTestBase
{
LinearRegression linearRegression = new
LinearRegression().setWeightCol("weight");
linearRegression =
TestUtils.saveAndReload(
- tEnv, linearRegression,
tempFolder.newFolder().getAbsolutePath());
+ tEnv,
+ linearRegression,
+ tempFolder.newFolder().getAbsolutePath(),
+ LinearRegression::load);
LinearRegressionModel model = linearRegression.fit(trainDataTable);
- model = TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ model =
+ TestUtils.saveAndReload(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ LinearRegressionModel::load);
assertEquals(
Collections.singletonList("coefficient"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java
index 86ae5bf4..1b227085 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java
@@ -394,7 +394,8 @@ public class ANOVATestTest extends AbstractTestBase {
public void testSaveLoadAndTransform() throws Exception {
ANOVATest anovaTest = new ANOVATest();
ANOVATest loadedANOVATest =
- TestUtils.saveAndReload(tEnv, anovaTest,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, anovaTest,
tempFolder.newFolder().getAbsolutePath(), ANOVATest::load);
Table output = loadedANOVATest.transform(denseInputTable)[0];
verifyTransformationResult(output, EXPECTED_OUTPUT_DENSE);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ChiSqTestTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ChiSqTestTest.java
index 706edfa7..87a3f376 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ChiSqTestTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ChiSqTestTest.java
@@ -188,7 +188,8 @@ public class ChiSqTestTest extends AbstractTestBase {
ChiSqTest chiSqTest = new
ChiSqTest().setFeaturesCol("features").setLabelCol("label");
ChiSqTest loadedChiSqTest =
- TestUtils.saveAndReload(tEnv, chiSqTest,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv, chiSqTest,
tempFolder.newFolder().getAbsolutePath(), ChiSqTest::load);
Table output1 =
loadedChiSqTest.transform(inputTableWithDoubleLabel)[0];
verifyPredictionResult(output1,
expectedChiSqTestResultWithDoubleLabel);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java
index 156006f1..faf8fe57 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java
@@ -412,7 +412,11 @@ public class FValueTestTest extends AbstractTestBase {
public void testSaveLoadAndTransform() throws Exception {
FValueTest fValueTest = new FValueTest();
FValueTest loadedFValueTest =
- TestUtils.saveAndReload(tEnv, fValueTest,
tempFolder.newFolder().getAbsolutePath());
+ TestUtils.saveAndReload(
+ tEnv,
+ fValueTest,
+ tempFolder.newFolder().getAbsolutePath(),
+ FValueTest::load);
Table output = loadedFValueTest.transform(denseInputTable)[0];
verifyTransformationResult(output, EXPECTED_OUTPUT_DENSE);
}