jiangxin369 commented on code in PR #196: URL: https://github.com/apache/flink-ml/pull/196#discussion_r1064248240
########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxAllowedModelDelayMs.java: ########## @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.LongParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared max allowed model delay in milliseconds param. */ +public interface HasMaxAllowedModelDelayMs<T> extends WithParams<T> { + Param<Long> MAX_ALLOWED_MODEL_DELAY_MS = + new LongParam( + "maxAllowedModelDelayMs", + "The maximum difference between timestamp of the input record and model data when " Review Comment: ```suggestion "The maximum difference between the timestamps of the input record and model data when " ``` So as in Python and Markdown files. ########## flink-ml-python/pyflink/ml/lib/param.py: ########## @@ -581,3 +581,48 @@ def get_flatten(self) -> bool: @property def flatten(self): return self.get(self.FLATTEN) + + +class HasModelVersionCol(WithParams, ABC): Review Comment: Could you add tests of these two params in `test_param.py`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java: ########## @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.standardscaler; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.functions.windowing.ProcessAllWindowFunction; +import org.apache.flink.streaming.api.windowing.windows.Window; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * An Estimator which implements the online standard scaling algorithm, which is the online version + * of {@link StandardScaler}. + * + * <p>OnlineStandardScaler splits the input data by the user-specified window strategy (i.e., {@link + * org.apache.flink.ml.common.param.HasWindows}). For each window, it computes the mean and standard + * deviation using the data seen so far (i.e., not only the data in the current window, but also the + * history data). The model data generated by OnlineStandardScaler is a model stream. There is one + * model data for each window. + * + * <p>During the inference phase (i.e., using {@link OnlineStandardScalerModel} for prediction), + * users could output the model version that is used for predicting each data point. Moreover, + * + * <ul> + * <li>When the train data and test data both contains event time, users could specify the maximum + * difference between timestamp of the input and model data ({@link + * org.apache.flink.ml.common.param.HasMaxAllowedModelDelayMs}), which enforces to use a + * relatively fresh model for prediction. + * <li>Otherwise, the prediction process always use the current model data for prediction. + * </ul> + */ +public class OnlineStandardScaler + implements Estimator<OnlineStandardScaler, OnlineStandardScalerModel>, + OnlineStandardScalerParams<OnlineStandardScaler> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public OnlineStandardScaler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineStandardScalerModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<StandardScalerModelData> modelData = + DataStreamUtils.windowAllAndProcess( + tEnv.toDataStream(inputs[0]), + getWindows(), + new ComputeModelDataFunction<>(getInputCol())); + + OnlineStandardScalerModel model = + new OnlineStandardScalerModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + private static class ComputeModelDataFunction<W extends Window> + extends ProcessAllWindowFunction<Row, StandardScalerModelData, W> { + + private final String inputCol; + + public ComputeModelDataFunction(String inputCol) { + this.inputCol = inputCol; + } + + @Override + public void process( + ProcessAllWindowFunction<Row, StandardScalerModelData, W>.Context context, + Iterable<Row> iterable, + Collector<StandardScalerModelData> collector) + throws Exception { + ListState<DenseVector> sumState = + context.globalState() + .getListState( + new ListStateDescriptor<>( + "sumState", DenseVectorTypeInfo.INSTANCE)); + ListState<DenseVector> squaredSumState = + context.globalState() + .getListState( + new ListStateDescriptor<>( + "squaredSumState", DenseVectorTypeInfo.INSTANCE)); + ListState<Long> numElementsState = + context.globalState() + .getListState( + new ListStateDescriptor<>("numElementsState", Types.LONG)); + ListState<Long> modelVersionState = + context.globalState() + .getListState( + new ListStateDescriptor<>("modelVersionState", Types.LONG)); + DenseVector sum = + OperatorStateUtils.getUniqueElement(sumState, "sumState").orElse(null); + DenseVector squaredSum = + OperatorStateUtils.getUniqueElement(squaredSumState, "squaredSumState") + .orElse(null); + long numElements = + OperatorStateUtils.getUniqueElement(numElementsState, "numElementsState") + .orElse(0L); + long modelVersion = + OperatorStateUtils.getUniqueElement(modelVersionState, "modelVersionState") + .orElse(0L); + + long numElementsBefore = numElements; + for (Row element : iterable) { + Vector inputVec = + ((Vector) Objects.requireNonNull(element.getField(inputCol))).clone(); Review Comment: Why do we need to clone the input vector? ########## flink-ml-python/pyflink/ml/lib/feature/tests/test_onlinestandardscaler.py: ########## @@ -0,0 +1,210 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from pyflink.common import Types +from pyflink.common.time import Time, Instant +from pyflink.java_gateway import get_gateway +from pyflink.table import Schema +from pyflink.table.types import DataTypes +from pyflink.table.expressions import col + +from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo +from pyflink.ml.lib.feature.onlinestandardscaler import OnlineStandardScaler, \ + OnlineStandardScalerModel +from pyflink.ml.tests.test_utils import PyFlinkMLTestCase, update_existing_params +from pyflink.ml.core.windows import GlobalWindows, EventTimeTumblingWindows + + +class OnlineStandardScalerTest(PyFlinkMLTestCase): + def setUp(self): + super(OnlineStandardScalerTest, self).setUp() + + dense_vector_serializer = get_gateway().jvm.org.apache.flink.table.types.logical.RawType( + get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(0).getClass(), + get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer() + ).getSerializerString() + + schema = Schema.new_builder() \ + .column("ts_in_long", DataTypes.BIGINT()) \ + .column("ts", "TIMESTAMP_LTZ(3)") \ + .column("input", "RAW('org.apache.flink.ml.linalg.DenseVector', '{serializer}')" + .format(serializer=dense_vector_serializer)) \ + .watermark("ts", "ts - INTERVAL '1' SECOND") \ + .build() + + self.input_table = self.t_env.from_data_stream( + self.env.from_collection([ + (0, Instant.of_epoch_milli(0), Vectors.dense(-2.5, 9, 1),), + (1000, Instant.of_epoch_milli(1000), Vectors.dense(1.4, -5, 1),), + (2000, Instant.of_epoch_milli(2000), Vectors.dense(2, -1, -2),), + (6000, Instant.of_epoch_milli(6000), Vectors.dense(0.7, 3, 1),), + (7000, Instant.of_epoch_milli(7000), Vectors.dense(0, 1, 1),), + (8000, Instant.of_epoch_milli(8000), Vectors.dense(0.5, 0, -2),), + (9000, Instant.of_epoch_milli(9000), Vectors.dense(0.4, 1, 1),), + (10000, Instant.of_epoch_milli(10000), Vectors.dense(0.3, 2, 1),), + (11000, Instant.of_epoch_milli(11000), Vectors.dense(0.5, 1, -2),) + ], + type_info=Types.ROW_NAMED( + ['ts_in_long', 'ts', 'input'], + [Types.LONG(), Types.INSTANT(), DenseVectorTypeInfo()])), + schema) + + self.expected_model_data = [ + [ + Vectors.dense(0.3, 1, 0), + Vectors.dense(2.4433583, 7.2111026, 1.7320508), + 0, + 2999 + ], + [ + Vectors.dense(0.35, 1.1666667, 0), + Vectors.dense(1.5630099, 4.6654760, 1.5491933), + 1, + 8999 + ], + [ + Vectors.dense(0.3666667, 1.2222222, 0), + Vectors.dense(1.2369316, 3.7006005, 1.5), + 2, + 11999 + ] + ] + + self.tolerance = 1e-7 + + def test_param(self): + standard_scaler = OnlineStandardScaler() + + self.assertEqual('input', standard_scaler.input_col) + self.assertEqual(False, standard_scaler.with_mean) + self.assertEqual(True, standard_scaler.with_std) + self.assertEqual('output', standard_scaler.output_col) + self.assertIsNone(standard_scaler.model_version_col) + self.assertEqual(GlobalWindows(), standard_scaler.windows) + self.assertEqual(0, standard_scaler.max_allowed_model_delay_ms) + + standard_scaler.set_input_col('test_input') \ + .set_with_mean(True) \ + .set_with_std(False) \ + .set_output_col('test_output') \ + .set_model_version_col('model_version_col') \ + .set_windows(EventTimeTumblingWindows.of(Time.milliseconds(3000))) \ + .set_max_allowed_model_delay_ms(3000) + + self.assertEqual('test_input', standard_scaler.input_col) + self.assertEqual(True, standard_scaler.with_mean) + self.assertEqual(False, standard_scaler.with_std) + self.assertEqual('test_output', standard_scaler.output_col) + self.assertEqual("model_version_col", standard_scaler.model_version_col) + self.assertEqual(EventTimeTumblingWindows.of(Time.milliseconds(3000)), + standard_scaler.windows) + self.assertEqual(3000, standard_scaler.max_allowed_model_delay_ms) + + def test_output_schema(self): + temp_table = self.input_table.alias('ts_in_long', 'ts', "test_input") + + standard_scaler = OnlineStandardScaler() \ + .set_input_col('test_input') \ + .set_output_col('test_output') \ + .set_model_version_col('model_version_col') + + output = standard_scaler.fit(temp_table).transform(temp_table)[0] + + self.assertEqual(['ts_in_long', 'ts', 'test_input', 'test_output', 'model_version_col'], + output.get_schema().get_field_names()) + + def test_fit_and_predict(self): + standard_scaler = OnlineStandardScaler() + window_size_ms = 3000 + + standard_scaler \ + .set_windows(EventTimeTumblingWindows.of(Time.milliseconds(window_size_ms))) \ + .set_model_version_col("model_version_col") + output = standard_scaler.fit(self.input_table).transform(self.input_table)[0] + self.verify_used_model_version(output, standard_scaler.model_version_col, + standard_scaler.max_allowed_model_delay_ms) + + def test_get_model_data(self): + standard_scaler = OnlineStandardScaler() + window_size_ms = 3000 + + standard_scaler \ + .set_windows(EventTimeTumblingWindows.of(Time.milliseconds(window_size_ms))) \ + .set_model_version_col("model_version_col") + + model_data = standard_scaler.fit(self.input_table).get_model_data()[0] + self.assertEqual(["mean", "std", "version", "timestamp"], + model_data.get_schema().get_field_names()) + + model_rows = [result for result in + self.t_env.to_data_stream(model_data).execute_and_collect()] + self.assertEqual(len(self.expected_model_data), len(model_rows)) + for idx in range(len(self.expected_model_data)): + self.assertListAlmostEqual(self.expected_model_data[idx][0].to_array(), + model_rows[idx][0].to_array(), + delta=self.tolerance) + self.assertListAlmostEqual(self.expected_model_data[idx][1].to_array(), + model_rows[idx][1].to_array(), + delta=self.tolerance) + self.assertEqual(self.expected_model_data[idx][2], model_rows[idx][2]) + self.assertEqual(self.expected_model_data[idx][3], model_rows[idx][3]) + + def test_set_model_data(self): + standard_scaler = OnlineStandardScaler() + window_size_ms = 3000 + + model = standard_scaler \ + .set_windows(EventTimeTumblingWindows.of(Time.milliseconds(window_size_ms))) \ + .set_model_version_col("model_version_col") \ + .fit(self.input_table) + model_data = model.get_model_data()[0] + + new_model = OnlineStandardScalerModel().set_model_data(model_data) + update_existing_params(new_model, model) + output = new_model.transform(self.input_table)[0] + + self.verify_used_model_version(output, model.model_version_col, + model.max_allowed_model_delay_ms) + + def test_save_load_and_predict(self): + window_size_ms = 3000 Review Comment: Would it be better if declare the `window_size_ms` as a shared constant? ########## flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java: ########## @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.window.CountTumblingWindows; +import org.apache.flink.ml.common.window.EventTimeTumblingWindows; +import org.apache.flink.ml.common.window.GlobalWindows; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; +import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests {@link OnlineStandardScaler} and {@link OnlineStandardScalerModel}. */ +public class OnlineStandardScalerTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private final List<Row> inputData = + Arrays.asList( + Row.of(0L, Vectors.dense(-2.5, 9, 1)), + Row.of(1000L, Vectors.dense(1.4, -5, 1)), + Row.of(2000L, Vectors.dense(2, -1, -2)), + Row.of(6000L, Vectors.dense(0.7, 3, 1)), + Row.of(7000L, Vectors.dense(0, 1, 1)), + Row.of(8000L, Vectors.dense(0.5, 0, -2)), + Row.of(9000L, Vectors.dense(0.4, 1, 1)), + Row.of(10000L, Vectors.dense(0.3, 2, 1)), + Row.of(11000L, Vectors.dense(0.5, 1, -2))); + + private final List<StandardScalerModelData> expectedModelData = + Arrays.asList( + new StandardScalerModelData( + Vectors.dense(0.3, 1, 0), + Vectors.dense(2.4433583, 7.2111026, 1.7320508), + 0L, + 2999L), + new StandardScalerModelData( + Vectors.dense(0.35, 1.1666667, 0), + Vectors.dense(1.5630099, 4.6654760, 1.5491933), + 1L, + 8999L), + new StandardScalerModelData( + Vectors.dense(0.3666667, 1.2222222, 0), + Vectors.dense(1.2369316, 3.7006005, 1.5), + 2L, + 11999L)); + + private static final double TOLERANCE = 1e-7; + + private Table inputTable; + + private Table inputTableWithEventTime; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.getConfig().enableObjectReuse(); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + DataStream<Row> inputStream = env.fromCollection(inputData); + inputTable = + tEnv.fromDataStream( + inputStream, + Schema.newBuilder() + .column("f0", DataTypes.BIGINT()) + .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .build()) + .as("id", "input"); + + DataStream<Row> inputStreamWithEventTime = + inputStream.assignTimestampsAndWatermarks( + WatermarkStrategy.<Row>forMonotonousTimestamps() + .withTimestampAssigner( + (SerializableTimestampAssigner<Row>) + (element, recordTimestamp) -> + element.getFieldAs(0))); + inputTableWithEventTime = + tEnv.fromDataStream( + inputStreamWithEventTime, + Schema.newBuilder() + .column("f0", DataTypes.BIGINT()) + .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build()) + .as("id", "input"); + } + + @Test + public void testParam() { + OnlineStandardScaler standardScaler = new OnlineStandardScaler(); + + assertEquals("input", standardScaler.getInputCol()); + assertEquals(false, standardScaler.getWithMean()); + assertEquals(true, standardScaler.getWithStd()); + assertEquals("output", standardScaler.getOutputCol()); + assertNull(standardScaler.getModelVersionCol()); + assertEquals(GlobalWindows.getInstance(), standardScaler.getWindows()); + assertEquals(0L, standardScaler.getMaxAllowedModelDelayMs()); + + standardScaler + .setInputCol("test_input") + .setWithMean(true) + .setWithStd(false) + .setOutputCol("test_output") + .setModelVersionCol("model_version_col") + .setWindows(EventTimeTumblingWindows.of(Time.milliseconds(3000))) + .setMaxAllowedModelDelayMs(3000L); + + assertEquals("test_input", standardScaler.getInputCol()); + assertEquals(true, standardScaler.getWithMean()); + assertEquals(false, standardScaler.getWithStd()); + assertEquals("test_output", standardScaler.getOutputCol()); + assertEquals("model_version_col", standardScaler.getModelVersionCol()); + assertEquals( + EventTimeTumblingWindows.of(Time.milliseconds(3000)), standardScaler.getWindows()); + assertEquals(3000L, standardScaler.getMaxAllowedModelDelayMs()); + } + + @Test + public void testOutputSchema() { + Table renamedTable = inputTable.as("test_id", "test_input"); + OnlineStandardScaler standardScaler = + new OnlineStandardScaler() + .setInputCol("test_input") + .setOutputCol("test_output") + .setModelVersionCol("model_version_col"); + Table output = standardScaler.fit(renamedTable).transform(renamedTable)[0]; + + assertEquals( + Arrays.asList("test_id", "test_input", "test_output", "model_version_col"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFitAndPredictWithEventTimeWindow() throws Exception { + OnlineStandardScaler standardScaler = new OnlineStandardScaler(); + Table output; + int windowSizeMs = 3000; + + // Tests event time window with maxAllowedModelDelayMs as 0. + standardScaler + .setWindows(EventTimeTumblingWindows.of(Time.milliseconds(windowSizeMs))) + .setModelVersionCol("modelVersionCol"); + output = standardScaler.fit(inputTableWithEventTime).transform(inputTableWithEventTime)[0]; + verifyUsedModelVersion( Review Comment: Could you also verify the outputs of the prediction? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerParams.java: ########## @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.standardscaler; + +import org.apache.flink.ml.common.param.HasWindows; + +/** Params for {@link OnlineStandardScaler}. */ +public interface OnlineStandardScalerParams<T> + extends HasWindows<T>, OnlineStandardScalerModelParams<T> {} Review Comment: From my testing, not all `Windows` are supported, e.g., when using `ProcessingTimeTumblingWindows`, the `OnlineStandardScaler` cannot train any models. Is that expected behavior? If so, maybe we can add a validator and document the windows we supported. ########## docs/content/docs/operators/feature/onlinestandardscaler.md: ########## @@ -0,0 +1,261 @@ +--- +title: "OnlineStandardScaler" +weight: 1 +type: docs +aliases: +- /operators/feature/onlinestandardscaler.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## OnlineStandardScaler + +An Estimator which implements the online standard scaling algorithm, which +is the online version of StandardScaler. + +OnlineStandardScaler splits the input data by the user-specified window strategy. +For each window, it computes the mean and standard deviation using the data seen +so far (i.e., not only the data in the current window, but also the history data). +The model data generated by OnlineStandardScaler is a model stream. +There is one model data for each window. + +During the inference phase (i.e., using OnlineStandardScalerModel for prediction), +users could output the model version that is used for predicting each data point. +Moreover, +- When the train data and test data both contains event time, users could +specify the maximum difference between timestamp of the input and model data, +which enforces to use a relatively fresh model for prediction. +- Otherwise, the prediction process always use the current model data for prediction. Review Comment: ```suggestion - Otherwise, the prediction process always uses the current model data for prediction. ``` ########## docs/content/docs/operators/feature/onlinestandardscaler.md: ########## @@ -0,0 +1,261 @@ +--- +title: "OnlineStandardScaler" +weight: 1 +type: docs +aliases: +- /operators/feature/onlinestandardscaler.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## OnlineStandardScaler + +An Estimator which implements the online standard scaling algorithm, which +is the online version of StandardScaler. + +OnlineStandardScaler splits the input data by the user-specified window strategy. +For each window, it computes the mean and standard deviation using the data seen +so far (i.e., not only the data in the current window, but also the history data). +The model data generated by OnlineStandardScaler is a model stream. +There is one model data for each window. + +During the inference phase (i.e., using OnlineStandardScalerModel for prediction), +users could output the model version that is used for predicting each data point. +Moreover, +- When the train data and test data both contains event time, users could Review Comment: ```suggestion - When the train data and test data both contain event time, users could ``` So as in Javadoc and Python. ########## docs/content/docs/operators/feature/onlinestandardscaler.md: ########## @@ -0,0 +1,261 @@ +--- +title: "OnlineStandardScaler" +weight: 1 +type: docs +aliases: +- /operators/feature/onlinestandardscaler.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## OnlineStandardScaler + +An Estimator which implements the online standard scaling algorithm, which +is the online version of StandardScaler. + +OnlineStandardScaler splits the input data by the user-specified window strategy. +For each window, it computes the mean and standard deviation using the data seen +so far (i.e., not only the data in the current window, but also the history data). +The model data generated by OnlineStandardScaler is a model stream. +There is one model data for each window. + +During the inference phase (i.e., using OnlineStandardScalerModel for prediction), +users could output the model version that is used for predicting each data point. +Moreover, +- When the train data and test data both contains event time, users could +specify the maximum difference between timestamp of the input and model data, +which enforces to use a relatively fresh model for prediction. +- Otherwise, the prediction process always use the current model data for prediction. + + +### Input Columns + +| Param name | Type | Default | Description | +|:-----------|:-------|:----------|:-----------------------| +| inputCol | Vector | `"input"` | Features to be scaled. | + +### Output Columns + +| Param name | Type | Default | Description | +|:-----------|:-------|:-----------|:-----------------| +| outputCol | Vector | `"output"` | Scaled features. | + +### Parameters + +Below are the parameters required by `OnlineStandardScalerModel`. + +| Key | Default | Type | Required | Description | +|------------------------|------------|---------|----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| inputCol | `"input"` | String | no | Input column name. | +| outputCol | `"output"` | String | no | Output column name. | +| withMean | `false` | Boolean | no | Whether centers the data with mean before scaling. | +| withStd | `true` | Boolean | no | Whether scales the data with standard deviation. | + | modelVersionCol | `null` | String | no | The version of the model data that the input data is predicted with. | + | maxAllowedModelDelayMs | `0L` | Long | no | The maximum difference between timestamp of the input record and model data when using the model data to predict that input record. This param only works when the input contains event time." | + +`OnlineStandardScaler` needs parameters above and also below. + +| Key | Default | Type | Required | Description | +|---------|-------------------------------|---------|----------|--------------------------------------------------------------------------------| +| windows | `GlobalWindows.getInstance()` | Windows | no | Windowing strategy that determines how to create mini-batches from input data. | + + +### Examples + +{{< tabs examples >}} + +{{< tab "Java">}} + +```java +import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.ml.common.window.EventTimeTumblingWindows; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import java.util.Arrays; +import java.util.List; + +/** Simple program that trains a OnlineStandardScaler model and uses it for feature engineering. */ +public class OnlineStandardScalerExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + List<Row> inputData = + Arrays.asList( + Row.of(0L, Vectors.dense(-2.5, 9, 1)), + Row.of(1000L, Vectors.dense(1.4, -5, 1)), + Row.of(2000L, Vectors.dense(2, -1, -2)), + Row.of(6000L, Vectors.dense(0.7, 3, 1)), + Row.of(7000L, Vectors.dense(0, 1, 1)), + Row.of(8000L, Vectors.dense(0.5, 0, -2)), + Row.of(9000L, Vectors.dense(0.4, 1, 1)), + Row.of(10000L, Vectors.dense(0.3, 2, 1)), + Row.of(11000L, Vectors.dense(0.5, 1, -2))); + + DataStream<Row> inputStream = env.fromCollection(inputData); + + DataStream<Row> inputStreamWithEventTime = + inputStream.assignTimestampsAndWatermarks( + WatermarkStrategy.<Row>forMonotonousTimestamps() + .withTimestampAssigner( + (SerializableTimestampAssigner<Row>) + (element, recordTimestamp) -> + element.getFieldAs(0))); + + Table inputTable = + tEnv.fromDataStream( + inputStreamWithEventTime, + Schema.newBuilder() + .column("f0", DataTypes.BIGINT()) + .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build()) + .as("id", "input"); + + // Creates an OnlineStandardScaler object and initializes its parameters. + long windowSizeMs = 3000; + OnlineStandardScaler onlineStandardScaler = + new OnlineStandardScaler() + .setWindows(EventTimeTumblingWindows.of(Time.milliseconds(windowSizeMs))) + .setModelVersionCol("modelVersionCol"); + + // Trains the OnlineStandardScaler Model. + OnlineStandardScalerModel model = onlineStandardScaler.fit(inputTable); + + // Uses the OnlineStandardScaler Model for predictions. + Table outputTable = model.transform(inputTable)[0]; + + // Extracts and displays the results. + for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + DenseVector inputValue = (DenseVector) row.getField(onlineStandardScaler.getInputCol()); + DenseVector outputValue = + (DenseVector) row.getField(onlineStandardScaler.getOutputCol()); + long modelVersion = row.getFieldAs(onlineStandardScaler.getModelVersionCol()); + System.out.printf( + "Input Value: %s\tOutput Value: %s\tModel Version: %s\n", + inputValue, outputValue, modelVersion); + } + } +} + Review Comment: Please remove the redundant empty line. ########## docs/content/docs/operators/feature/onlinestandardscaler.md: ########## @@ -0,0 +1,261 @@ +--- +title: "OnlineStandardScaler" +weight: 1 +type: docs +aliases: +- /operators/feature/onlinestandardscaler.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## OnlineStandardScaler + +An Estimator which implements the online standard scaling algorithm, which +is the online version of StandardScaler. + +OnlineStandardScaler splits the input data by the user-specified window strategy. +For each window, it computes the mean and standard deviation using the data seen +so far (i.e., not only the data in the current window, but also the history data). +The model data generated by OnlineStandardScaler is a model stream. +There is one model data for each window. + +During the inference phase (i.e., using OnlineStandardScalerModel for prediction), +users could output the model version that is used for predicting each data point. +Moreover, +- When the train data and test data both contains event time, users could +specify the maximum difference between timestamp of the input and model data, +which enforces to use a relatively fresh model for prediction. +- Otherwise, the prediction process always use the current model data for prediction. + + +### Input Columns + +| Param name | Type | Default | Description | +|:-----------|:-------|:----------|:-----------------------| +| inputCol | Vector | `"input"` | Features to be scaled. | + +### Output Columns + +| Param name | Type | Default | Description | +|:-----------|:-------|:-----------|:-----------------| +| outputCol | Vector | `"output"` | Scaled features. | + +### Parameters + +Below are the parameters required by `OnlineStandardScalerModel`. + +| Key | Default | Type | Required | Description | +|------------------------|------------|---------|----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| inputCol | `"input"` | String | no | Input column name. | +| outputCol | `"output"` | String | no | Output column name. | +| withMean | `false` | Boolean | no | Whether centers the data with mean before scaling. | +| withStd | `true` | Boolean | no | Whether scales the data with standard deviation. | + | modelVersionCol | `null` | String | no | The version of the model data that the input data is predicted with. | Review Comment: ```suggestion | modelVersionCol | `null` | String | no | The version of the model data that the input data is predicted with. | ``` So as the below line. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java: ########## @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.standardscaler; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.functions.windowing.ProcessAllWindowFunction; +import org.apache.flink.streaming.api.windowing.windows.Window; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * An Estimator which implements the online standard scaling algorithm, which is the online version + * of {@link StandardScaler}. + * + * <p>OnlineStandardScaler splits the input data by the user-specified window strategy (i.e., {@link + * org.apache.flink.ml.common.param.HasWindows}). For each window, it computes the mean and standard + * deviation using the data seen so far (i.e., not only the data in the current window, but also the + * history data). The model data generated by OnlineStandardScaler is a model stream. There is one + * model data for each window. + * + * <p>During the inference phase (i.e., using {@link OnlineStandardScalerModel} for prediction), + * users could output the model version that is used for predicting each data point. Moreover, + * + * <ul> + * <li>When the train data and test data both contains event time, users could specify the maximum + * difference between timestamp of the input and model data ({@link + * org.apache.flink.ml.common.param.HasMaxAllowedModelDelayMs}), which enforces to use a + * relatively fresh model for prediction. + * <li>Otherwise, the prediction process always use the current model data for prediction. + * </ul> + */ +public class OnlineStandardScaler + implements Estimator<OnlineStandardScaler, OnlineStandardScalerModel>, + OnlineStandardScalerParams<OnlineStandardScaler> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public OnlineStandardScaler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineStandardScalerModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<StandardScalerModelData> modelData = + DataStreamUtils.windowAllAndProcess( + tEnv.toDataStream(inputs[0]), + getWindows(), + new ComputeModelDataFunction<>(getInputCol())); + + OnlineStandardScalerModel model = + new OnlineStandardScalerModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + private static class ComputeModelDataFunction<W extends Window> + extends ProcessAllWindowFunction<Row, StandardScalerModelData, W> { + + private final String inputCol; + + public ComputeModelDataFunction(String inputCol) { + this.inputCol = inputCol; + } + + @Override + public void process( + ProcessAllWindowFunction<Row, StandardScalerModelData, W>.Context context, + Iterable<Row> iterable, + Collector<StandardScalerModelData> collector) + throws Exception { + ListState<DenseVector> sumState = + context.globalState() + .getListState( + new ListStateDescriptor<>( + "sumState", DenseVectorTypeInfo.INSTANCE)); + ListState<DenseVector> squaredSumState = + context.globalState() + .getListState( + new ListStateDescriptor<>( + "squaredSumState", DenseVectorTypeInfo.INSTANCE)); + ListState<Long> numElementsState = + context.globalState() + .getListState( + new ListStateDescriptor<>("numElementsState", Types.LONG)); + ListState<Long> modelVersionState = + context.globalState() + .getListState( + new ListStateDescriptor<>("modelVersionState", Types.LONG)); + DenseVector sum = + OperatorStateUtils.getUniqueElement(sumState, "sumState").orElse(null); + DenseVector squaredSum = + OperatorStateUtils.getUniqueElement(squaredSumState, "squaredSumState") + .orElse(null); + long numElements = + OperatorStateUtils.getUniqueElement(numElementsState, "numElementsState") + .orElse(0L); + long modelVersion = + OperatorStateUtils.getUniqueElement(modelVersionState, "modelVersionState") + .orElse(0L); + + long numElementsBefore = numElements; + for (Row element : iterable) { + Vector inputVec = + ((Vector) Objects.requireNonNull(element.getField(inputCol))).clone(); + if (numElements == 0) { Review Comment: How about moving the `if` block out of the `for` loop? ########## flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OnlineStandardScalerExample.java: ########## @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.examples.feature; + +import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.ml.common.window.EventTimeTumblingWindows; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; +import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import java.util.Arrays; +import java.util.List; + +/** Simple program that trains a OnlineStandardScaler model and uses it for feature engineering. */ +public class OnlineStandardScalerExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + List<Row> inputData = + Arrays.asList( + Row.of(0L, Vectors.dense(-2.5, 9, 1)), + Row.of(1000L, Vectors.dense(1.4, -5, 1)), + Row.of(2000L, Vectors.dense(2, -1, -2)), + Row.of(6000L, Vectors.dense(0.7, 3, 1)), + Row.of(7000L, Vectors.dense(0, 1, 1)), + Row.of(8000L, Vectors.dense(0.5, 0, -2)), + Row.of(9000L, Vectors.dense(0.4, 1, 1)), + Row.of(10000L, Vectors.dense(0.3, 2, 1)), + Row.of(11000L, Vectors.dense(0.5, 1, -2))); + + DataStream<Row> inputStream = env.fromCollection(inputData); + + DataStream<Row> inputStreamWithEventTime = + inputStream.assignTimestampsAndWatermarks( + WatermarkStrategy.<Row>forMonotonousTimestamps() + .withTimestampAssigner( + (SerializableTimestampAssigner<Row>) + (element, recordTimestamp) -> + element.getFieldAs(0))); + + Table inputTable = + tEnv.fromDataStream( + inputStreamWithEventTime, + Schema.newBuilder() + .column("f0", DataTypes.BIGINT()) + .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build()) + .as("id", "input"); + + // Creates an OnlineStandardScaler object and initializes its parameters. + long windowSizeMs = 3000; + OnlineStandardScaler onlineStandardScaler = + new OnlineStandardScaler() + .setWindows(EventTimeTumblingWindows.of(Time.milliseconds(windowSizeMs))) + .setModelVersionCol("modelVersionCol"); + + // Trains the OnlineStandardScaler Model. + OnlineStandardScalerModel model = onlineStandardScaler.fit(inputTable); + + // Uses the OnlineStandardScaler Model for predictions. + Table outputTable = model.transform(inputTable)[0]; + + // Extracts and displays the results. + for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + DenseVector inputValue = (DenseVector) row.getField(onlineStandardScaler.getInputCol()); + DenseVector outputValue = + (DenseVector) row.getField(onlineStandardScaler.getOutputCol()); + long modelVersion = row.getFieldAs(onlineStandardScaler.getModelVersionCol()); + System.out.printf( + "Input Value: %s\tOutput Value: %s\tModel Version: %s\n", Review Comment: Would it be better if we align the output of the `Model Version` column? E.g., `%-65s\t`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
