lindong28 commented on code in PR #196: URL: https://github.com/apache/flink-ml/pull/196#discussion_r1093183274
########## docs/content/docs/operators/feature/onlinestandardscaler.md: ########## @@ -0,0 +1,261 @@ +--- +title: "OnlineStandardScaler" +weight: 1 +type: docs +aliases: +- /operators/feature/onlinestandardscaler.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## OnlineStandardScaler + +An Estimator which implements the online standard scaling algorithm, which +is the online version of StandardScaler. + +OnlineStandardScaler splits the input data by the user-specified window strategy. +For each window, it computes the mean and standard deviation using the data seen +so far (i.e., not only the data in the current window, but also the history data). +The model data generated by OnlineStandardScaler is a model stream. +There is one model data for each window. + +During the inference phase (i.e., using OnlineStandardScalerModel for prediction), +users could output the model version that is used for predicting each data point. +Moreover, +- When the train data and test data both contain event time, users could +specify the maximum difference between the timestamps of the input and model data, +which enforces to use a relatively fresh model for prediction. +- Otherwise, the prediction process always uses the current model data for prediction. + + +### Input Columns + +| Param name | Type | Default | Description | +|:-----------|:-------|:----------|:-----------------------| +| inputCol | Vector | `"input"` | Features to be scaled. | + +### Output Columns + +| Param name | Type | Default | Description | +|:----------------|:-------|:-----------|:--------------------------------------------------------------------------------------| +| outputCol | Vector | `"output"` | Scaled features. | +| modelVersionCol | String | `version` | The version (in long format) of the model data that the input data is predicted with. | + +### Parameters + +Below are the parameters required by `OnlineStandardScalerModel`. + +| Key | Default | Type | Required | Description | +|------------------------|------------|---------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| inputCol | `"input"` | String | no | Input column name. | +| outputCol | `"output"` | String | no | Output column name. | +| withMean | `false` | Boolean | no | Whether centers the data with mean before scaling. | +| withStd | `true` | Boolean | no | Whether scales the data with standard deviation. | +| modelVersionCol | `null` | String | no | The version (in long format) of the model data that the input data is predicted with. | +| maxAllowedModelDelayMs | `0L` | Long | no | The maximum difference allowed between the timestamps of the input record and the model data that is used to predict that input record. This param only works when the input contains event time. | + +`OnlineStandardScaler` needs parameters above and also below. + +| Key | Default | Type | Required | Description | +|---------|-------------------------------|---------|----------|--------------------------------------------------------------------------------| +| windows | `GlobalWindows.getInstance()` | Windows | no | Windowing strategy that determines how to create mini-batches from input data. | + + +### Examples + +{{< tabs examples >}} + +{{< tab "Java">}} + +```java +import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.ml.common.window.EventTimeTumblingWindows; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import java.util.Arrays; +import java.util.List; + +/** Simple program that trains a OnlineStandardScaler model and uses it for feature engineering. */ +public class OnlineStandardScalerExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + List<Row> inputData = + Arrays.asList( + Row.of(0L, Vectors.dense(-2.5, 9, 1)), + Row.of(1000L, Vectors.dense(1.4, -5, 1)), + Row.of(2000L, Vectors.dense(2, -1, -2)), + Row.of(6000L, Vectors.dense(0.7, 3, 1)), + Row.of(7000L, Vectors.dense(0, 1, 1)), + Row.of(8000L, Vectors.dense(0.5, 0, -2)), + Row.of(9000L, Vectors.dense(0.4, 1, 1)), + Row.of(10000L, Vectors.dense(0.3, 2, 1)), + Row.of(11000L, Vectors.dense(0.5, 1, -2))); + + DataStream<Row> inputStream = env.fromCollection(inputData); + + DataStream<Row> inputStreamWithEventTime = + inputStream.assignTimestampsAndWatermarks( + WatermarkStrategy.<Row>forMonotonousTimestamps() + .withTimestampAssigner( + (SerializableTimestampAssigner<Row>) + (element, recordTimestamp) -> + element.getFieldAs(0))); + + Table inputTable = + tEnv.fromDataStream( + inputStreamWithEventTime, + Schema.newBuilder() + .column("f0", DataTypes.BIGINT()) + .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build()) + .as("id", "input"); + + // Creates an OnlineStandardScaler object and initializes its parameters. + long windowSizeMs = 3000; + OnlineStandardScaler onlineStandardScaler = + new OnlineStandardScaler() + .setWindows(EventTimeTumblingWindows.of(Time.milliseconds(windowSizeMs))) + .setModelVersionCol("modelVersionCol"); Review Comment: Now that we have default value for this parameter, would it be simpler to skip setting it? It would also be more consistent the pattern of not explicitly setting input/output columns. Same for `OnlineStandardScalerExample`. ########## docs/content/docs/operators/feature/onlinestandardscaler.md: ########## @@ -0,0 +1,261 @@ +--- +title: "OnlineStandardScaler" +weight: 1 +type: docs +aliases: +- /operators/feature/onlinestandardscaler.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## OnlineStandardScaler + +An Estimator which implements the online standard scaling algorithm, which +is the online version of StandardScaler. + +OnlineStandardScaler splits the input data by the user-specified window strategy. +For each window, it computes the mean and standard deviation using the data seen +so far (i.e., not only the data in the current window, but also the history data). +The model data generated by OnlineStandardScaler is a model stream. +There is one model data for each window. + +During the inference phase (i.e., using OnlineStandardScalerModel for prediction), +users could output the model version that is used for predicting each data point. +Moreover, +- When the train data and test data both contain event time, users could +specify the maximum difference between the timestamps of the input and model data, +which enforces to use a relatively fresh model for prediction. +- Otherwise, the prediction process always uses the current model data for prediction. + + +### Input Columns + +| Param name | Type | Default | Description | +|:-----------|:-------|:----------|:-----------------------| +| inputCol | Vector | `"input"` | Features to be scaled. | + +### Output Columns + +| Param name | Type | Default | Description | +|:----------------|:-------|:-----------|:--------------------------------------------------------------------------------------| +| outputCol | Vector | `"output"` | Scaled features. | +| modelVersionCol | String | `version` | The version (in long format) of the model data that the input data is predicted with. | + +### Parameters + +Below are the parameters required by `OnlineStandardScalerModel`. + +| Key | Default | Type | Required | Description | +|------------------------|------------|---------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| inputCol | `"input"` | String | no | Input column name. | +| outputCol | `"output"` | String | no | Output column name. | +| withMean | `false` | Boolean | no | Whether centers the data with mean before scaling. | +| withStd | `true` | Boolean | no | Whether scales the data with standard deviation. | +| modelVersionCol | `null` | String | no | The version (in long format) of the model data that the input data is predicted with. | +| maxAllowedModelDelayMs | `0L` | Long | no | The maximum difference allowed between the timestamps of the input record and the model data that is used to predict that input record. This param only works when the input contains event time. | + +`OnlineStandardScaler` needs parameters above and also below. + +| Key | Default | Type | Required | Description | +|---------|-------------------------------|---------|----------|--------------------------------------------------------------------------------| +| windows | `GlobalWindows.getInstance()` | Windows | no | Windowing strategy that determines how to create mini-batches from input data. | + + +### Examples + +{{< tabs examples >}} + +{{< tab "Java">}} + +```java +import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.ml.common.window.EventTimeTumblingWindows; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import java.util.Arrays; +import java.util.List; + +/** Simple program that trains a OnlineStandardScaler model and uses it for feature engineering. */ +public class OnlineStandardScalerExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + List<Row> inputData = + Arrays.asList( + Row.of(0L, Vectors.dense(-2.5, 9, 1)), + Row.of(1000L, Vectors.dense(1.4, -5, 1)), + Row.of(2000L, Vectors.dense(2, -1, -2)), + Row.of(6000L, Vectors.dense(0.7, 3, 1)), + Row.of(7000L, Vectors.dense(0, 1, 1)), + Row.of(8000L, Vectors.dense(0.5, 0, -2)), + Row.of(9000L, Vectors.dense(0.4, 1, 1)), + Row.of(10000L, Vectors.dense(0.3, 2, 1)), + Row.of(11000L, Vectors.dense(0.5, 1, -2))); + + DataStream<Row> inputStream = env.fromCollection(inputData); + + DataStream<Row> inputStreamWithEventTime = + inputStream.assignTimestampsAndWatermarks( + WatermarkStrategy.<Row>forMonotonousTimestamps() + .withTimestampAssigner( + (SerializableTimestampAssigner<Row>) + (element, recordTimestamp) -> + element.getFieldAs(0))); + + Table inputTable = + tEnv.fromDataStream( + inputStreamWithEventTime, + Schema.newBuilder() + .column("f0", DataTypes.BIGINT()) + .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build()) + .as("id", "input"); + + // Creates an OnlineStandardScaler object and initializes its parameters. + long windowSizeMs = 3000; + OnlineStandardScaler onlineStandardScaler = + new OnlineStandardScaler() + .setWindows(EventTimeTumblingWindows.of(Time.milliseconds(windowSizeMs))) + .setModelVersionCol("modelVersionCol"); + + // Trains the OnlineStandardScaler Model. + OnlineStandardScalerModel model = onlineStandardScaler.fit(inputTable); + + // Uses the OnlineStandardScaler Model for predictions. + Table outputTable = model.transform(inputTable)[0]; + + // Extracts and displays the results. + for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + DenseVector inputValue = (DenseVector) row.getField(onlineStandardScaler.getInputCol()); + DenseVector outputValue = + (DenseVector) row.getField(onlineStandardScaler.getOutputCol()); + long modelVersion = row.getFieldAs(onlineStandardScaler.getModelVersionCol()); + System.out.printf( + "Input Value: %s\tOutput Value: %s\tModel Version: %s\n", + inputValue, outputValue, modelVersion); + } + } +} + +``` + +{{< /tab>}} + +{{< tab "Python">}} + +```python +# Simple program that trains an OnlineStandardScaler model and uses it for feature +# engineering. + +from pyflink.common import Types +from pyflink.common.time import Time, Instant +from pyflink.java_gateway import get_gateway +from pyflink.table import Schema +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.table import StreamTableEnvironment +from pyflink.table.expressions import col + +from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo +from pyflink.ml.lib.feature.onlinestandardscaler import OnlineStandardScaler +from pyflink.ml.core.windows import EventTimeTumblingWindows + +# Creates a new StreamExecutionEnvironment. +env = StreamExecutionEnvironment.get_execution_environment() + +# Creates a StreamTableEnvironment. +t_env = StreamTableEnvironment.create(env) + +# Generates input data. +dense_vector_serializer = get_gateway().jvm.org.apache.flink.table.types.logical.RawType( + get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(0).getClass(), + get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer() +).getSerializerString() + +schema = Schema.new_builder() + .column("ts", "TIMESTAMP_LTZ(3)") + .column("input", "RAW('org.apache.flink.ml.linalg.DenseVector', '{serializer}')" + .format(serializer=dense_vector_serializer)) + .watermark("ts", "ts - INTERVAL '1' SECOND") + .build() + +input_data = t_env.from_data_stream( + env.from_collection([ + (Instant.of_epoch_milli(0), Vectors.dense(-2.5, 9, 1),), + (Instant.of_epoch_milli(1000), Vectors.dense(1.4, -5, 1),), + (Instant.of_epoch_milli(2000), Vectors.dense(2, -1, -2),), + (Instant.of_epoch_milli(6000), Vectors.dense(0.7, 3, 1),), + (Instant.of_epoch_milli(7000), Vectors.dense(0, 1, 1),), + (Instant.of_epoch_milli(8000), Vectors.dense(0.5, 0, -2),), + (Instant.of_epoch_milli(9000), Vectors.dense(0.4, 1, 1),), + (Instant.of_epoch_milli(10000), Vectors.dense(0.3, 2, 1),), + (Instant.of_epoch_milli(11000), Vectors.dense(0.5, 1, -2),) + ], + type_info=Types.ROW_NAMED( + ['ts', 'input'], + [Types.INSTANT(), DenseVectorTypeInfo()])), + schema) + +# Creates an online standard-scaler object and initialize its parameters. +standard_scaler = OnlineStandardScaler() + .set_windows(EventTimeTumblingWindows.of(Time.milliseconds(3000))) + .set_model_version_col('model_version_col') Review Comment: Remove this? ########## flink-ml-python/pyflink/ml/lib/feature/tests/test_onlinestandardscaler.py: ########## @@ -0,0 +1,216 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from pyflink.common import Types +from pyflink.common.time import Time, Instant +from pyflink.java_gateway import get_gateway +from pyflink.table import Schema +from pyflink.table.types import DataTypes +from pyflink.table.expressions import col + +from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo +from pyflink.ml.lib.feature.onlinestandardscaler import OnlineStandardScaler, \ + OnlineStandardScalerModel +from pyflink.ml.tests.test_utils import PyFlinkMLTestCase, update_existing_params +from pyflink.ml.core.windows import GlobalWindows, EventTimeTumblingWindows + + +class OnlineStandardScalerTest(PyFlinkMLTestCase): + def setUp(self): + super(OnlineStandardScalerTest, self).setUp() + + dense_vector_serializer = get_gateway().jvm.org.apache.flink.table.types.logical.RawType( + get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(0).getClass(), + get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer() + ).getSerializerString() + + schema = Schema.new_builder() \ + .column("ts_in_long", DataTypes.BIGINT()) \ + .column("ts", "TIMESTAMP_LTZ(3)") \ + .column("input", "RAW('org.apache.flink.ml.linalg.DenseVector', '{serializer}')" + .format(serializer=dense_vector_serializer)) \ + .watermark("ts", "ts - INTERVAL '1' SECOND") \ + .build() + + self.input_table = self.t_env.from_data_stream( + self.env.from_collection([ + (0, Instant.of_epoch_milli(0), Vectors.dense(-2.5, 9, 1),), + (1000, Instant.of_epoch_milli(1000), Vectors.dense(1.4, -5, 1),), + (2000, Instant.of_epoch_milli(2000), Vectors.dense(2, -1, -2),), + (6000, Instant.of_epoch_milli(6000), Vectors.dense(0.7, 3, 1),), + (7000, Instant.of_epoch_milli(7000), Vectors.dense(0, 1, 1),), + (8000, Instant.of_epoch_milli(8000), Vectors.dense(0.5, 0, -2),), + (9000, Instant.of_epoch_milli(9000), Vectors.dense(0.4, 1, 1),), + (10000, Instant.of_epoch_milli(10000), Vectors.dense(0.3, 2, 1),), + (11000, Instant.of_epoch_milli(11000), Vectors.dense(0.5, 1, -2),) + ], + type_info=Types.ROW_NAMED( + ['ts_in_long', 'ts', 'input'], + [Types.LONG(), Types.INSTANT(), DenseVectorTypeInfo()])), + schema) + + self.window_size_ms = 3000 + + self.expected_model_data = [ + [ + Vectors.dense(0.3, 1, 0), + Vectors.dense(2.4433583, 7.2111026, 1.7320508), + 0, + 2999 + ], + [ + Vectors.dense(0.35, 1.1666667, 0), + Vectors.dense(1.5630099, 4.6654760, 1.5491933), + 1, + 8999 + ], + [ + Vectors.dense(0.3666667, 1.2222222, 0), + Vectors.dense(1.2369316, 3.7006005, 1.5), + 2, + 11999 + ] + ] + + self.tolerance = 1e-7 + + def test_param(self): + standard_scaler = OnlineStandardScaler() + + self.assertEqual('input', standard_scaler.input_col) + self.assertEqual(False, standard_scaler.with_mean) + self.assertEqual(True, standard_scaler.with_std) + self.assertEqual('output', standard_scaler.output_col) + self.assertEqual('version', standard_scaler.model_version_col) + self.assertEqual(GlobalWindows(), standard_scaler.windows) + self.assertEqual(0, standard_scaler.max_allowed_model_delay_ms) + + standard_scaler.set_input_col('test_input') \ + .set_with_mean(True) \ + .set_with_std(False) \ + .set_output_col('test_output') \ + .set_model_version_col('model_version_col') \ Review Comment: Change the value to something that is not the default value? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java: ########## @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.standardscaler; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.metrics.MLMetrics; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** A Model which transforms data using the model data computed by {@link OnlineStandardScaler}. */ +public class OnlineStandardScalerModel + implements Model<OnlineStandardScalerModel>, + OnlineStandardScalerModelParams<OnlineStandardScalerModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public OnlineStandardScalerModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + String modelVersionCol = getModelVersionCol(); + + TypeInformation<?>[] outputTypes; + String[] outputNames; + if (modelVersionCol == null) { + outputTypes = ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE); + outputNames = ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()); + } else { + outputTypes = + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE, Types.LONG); + outputNames = + ArrayUtils.addAll( + inputTypeInfo.getFieldNames(), getOutputCol(), modelVersionCol); + } + RowTypeInfo outputTypeInfo = new RowTypeInfo(outputTypes, outputNames); + + DataStream<Row> predictionResult = + tEnv.toDataStream(inputs[0]) + .connect( + StandardScalerModelData.getModelDataStream(modelDataTable) + .broadcast()) + .transform( + "PredictionOperator", + outputTypeInfo, + new PredictionOperator( + inputTypeInfo, + getInputCol(), + getWithMean(), + getWithStd(), + getMaxAllowedModelDelayMs(), + getModelVersionCol())); + + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + /** A utility operator used for prediction. */ + @SuppressWarnings({"unchecked", "rawtypes"}) + private static class PredictionOperator extends AbstractStreamOperator<Row> + implements TwoInputStreamOperator<Row, StandardScalerModelData, Row> { + private final RowTypeInfo inputTypeInfo; + + private final String inputCol; + + private final boolean withMean; + + private final boolean withStd; + + private final long maxAllowedModelDelayMs; + + private final String modelVersionCol; + + private ListState<StreamRecord> bufferedPointsState; + + private ListState<StandardScalerModelData> modelDataState; + + /** Model data for inference. */ + private StandardScalerModelData modelData; + + private DenseVector mean; + + /** Inverse of standard deviation. */ + private DenseVector scale; + + private long modelVersion; + + private long modelTimeStamp; + + public PredictionOperator( + RowTypeInfo inputTypeInfo, + String inputCol, + boolean withMean, + boolean withStd, + long maxAllowedModelDelayMs, + String modelVersionCol) { + this.inputTypeInfo = inputTypeInfo; + this.inputCol = inputCol; + this.withMean = withMean; + this.withStd = withStd; + this.maxAllowedModelDelayMs = maxAllowedModelDelayMs; + this.modelVersionCol = modelVersionCol; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + bufferedPointsState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<StreamRecord>( + "bufferedPoints", + new StreamElementSerializer( + inputTypeInfo.createSerializer( + getExecutionConfig())))); + + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "modelData", + TypeInformation.of(StandardScalerModelData.class))); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + if (modelData != null) { + modelDataState.clear(); + modelDataState.add(modelData); + } + } + + @Override + public void open() throws Exception { + super.open(); + MetricGroup mlModelMetricGroup = + getRuntimeContext() + .getMetricGroup() + .addGroup(MLMetrics.ML_GROUP) + .addGroup( + MLMetrics.ML_MODEL_GROUP, + OnlineStandardScalerModel.class.getSimpleName()); + mlModelMetricGroup.gauge(MLMetrics.TIMESTAMP, (Gauge<Long>) () -> modelTimeStamp); + mlModelMetricGroup.gauge(MLMetrics.VERSION, (Gauge<Long>) () -> modelVersion); + } + + @Override + public void processElement1(StreamRecord<Row> dataPoint) throws Exception { + if (dataPoint.getTimestamp() - maxAllowedModelDelayMs <= modelTimeStamp + && mean != null) { Review Comment: Is this line needed to make sure that we have received model data before processing data? Do we still need it if set the default value of modelTimeStamp to `Long.MIN_VALUE`? ########## flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OnlineStandardScalerExample.java: ########## @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.examples.feature; + +import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.ml.common.window.EventTimeTumblingWindows; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import java.util.Arrays; +import java.util.List; + +/** Simple program that trains a OnlineStandardScaler model and uses it for feature engineering. */ +public class OnlineStandardScalerExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + List<Row> inputData = + Arrays.asList( + Row.of(0L, Vectors.dense(-2.5, 9, 1)), + Row.of(1000L, Vectors.dense(1.4, -5, 1)), + Row.of(2000L, Vectors.dense(2, -1, -2)), + Row.of(6000L, Vectors.dense(0.7, 3, 1)), + Row.of(7000L, Vectors.dense(0, 1, 1)), + Row.of(8000L, Vectors.dense(0.5, 0, -2)), + Row.of(9000L, Vectors.dense(0.4, 1, 1)), + Row.of(10000L, Vectors.dense(0.3, 2, 1)), + Row.of(11000L, Vectors.dense(0.5, 1, -2))); + + DataStream<Row> inputStream = env.fromCollection(inputData); + + DataStream<Row> inputStreamWithEventTime = + inputStream.assignTimestampsAndWatermarks( + WatermarkStrategy.<Row>forMonotonousTimestamps() + .withTimestampAssigner( + (SerializableTimestampAssigner<Row>) + (element, recordTimestamp) -> + element.getFieldAs(0))); + + Table inputTable = + tEnv.fromDataStream( + inputStreamWithEventTime, + Schema.newBuilder() + .column("f0", DataTypes.BIGINT()) + .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build()) + .as("id", "input"); + + // Creates an OnlineStandardScaler object and initializes its parameters. + long windowSizeMs = 3000; + OnlineStandardScaler onlineStandardScaler = + new OnlineStandardScaler() + .setWindows(EventTimeTumblingWindows.of(Time.milliseconds(windowSizeMs))) + .setModelVersionCol("modelVersionCol"); Review Comment: Remove this line? ########## flink-ml-python/pyflink/examples/ml/feature/online_standardscaler_example.py: ########## @@ -0,0 +1,91 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Simple program that trains an OnlineStandardScaler model and uses it for feature +# engineering. + +from pyflink.common import Types +from pyflink.common.time import Time, Instant +from pyflink.java_gateway import get_gateway +from pyflink.table import Schema +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.table import StreamTableEnvironment +from pyflink.table.expressions import col + +from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo +from pyflink.ml.lib.feature.onlinestandardscaler import OnlineStandardScaler +from pyflink.ml.core.windows import EventTimeTumblingWindows + +# Creates a new StreamExecutionEnvironment. +env = StreamExecutionEnvironment.get_execution_environment() + +# Creates a StreamTableEnvironment. +t_env = StreamTableEnvironment.create(env) + +# Generates input data. +dense_vector_serializer = get_gateway().jvm.org.apache.flink.table.types.logical.RawType( + get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(0).getClass(), + get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer() +).getSerializerString() + +schema = Schema.new_builder() \ + .column("ts", "TIMESTAMP_LTZ(3)") \ + .column("input", "RAW('org.apache.flink.ml.linalg.DenseVector', '{serializer}')" + .format(serializer=dense_vector_serializer)) \ + .watermark("ts", "ts - INTERVAL '1' SECOND") \ + .build() + +input_data = t_env.from_data_stream( + env.from_collection([ + (Instant.of_epoch_milli(0), Vectors.dense(-2.5, 9, 1),), + (Instant.of_epoch_milli(1000), Vectors.dense(1.4, -5, 1),), + (Instant.of_epoch_milli(2000), Vectors.dense(2, -1, -2),), + (Instant.of_epoch_milli(6000), Vectors.dense(0.7, 3, 1),), + (Instant.of_epoch_milli(7000), Vectors.dense(0, 1, 1),), + (Instant.of_epoch_milli(8000), Vectors.dense(0.5, 0, -2),), + (Instant.of_epoch_milli(9000), Vectors.dense(0.4, 1, 1),), + (Instant.of_epoch_milli(10000), Vectors.dense(0.3, 2, 1),), + (Instant.of_epoch_milli(11000), Vectors.dense(0.5, 1, -2),) + ], + type_info=Types.ROW_NAMED( + ['ts', 'input'], + [Types.INSTANT(), DenseVectorTypeInfo()])), + schema) + +# Creates an online standard-scaler object and initialize its parameters. +standard_scaler = OnlineStandardScaler() \ + .set_windows(EventTimeTumblingWindows.of(Time.milliseconds(3000))) \ + .set_model_version_col('model_version_col') \ Review Comment: Remove this line? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasModelVersionCol.java: ########## @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared model version column param. */ +public interface HasModelVersionCol<T> extends WithParams<T> { + Param<String> MODEL_VERSION_COL = + new StringParam( + "modelVersionCol", + "The version (in long format) of the model data that the input data is predicted with.", Review Comment: The doc seems confusing. The semantics of this doc is actually "The version column name", not the version of the model. How about using the following doc: The name of the column which contains the version of the model data that the input data is predicted with. The version should be a 64-bit integer. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java: ########## @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.standardscaler; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.metrics.MLMetrics; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** A Model which transforms data using the model data computed by {@link OnlineStandardScaler}. */ +public class OnlineStandardScalerModel + implements Model<OnlineStandardScalerModel>, + OnlineStandardScalerModelParams<OnlineStandardScalerModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public OnlineStandardScalerModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + String modelVersionCol = getModelVersionCol(); + + TypeInformation<?>[] outputTypes; + String[] outputNames; + if (modelVersionCol == null) { + outputTypes = ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE); + outputNames = ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()); + } else { + outputTypes = + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE, Types.LONG); + outputNames = + ArrayUtils.addAll( + inputTypeInfo.getFieldNames(), getOutputCol(), modelVersionCol); + } + RowTypeInfo outputTypeInfo = new RowTypeInfo(outputTypes, outputNames); + + DataStream<Row> predictionResult = + tEnv.toDataStream(inputs[0]) + .connect( + StandardScalerModelData.getModelDataStream(modelDataTable) + .broadcast()) + .transform( + "PredictionOperator", + outputTypeInfo, + new PredictionOperator( + inputTypeInfo, + getInputCol(), + getWithMean(), + getWithStd(), + getMaxAllowedModelDelayMs(), + getModelVersionCol())); + + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + /** A utility operator used for prediction. */ + @SuppressWarnings({"unchecked", "rawtypes"}) + private static class PredictionOperator extends AbstractStreamOperator<Row> + implements TwoInputStreamOperator<Row, StandardScalerModelData, Row> { + private final RowTypeInfo inputTypeInfo; + + private final String inputCol; + + private final boolean withMean; + + private final boolean withStd; + + private final long maxAllowedModelDelayMs; + + private final String modelVersionCol; + + private ListState<StreamRecord> bufferedPointsState; + + private ListState<StandardScalerModelData> modelDataState; + + /** Model data for inference. */ + private StandardScalerModelData modelData; + + private DenseVector mean; + + /** Inverse of standard deviation. */ + private DenseVector scale; + + private long modelVersion; + + private long modelTimeStamp; + + public PredictionOperator( + RowTypeInfo inputTypeInfo, + String inputCol, + boolean withMean, + boolean withStd, + long maxAllowedModelDelayMs, + String modelVersionCol) { + this.inputTypeInfo = inputTypeInfo; + this.inputCol = inputCol; + this.withMean = withMean; + this.withStd = withStd; + this.maxAllowedModelDelayMs = maxAllowedModelDelayMs; + this.modelVersionCol = modelVersionCol; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + bufferedPointsState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<StreamRecord>( + "bufferedPoints", + new StreamElementSerializer( + inputTypeInfo.createSerializer( + getExecutionConfig())))); + + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "modelData", + TypeInformation.of(StandardScalerModelData.class))); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + if (modelData != null) { + modelDataState.clear(); + modelDataState.add(modelData); + } + } + + @Override + public void open() throws Exception { + super.open(); + MetricGroup mlModelMetricGroup = + getRuntimeContext() + .getMetricGroup() + .addGroup(MLMetrics.ML_GROUP) + .addGroup( + MLMetrics.ML_MODEL_GROUP, + OnlineStandardScalerModel.class.getSimpleName()); + mlModelMetricGroup.gauge(MLMetrics.TIMESTAMP, (Gauge<Long>) () -> modelTimeStamp); + mlModelMetricGroup.gauge(MLMetrics.VERSION, (Gauge<Long>) () -> modelVersion); + } + + @Override + public void processElement1(StreamRecord<Row> dataPoint) throws Exception { + if (dataPoint.getTimestamp() - maxAllowedModelDelayMs <= modelTimeStamp Review Comment: What is the value of `modelTimeStamp` if `processElement2()` has not been invoked? Should we explicitly give it a default value? In general it is a good practice to make sure all variables are initialized in the constructor (or `open()` in this case). Maybe do it for other variables as well. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java: ########## @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.window.CountTumblingWindows; +import org.apache.flink.ml.common.window.EventTimeTumblingWindows; +import org.apache.flink.ml.common.window.GlobalWindows; +import org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; +import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Tests {@link OnlineStandardScaler} and {@link OnlineStandardScalerModel}. */ +public class OnlineStandardScalerTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private final List<Row> inputData = + Arrays.asList( + Row.of(0L, Vectors.dense(-2.5, 9, 1)), + Row.of(1000L, Vectors.dense(1.4, -5, 1)), + Row.of(2000L, Vectors.dense(2, -1, -2)), + Row.of(6000L, Vectors.dense(0.7, 3, 1)), + Row.of(7000L, Vectors.dense(0, 1, 1)), + Row.of(8000L, Vectors.dense(0.5, 0, -2)), + Row.of(9000L, Vectors.dense(0.4, 1, 1)), + Row.of(10000L, Vectors.dense(0.3, 2, 1)), + Row.of(11000L, Vectors.dense(0.5, 1, -2))); + + private final List<StandardScalerModelData> expectedModelData = + Arrays.asList( + new StandardScalerModelData( + Vectors.dense(0.3, 1, 0), + Vectors.dense(2.4433583, 7.2111026, 1.7320508), + 0L, + 2999L), + new StandardScalerModelData( + Vectors.dense(0.35, 1.1666667, 0), + Vectors.dense(1.5630099, 4.6654760, 1.5491933), + 1L, + 8999L), + new StandardScalerModelData( + Vectors.dense(0.3666667, 1.2222222, 0), + Vectors.dense(1.2369316, 3.7006005, 1.5), + 2L, + 11999L)); + + private static final double TOLERANCE = 1e-7; + + private Table inputTable; + + private Table inputTableWithProcessingTime; + + private Table inputTableWithEventTime; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.getConfig().enableObjectReuse(); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + DataStream<Row> inputStream = env.fromCollection(inputData); + inputTable = + tEnv.fromDataStream( + inputStream, + Schema.newBuilder() + .column("f0", DataTypes.BIGINT()) + .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .build()) + .as("id", "input"); + + DataStream<Row> inputStreamWithProcessingTimeGap = + inputStream + .map( + new MapFunction<Row, Row>() { + private int count = 0; + + @Override + public Row map(Row value) throws Exception { + count++; + if (count % 3 == 0) { + Thread.sleep(1000); + } + return value; + } + }, + new RowTypeInfo( + new TypeInformation[] { + Types.LONG, DenseVectorTypeInfo.INSTANCE + }, + new String[] {"id", "input"})) + .setParallelism(1); + + inputTableWithProcessingTime = tEnv.fromDataStream(inputStreamWithProcessingTimeGap); + + DataStream<Row> inputStreamWithEventTime = + inputStream.assignTimestampsAndWatermarks( + WatermarkStrategy.<Row>forMonotonousTimestamps() + .withTimestampAssigner( + (SerializableTimestampAssigner<Row>) + (element, recordTimestamp) -> + element.getFieldAs(0))); + inputTableWithEventTime = + tEnv.fromDataStream( + inputStreamWithEventTime, + Schema.newBuilder() + .column("f0", DataTypes.BIGINT()) + .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build()) + .as("id", "input"); + } + + @Test + public void testParam() { + OnlineStandardScaler standardScaler = new OnlineStandardScaler(); + + assertEquals("input", standardScaler.getInputCol()); + assertEquals(false, standardScaler.getWithMean()); + assertEquals(true, standardScaler.getWithStd()); + assertEquals("output", standardScaler.getOutputCol()); + assertEquals("version", standardScaler.getModelVersionCol()); + assertEquals(GlobalWindows.getInstance(), standardScaler.getWindows()); + assertEquals(0L, standardScaler.getMaxAllowedModelDelayMs()); + + standardScaler + .setInputCol("test_input") + .setWithMean(true) + .setWithStd(false) + .setOutputCol("test_output") + .setModelVersionCol("model_version_col") Review Comment: Use non-default value here? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
