lindong28 commented on a change in pull request #73: URL: https://github.com/apache/flink-ml/pull/73#discussion_r841012351
########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerParams.java ########## @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.standardscaler; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.BooleanParam; +import org.apache.flink.ml.param.Param; + +/** + * Params for {@link StandardScaler}. + * + * @param <T> The class type of this instance. + */ +public interface StandardScalerParams<T> extends HasFeaturesCol<T>, HasPredictionCol<T> { Review comment: Spark's `StandardScaler` uses `inputCol` and `outputCol` as its parameters. Would it be better to follow the same approach here? Note that this `StandardScaler` is used to transform features rather than making predictions. Thus it does not seem intuitive to name its output as `predictionCol`. We can submit a followup PR to fix this problem. `HasFeaturesCol` should probably be renamed to `HasFeatureCol` as it represents just one column. And `HasInputCol` is more consistent with `HasOutputCol` than `HasFeatureCol` does. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java ########## @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.standardscaler.StandardScaler; +import org.apache.flink.ml.feature.standardscaler.StandardScalerModel; +import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** Tests {@link StandardScaler} and {@link StandardScalerModel}. */ +public class StandardScalerTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table denseTable; + + private final List<Row> denseInput = + Arrays.asList( + Row.of(Vectors.dense(-2.5, 9, 1)), + Row.of(Vectors.dense(1.4, -5, 1)), + Row.of(Vectors.dense(2, -1, -2))); + + private final List<DenseVector> expectedResWithMean = + Arrays.asList( + Vectors.dense(-2.8, 8, 1), + Vectors.dense(1.1, -6, 1), + Vectors.dense(1.7, -2, -2)); + + private final List<DenseVector> expectedResWithStd = + Arrays.asList( + Vectors.dense(-1.0231819, 1.2480754, 0.5773502), + Vectors.dense(0.5729819, -0.6933752, 0.5773503), + Vectors.dense(0.8185455, -0.1386750, -1.1547005)); + + private final List<DenseVector> expectedResWithMeanAndStd = + Arrays.asList( + Vectors.dense(-1.1459637, 1.1094004, 0.5773503), + Vectors.dense(0.45020003, -0.8320503, 0.5773503), + Vectors.dense(0.69576368, -0.2773501, -1.1547005)); + + private final double[] expectedMean = new double[] {0.3, 1, 0}; + private final double[] expectedStd = new double[] {2.4433583, 7.2111026, 1.7320508}; + private static final double TOLERANCE = 1e-7; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + denseTable = tEnv.fromDataStream(env.fromCollection(denseInput)).as("features"); + } + + @SuppressWarnings("unchecked") + private void verifyPredictionResult( + List<DenseVector> expectedOutput, Table output, String predictionCol) throws Exception { + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + List<DenseVector> predictions = new ArrayList<>(collectedResult.size()); + + for (Row r : collectedResult) { + Vector vec = (Vector) r.getField(predictionCol); + predictions.add(vec.toDense()); + } + + assertEquals(expectedOutput.size(), predictions.size()); + + predictions.sort( + (vec1, vec2) -> { + int size = Math.min(vec1.size(), vec2.size()); + for (int i = 0; i < size; i++) { + int cmp = Double.compare(vec1.get(i), vec2.get(i)); + if (cmp != 0) { + return cmp; + } + } + return 0; + }); + + for (int i = 0; i < predictions.size(); i++) { + assertArrayEquals(expectedOutput.get(i).values, predictions.get(i).values, TOLERANCE); + } + } + + @Test + public void testParam() { + StandardScaler standardScaler = new StandardScaler(); + + assertEquals("features", standardScaler.getFeaturesCol()); + assertEquals(false, standardScaler.getWithMean()); + assertEquals(true, standardScaler.getWithStd()); + assertEquals("prediction", standardScaler.getPredictionCol()); + + standardScaler + .setFeaturesCol("test_features") + .setWithMean(true) + .setWithStd(false) + .setPredictionCol("test_prediction"); + + assertEquals("test_features", standardScaler.getFeaturesCol()); + assertEquals(true, standardScaler.getWithMean()); + assertEquals(false, standardScaler.getWithStd()); + assertEquals("test_prediction", standardScaler.getPredictionCol()); + } + + @Test + public void testFeaturePredictionParam() { Review comment: nits: suppose we change the parameter to inputCol and outputCol, it might not be intuitive to name this method as `testFeaturePredictionParam`. How about changing it to `testOutputSchema()`? We can use this name consistently across other tests. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
