weibozhao commented on code in PR #83: URL: https://github.com/apache/flink-ml/pull/83#discussion_r881266473
########## flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java: ########## @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan + * McMahan et al. + * + * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click + * prediction: a view from the trenches.</a> + */ +public class OnlineLogisticRegression + implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>, + OnlineLogisticRegressionParams<OnlineLogisticRegression> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineLogisticRegression() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public OnlineLogisticRegressionModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<LogisticRegressionModelData> modelDataStream = + LogisticRegressionModelData.getModelDataStream(initModelDataTable); + + DataStream<Row> points = + tEnv.toDataStream(inputs[0]) + .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol())); + + DataStream<DenseVector> initModelData = + modelDataStream.map( + (MapFunction<LogisticRegressionModelData, DenseVector>) + value -> value.coefficient); + + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new FtrlIterationBody( + getGlobalBatchSize(), + getAlpha(), + getBeta(), + getReg(), + getElasticNet(), + getModelSaveInterval()); + + DataStream<LogisticRegressionModelData> onlineModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData); + OnlineLogisticRegressionModel model = + new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + private static class FeaturesExtractor implements MapFunction<Row, Row> { + private final String featuresCol; + private final String labelCol; + + private FeaturesExtractor(String featuresCol, String labelCol) { + this.featuresCol = featuresCol; + this.labelCol = labelCol; + } + + @Override + public Row map(Row row) throws Exception { + return Row.of(row.getField(featuresCol), row.getField(labelCol)); + } + } + + /** + * Implementation of ftrl optimizer. In this implementation, gradients are calculated in + * distributed workers and reduce to one gradient. The reduced gradient is used to update model + * by ftrl method. + * + * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl + */ + private static class FtrlIterationBody implements IterationBody { + private final int batchSize; + private final double alpha; + private final double beta; + private final double l1; + private final double l2; + private long modelVersion = 1L; + private final int modelSaveInterval; + + public FtrlIterationBody( + int batchSize, + double alpha, + double beta, + double reg, + double elasticNet, + int modelSaveInterval) { + this.batchSize = batchSize; + this.alpha = alpha; + this.beta = beta; + this.l1 = elasticNet * reg; + this.l2 = (1 - elasticNet) * reg; + this.modelSaveInterval = modelSaveInterval; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<DenseVector> modelData = variableStreams.get(0); + + DataStream<Row> points = dataStreams.get(0); + int parallelism = points.getParallelism(); + Preconditions.checkState( + parallelism <= batchSize, + "There are more subtasks in the training process than the number " + + "of elements in each batch. Some subtasks might be idling forever."); + + DataStream<DenseVector[]> newGradient = + DataStreamUtils.generateBatchData(points, parallelism, batchSize) + .connect(modelData.broadcast()) + .transform( + "LocalGradientCalculator", + TypeInformation.of(DenseVector[].class), + new CalculateLocalGradient()) + .setParallelism(parallelism) + .countWindowAll(parallelism) + .reduce( + (ReduceFunction<DenseVector[]>) + (gradientInfo, newGradientInfo) -> { + for (int i = 0; + i < newGradientInfo[1].size(); + ++i) { + newGradientInfo[0].values[i] = + gradientInfo[0].values[i] + + newGradientInfo[0].values[i]; + newGradientInfo[1].values[i] = + gradientInfo[1].values[i] + + newGradientInfo[1].values[i]; + if (newGradientInfo[2] == null) { + newGradientInfo[2] = gradientInfo[2]; + } + } + return newGradientInfo; + }); + DataStream<DenseVector> feedbackModelData = + newGradient + .transform( + "ModelDataUpdater", + TypeInformation.of(DenseVector.class), + new UpdateModel(alpha, beta, l1, l2)) + .setParallelism(1); + + DataStream<LogisticRegressionModelData> outputModelData = + feedbackModelData + .filter( + new FilterFunction<DenseVector>() { + private int step = 0; + + @Override + public boolean filter(DenseVector denseVector) { + step++; + return step % modelSaveInterval == 0; + } + }) + .setParallelism(1) Review Comment: I think is not needed now. If more than one algorithm use this part of code, we should extract it to a function or class. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
