zhipeng93 commented on code in PR #192: URL: https://github.com/apache/flink-ml/pull/192#discussion_r1103961097
########## flink-ml-examples/src/main/java/org/apache/flink/ml/examples/recommendation/SwingExample.java: ########## @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.examples.recommendation; + +import org.apache.flink.ml.recommendation.swing.Swing; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +/** Simple program that creates a Swing instance and uses it to give recommendations for items. */ Review Comment: nit: ... and uses it to `generate` recommendations for items. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java: ########## @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Swing algorithm. + * + * <p>Swing is an item recall algorithm. The topology of user-item graph usually can be described as + * user-item-user or item-user-item, which are like 'swing'. For example, if both user <em>u</em> + * and user <em>v</em> have purchased the same commodity <em>i</em> , they will form a relationship + * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased commodity <em>j</em> in + * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are similar. The similarity + * between items in Swing is defined as + * + * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap + * U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha_2+|I_u\cap + * I_v|}} $$ + * + * <p>This implementation is based on the algorithm proposed in the paper: "Large Scale Product + * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, Yadong Zhu and Yi Zhang. + * (<a href="https://arxiv.org/pdf/2010.05525.pdf">https://arxiv.org/pdf/2010.05525.pdf</a>) + */ +public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Swing() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + Preconditions.checkArgument(inputs.length == 1); + final ResolvedSchema schema = inputs[0].getResolvedSchema(); + + if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol)) + && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, itemCol)))) { + throw new IllegalArgumentException("The types of user and item columns must be Long."); + } + + if (getMaxUserBehavior() < getMinUserBehavior()) { + throw new IllegalArgumentException( + "The maxUserBehavior must be larger or equal to minUserBehavior."); + } + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers = Review Comment: Is `purchasingBehavior` more explainable than `itemUsers`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java: ########## @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Swing algorithm. + * + * <p>Swing is an item recall algorithm. The topology of user-item graph usually can be described as + * user-item-user or item-user-item, which are like 'swing'. For example, if both user <em>u</em> + * and user <em>v</em> have purchased the same commodity <em>i</em> , they will form a relationship + * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased commodity <em>j</em> in + * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are similar. The similarity + * between items in Swing is defined as + * + * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap + * U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha_2+|I_u\cap + * I_v|}} $$ + * + * <p>This implementation is based on the algorithm proposed in the paper: "Large Scale Product + * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, Yadong Zhu and Yi Zhang. + * (<a href="https://arxiv.org/pdf/2010.05525.pdf">https://arxiv.org/pdf/2010.05525.pdf</a>) + */ +public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Swing() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + Preconditions.checkArgument(inputs.length == 1); + final ResolvedSchema schema = inputs[0].getResolvedSchema(); + + if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol)) + && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, itemCol)))) { + throw new IllegalArgumentException("The types of user and item columns must be Long."); + } + + if (getMaxUserBehavior() < getMinUserBehavior()) { + throw new IllegalArgumentException( + "The maxUserBehavior must be larger or equal to minUserBehavior."); Review Comment: Let's also print the value of `maxUserBehavior` and `minUserBehavior` in the error message to better debugging. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java: ########## @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Swing algorithm. + * + * <p>Swing is an item recall algorithm. The topology of user-item graph usually can be described as + * user-item-user or item-user-item, which are like 'swing'. For example, if both user <em>u</em> + * and user <em>v</em> have purchased the same commodity <em>i</em> , they will form a relationship + * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased commodity <em>j</em> in + * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are similar. The similarity + * between items in Swing is defined as + * + * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap + * U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha_2+|I_u\cap + * I_v|}} $$ + * + * <p>This implementation is based on the algorithm proposed in the paper: "Large Scale Product + * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, Yadong Zhu and Yi Zhang. + * (<a href="https://arxiv.org/pdf/2010.05525.pdf">https://arxiv.org/pdf/2010.05525.pdf</a>) + */ +public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Swing() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + Preconditions.checkArgument(inputs.length == 1); + final ResolvedSchema schema = inputs[0].getResolvedSchema(); + + if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol)) + && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, itemCol)))) { + throw new IllegalArgumentException("The types of user and item columns must be Long."); + } + + if (getMaxUserBehavior() < getMinUserBehavior()) { + throw new IllegalArgumentException( + "The maxUserBehavior must be larger or equal to minUserBehavior."); + } + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers = + tEnv.toDataStream(inputs[0]) + .map( + row -> { + if (row.getFieldAs(userCol) == null + || row.getFieldAs(itemCol) == null) { + throw new RuntimeException( + "Data of user and item column must not be null."); + } + return Tuple2.of( + (Long) row.getFieldAs(userCol), + (Long) row.getFieldAs(itemCol)); + }) + .returns(Types.TUPLE(Types.LONG, Types.LONG)); + + SingleOutputStreamOperator<Tuple3<Long, Long, Map<Long, String>>> userAllItemsStream = + itemUsers + .keyBy(tuple -> tuple.f0) + .transform( + "collectingUserBehavior", + Types.TUPLE( + Types.LONG, + Types.LONG, + Types.MAP(Types.LONG, Types.STRING)), + new CollectingUserBehavior( + getMinUserBehavior(), getMaxUserBehavior())); + + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + new TypeInformation[] {Types.LONG, Types.STRING}, + new String[] {getItemCol(), getOutputCol()}); + + DataStream<Row> output = + userAllItemsStream + .keyBy(tuple -> tuple.f1) + .transform( + "computingSimilarItems", + outputTypeInfo, + new ComputingSimilarItems( + getK(), + getMaxUserNumPerItem(), + getAlpha1(), + getAlpha2(), + getBeta())); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static Swing load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + /** + * Collects user behavior data and appends to the input table. + * + * <p>During the process, this operator collects users and all items he/she has purchased, and + * its input table must be bounded. Because Flink doesn't support type info of `Set` officially, + * The appended column is `Map` contains items as key and maps null value. + */ + private static class CollectingUserBehavior + extends AbstractStreamOperator<Tuple3<Long, Long, Map<Long, String>>> + implements OneInputStreamOperator< + Tuple2<Long, Long>, Tuple3<Long, Long, Map<Long, String>>>, + BoundedOneInput { + private final int minUserItemInteraction; + private final int maxUserItemInteraction; + + // Maps a user id to a set of items. Because ListState cannot keep values of type `Set`, + // we use `Map<Long, String>` with null values instead. So does `userItemsMap` and + // `itemUsersMap` in `ComputingSimilarItems`. + private Map<Long, Map<Long, String>> userItemsMap = new HashMap<>(); + + private ListState<Map<Long, Map<Long, String>>> userAllItemsMapState; + + private CollectingUserBehavior(int minUserItemInteraction, int maxUserItemInteraction) { + this.minUserItemInteraction = minUserItemInteraction; + this.maxUserItemInteraction = maxUserItemInteraction; + } + + @Override + public void endInput() { + + userItemsMap.forEach( + (user, items) -> { + if (items.size() >= minUserItemInteraction + && items.size() <= maxUserItemInteraction) { + items.forEach( + (item, nullValue) -> + output.collect( + new StreamRecord<>( + new Tuple3<>(user, item, items)))); + } + }); + + userAllItemsMapState.clear(); + } + + @Override + public void processElement(StreamRecord<Tuple2<Long, Long>> element) { + Tuple2<Long, Long> userAndItem = element.getValue(); + long user = userAndItem.f0; + long item = userAndItem.f1; + Map<Long, String> items = userItemsMap.get(user); Review Comment: Is `userAndPurchasedItems` more intuitive? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java: ########## @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Swing algorithm. + * + * <p>Swing is an item recall algorithm. The topology of user-item graph usually can be described as + * user-item-user or item-user-item, which are like 'swing'. For example, if both user <em>u</em> + * and user <em>v</em> have purchased the same commodity <em>i</em> , they will form a relationship + * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased commodity <em>j</em> in + * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are similar. The similarity + * between items in Swing is defined as + * + * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap + * U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha_2+|I_u\cap + * I_v|}} $$ + * + * <p>This implementation is based on the algorithm proposed in the paper: "Large Scale Product + * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, Yadong Zhu and Yi Zhang. + * (<a href="https://arxiv.org/pdf/2010.05525.pdf">https://arxiv.org/pdf/2010.05525.pdf</a>) + */ +public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Swing() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + Preconditions.checkArgument(inputs.length == 1); + final ResolvedSchema schema = inputs[0].getResolvedSchema(); + + if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol)) + && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, itemCol)))) { + throw new IllegalArgumentException("The types of user and item columns must be Long."); + } + + if (getMaxUserBehavior() < getMinUserBehavior()) { + throw new IllegalArgumentException( + "The maxUserBehavior must be larger or equal to minUserBehavior."); + } + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers = + tEnv.toDataStream(inputs[0]) + .map( + row -> { + if (row.getFieldAs(userCol) == null + || row.getFieldAs(itemCol) == null) { + throw new RuntimeException( + "Data of user and item column must not be null."); + } + return Tuple2.of( + (Long) row.getFieldAs(userCol), + (Long) row.getFieldAs(itemCol)); + }) + .returns(Types.TUPLE(Types.LONG, Types.LONG)); + + SingleOutputStreamOperator<Tuple3<Long, Long, Map<Long, String>>> userAllItemsStream = + itemUsers + .keyBy(tuple -> tuple.f0) + .transform( + "collectingUserBehavior", + Types.TUPLE( + Types.LONG, + Types.LONG, + Types.MAP(Types.LONG, Types.STRING)), + new CollectingUserBehavior( + getMinUserBehavior(), getMaxUserBehavior())); + + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + new TypeInformation[] {Types.LONG, Types.STRING}, + new String[] {getItemCol(), getOutputCol()}); + + DataStream<Row> output = + userAllItemsStream + .keyBy(tuple -> tuple.f1) + .transform( + "computingSimilarItems", + outputTypeInfo, + new ComputingSimilarItems( + getK(), + getMaxUserNumPerItem(), + getAlpha1(), + getAlpha2(), + getBeta())); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static Swing load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + /** + * Collects user behavior data and appends to the input table. + * + * <p>During the process, this operator collects users and all items he/she has purchased, and + * its input table must be bounded. Because Flink doesn't support type info of `Set` officially, + * The appended column is `Map` contains items as key and maps null value. + */ + private static class CollectingUserBehavior + extends AbstractStreamOperator<Tuple3<Long, Long, Map<Long, String>>> + implements OneInputStreamOperator< + Tuple2<Long, Long>, Tuple3<Long, Long, Map<Long, String>>>, + BoundedOneInput { + private final int minUserItemInteraction; + private final int maxUserItemInteraction; + + // Maps a user id to a set of items. Because ListState cannot keep values of type `Set`, + // we use `Map<Long, String>` with null values instead. So does `userItemsMap` and + // `itemUsersMap` in `ComputingSimilarItems`. + private Map<Long, Map<Long, String>> userItemsMap = new HashMap<>(); + + private ListState<Map<Long, Map<Long, String>>> userAllItemsMapState; + + private CollectingUserBehavior(int minUserItemInteraction, int maxUserItemInteraction) { + this.minUserItemInteraction = minUserItemInteraction; + this.maxUserItemInteraction = maxUserItemInteraction; + } + + @Override + public void endInput() { + + userItemsMap.forEach( Review Comment: It seems that we should output a set instead of a map. Am I understanding it correctly? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java: ########## @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params for {@link Swing}. + * + * @param <T> The class type of this instance. + */ +public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> { + Param<String> USER_COL = + new StringParam("userCol", "Name of user column.", "user", ParamValidators.notNull()); + + Param<String> ITEM_COL = + new StringParam("itemCol", "Name of item column.", "item", ParamValidators.notNull()); + + Param<Integer> MAX_USER_NUM_PER_ITEM = + new IntParam( + "maxUserNumPerItem", + "The max number of users that has purchased for each item. If the number of users that have " Review Comment: Has --> Have ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java: ########## @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Swing algorithm. + * + * <p>Swing is an item recall algorithm. The topology of user-item graph usually can be described as + * user-item-user or item-user-item, which are like 'swing'. For example, if both user <em>u</em> + * and user <em>v</em> have purchased the same commodity <em>i</em> , they will form a relationship + * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased commodity <em>j</em> in + * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are similar. The similarity + * between items in Swing is defined as + * + * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap + * U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha_2+|I_u\cap + * I_v|}} $$ + * + * <p>This implementation is based on the algorithm proposed in the paper: "Large Scale Product + * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, Yadong Zhu and Yi Zhang. + * (<a href="https://arxiv.org/pdf/2010.05525.pdf">https://arxiv.org/pdf/2010.05525.pdf</a>) + */ +public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Swing() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + Preconditions.checkArgument(inputs.length == 1); Review Comment: Let's move this check at the begining of method. In general, we want to make the error message to show up as early as possible. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java: ########## @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params for {@link Swing}. + * + * @param <T> The class type of this instance. + */ +public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> { + Param<String> USER_COL = + new StringParam("userCol", "Name of user column.", "user", ParamValidators.notNull()); + + Param<String> ITEM_COL = + new StringParam("itemCol", "Name of item column.", "item", ParamValidators.notNull()); + + Param<Integer> MAX_USER_NUM_PER_ITEM = + new IntParam( + "maxUserNumPerItem", + "The max number of users that has purchased for each item. If the number of users that have " + + "purchased this item is larger than this value, then only maxUserNumPerItem users will " + + "be sampled and used in the computation logic.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> K = + new IntParam( + "k", + "The max number of similar items to output for each item.", + 100, + ParamValidators.gt(0)); + + Param<Integer> MIN_USER_BEHAVIOR = + new IntParam( + "minUserBehavior", + "The min number of interaction behavior between item and user.", + 10, + ParamValidators.gt(0)); + + Param<Integer> MAX_USER_BEHAVIOR = + new IntParam( + "maxUserBehavior", + "The max number of interaction behavior between item and user. " + + "The algorithm filters out activate users.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> ALPHA1 = + new IntParam( + "alpha1", + "This parameter is used to calculate weight of each user. " + + "The higher alpha1 is, the smaller weight each user gets.", + 15, + ParamValidators.gtEq(0)); + + Param<Integer> ALPHA2 = + new IntParam( + "alpha2", + "This parameter is used to calculate similarity of users. " + + "The higher alpha2 is, the less the similarity score is.", + 0, + ParamValidators.gtEq(0)); + + Param<Double> BETA = + new DoubleParam( + "beta", + "This parameter is used to calculate weight of each user. " Review Comment: Decay factor for number of users that have purchased one item. The higher beta is, the less purchasing behavior contributes to the similarity score. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java: ########## @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Swing algorithm. + * + * <p>Swing is an item recall algorithm. The topology of user-item graph usually can be described as + * user-item-user or item-user-item, which are like 'swing'. For example, if both user <em>u</em> + * and user <em>v</em> have purchased the same commodity <em>i</em> , they will form a relationship + * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased commodity <em>j</em> in + * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are similar. The similarity + * between items in Swing is defined as + * + * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap + * U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha_2+|I_u\cap + * I_v|}} $$ + * + * <p>This implementation is based on the algorithm proposed in the paper: "Large Scale Product + * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, Yadong Zhu and Yi Zhang. + * (<a href="https://arxiv.org/pdf/2010.05525.pdf">https://arxiv.org/pdf/2010.05525.pdf</a>) + */ +public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Swing() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + Preconditions.checkArgument(inputs.length == 1); + final ResolvedSchema schema = inputs[0].getResolvedSchema(); + + if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol)) + && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, itemCol)))) { + throw new IllegalArgumentException("The types of user and item columns must be Long."); + } + + if (getMaxUserBehavior() < getMinUserBehavior()) { + throw new IllegalArgumentException( + "The maxUserBehavior must be larger or equal to minUserBehavior."); + } + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers = + tEnv.toDataStream(inputs[0]) + .map( + row -> { + if (row.getFieldAs(userCol) == null + || row.getFieldAs(itemCol) == null) { + throw new RuntimeException( + "Data of user and item column must not be null."); + } + return Tuple2.of( + (Long) row.getFieldAs(userCol), + (Long) row.getFieldAs(itemCol)); + }) + .returns(Types.TUPLE(Types.LONG, Types.LONG)); + + SingleOutputStreamOperator<Tuple3<Long, Long, Map<Long, String>>> userAllItemsStream = + itemUsers + .keyBy(tuple -> tuple.f0) + .transform( + "collectingUserBehavior", + Types.TUPLE( + Types.LONG, + Types.LONG, + Types.MAP(Types.LONG, Types.STRING)), + new CollectingUserBehavior( + getMinUserBehavior(), getMaxUserBehavior())); + + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + new TypeInformation[] {Types.LONG, Types.STRING}, + new String[] {getItemCol(), getOutputCol()}); + + DataStream<Row> output = + userAllItemsStream + .keyBy(tuple -> tuple.f1) + .transform( + "computingSimilarItems", + outputTypeInfo, + new ComputingSimilarItems( + getK(), + getMaxUserNumPerItem(), + getAlpha1(), + getAlpha2(), + getBeta())); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static Swing load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + /** + * Collects user behavior data and appends to the input table. + * + * <p>During the process, this operator collects users and all items he/she has purchased, and + * its input table must be bounded. Because Flink doesn't support type info of `Set` officially, + * The appended column is `Map` contains items as key and maps null value. + */ + private static class CollectingUserBehavior + extends AbstractStreamOperator<Tuple3<Long, Long, Map<Long, String>>> + implements OneInputStreamOperator< + Tuple2<Long, Long>, Tuple3<Long, Long, Map<Long, String>>>, + BoundedOneInput { + private final int minUserItemInteraction; + private final int maxUserItemInteraction; + + // Maps a user id to a set of items. Because ListState cannot keep values of type `Set`, + // we use `Map<Long, String>` with null values instead. So does `userItemsMap` and + // `itemUsersMap` in `ComputingSimilarItems`. + private Map<Long, Map<Long, String>> userItemsMap = new HashMap<>(); + + private ListState<Map<Long, Map<Long, String>>> userAllItemsMapState; + + private CollectingUserBehavior(int minUserItemInteraction, int maxUserItemInteraction) { + this.minUserItemInteraction = minUserItemInteraction; + this.maxUserItemInteraction = maxUserItemInteraction; + } + + @Override + public void endInput() { + + userItemsMap.forEach( + (user, items) -> { + if (items.size() >= minUserItemInteraction + && items.size() <= maxUserItemInteraction) { + items.forEach( + (item, nullValue) -> + output.collect( + new StreamRecord<>( + new Tuple3<>(user, item, items)))); + } + }); + + userAllItemsMapState.clear(); + } + + @Override + public void processElement(StreamRecord<Tuple2<Long, Long>> element) { + Tuple2<Long, Long> userAndItem = element.getValue(); + long user = userAndItem.f0; + long item = userAndItem.f1; + Map<Long, String> items = userItemsMap.get(user); + + if (items == null) { + items = new LinkedHashMap<>(); + } + + if (items.size() <= maxUserItemInteraction) { + items.put(item, null); + } + + userItemsMap.put(user, items); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + userAllItemsMapState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "userAllItemsMapState", + Types.MAP( + Types.LONG, + Types.MAP(Types.LONG, Types.STRING)))); + + OperatorStateUtils.getUniqueElement(userAllItemsMapState, "userAllItemsMapState") + .ifPresent( + stat -> { + userItemsMap = stat; + }); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + userAllItemsMapState.update(Collections.singletonList(userItemsMap)); + } + } + + /** Calculates similarity between items and keep top k similar items of each target item. */ + private static class ComputingSimilarItems extends AbstractStreamOperator<Row> + implements OneInputStreamOperator<Tuple3<Long, Long, Map<Long, String>>, Row>, + BoundedOneInput { + + private Map<Long, Map<Long, String>> userItemsMap = new HashMap<>(); Review Comment: Are `userAndPurchasedItems` and `itemAndPurchasedUsers` more intutive? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java: ########## @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params for {@link Swing}. + * + * @param <T> The class type of this instance. + */ +public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> { + Param<String> USER_COL = + new StringParam("userCol", "Name of user column.", "user", ParamValidators.notNull()); + + Param<String> ITEM_COL = + new StringParam("itemCol", "Name of item column.", "item", ParamValidators.notNull()); + + Param<Integer> MAX_USER_NUM_PER_ITEM = + new IntParam( + "maxUserNumPerItem", + "The max number of users that has purchased for each item. If the number of users that have " + + "purchased this item is larger than this value, then only maxUserNumPerItem users will " + + "be sampled and used in the computation logic.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> K = + new IntParam( + "k", + "The max number of similar items to output for each item.", + 100, + ParamValidators.gt(0)); + + Param<Integer> MIN_USER_BEHAVIOR = + new IntParam( + "minUserBehavior", + "The min number of interaction behavior between item and user.", + 10, + ParamValidators.gt(0)); + + Param<Integer> MAX_USER_BEHAVIOR = + new IntParam( + "maxUserBehavior", + "The max number of interaction behavior between item and user. " + + "The algorithm filters out activate users.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> ALPHA1 = + new IntParam( + "alpha1", + "This parameter is used to calculate weight of each user. " + + "The higher alpha1 is, the smaller weight each user gets.", + 15, + ParamValidators.gtEq(0)); + + Param<Integer> ALPHA2 = + new IntParam( + "alpha2", + "This parameter is used to calculate similarity of users. " + + "The higher alpha2 is, the less the similarity score is.", + 0, Review Comment: `alpha2` should be greater than zero, right? ########## flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java: ########## @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.recommendation.swing.Swing; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** Tests {@link Swing}. */ +public class SwingTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + RowTypeInfo trainDataTypeInfo = + new RowTypeInfo( + new TypeInformation[] { + BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO + }, + new String[] {"user_id", "item_id"}); + private static final List<Row> trainRows = + new ArrayList<>( + Arrays.asList( + Row.of(0L, 10L), + Row.of(0L, 11L), + Row.of(0L, 12L), + Row.of(1L, 13L), + Row.of(1L, 12L), + Row.of(2L, 10L), + Row.of(2L, 11L), + Row.of(2L, 12L), + Row.of(3L, 13L), + Row.of(3L, 12L))); + + private static final List<Row> expectedScoreRows = + new ArrayList<>( + Arrays.asList( + Row.of(10L, "11,0.058845768947156235;12,0.058845768947156235"), + Row.of(11L, "10,0.058845768947156235;12,0.058845768947156235"), + Row.of( + 12L, + "13,0.09134833828228624;10,0.058845768947156235;11,0.058845768947156235"), + Row.of(13L, "12,0.09134833828228624"))); + + @Before + public void before() { + env = TestUtils.getExecutionEnvironment(); + tEnv = StreamTableEnvironment.create(env); + trainData = tEnv.fromDataStream(env.fromCollection(trainRows, trainDataTypeInfo)); + } + + private void compareResultAndExpected(List<Row> results) { + results.sort((o1, o2) -> Long.compare(o1.getFieldAs(0), o2.getFieldAs(0))); + + for (int i = 0; i < results.size(); i++) { + Row result = results.get(i); + String itemRankScore = result.getFieldAs(1); + Row expect = expectedScoreRows.get(i); + assertEquals(expect.getField(0), result.getField(0)); + assertEquals(expect.getField(1), itemRankScore); + } + } + + @Test + public void testParam() { + Swing swing = new Swing(); + + assertEquals("item", swing.getItemCol()); + assertEquals("user", swing.getUserCol()); + assertEquals(100, swing.getK()); + assertEquals(1000, swing.getMaxUserNumPerItem()); + assertEquals(10, swing.getMinUserBehavior()); + assertEquals(1000, swing.getMaxUserBehavior()); + assertEquals(15, swing.getAlpha1()); + assertEquals(0, swing.getAlpha2()); + assertEquals(0.3, swing.getBeta(), 1e-9); + + swing.setItemCol("item_1") + .setUserCol("user_1") + .setK(20) + .setMaxUserNumPerItem(500) + .setMinUserBehavior(10) + .setMaxUserBehavior(50) + .setAlpha1(5) + .setAlpha2(1) + .setBeta(0.35); + + assertEquals("item_1", swing.getItemCol()); + assertEquals("user_1", swing.getUserCol()); + assertEquals(20, swing.getK()); + assertEquals(500, swing.getMaxUserNumPerItem()); + assertEquals(10, swing.getMinUserBehavior()); + assertEquals(50, swing.getMaxUserBehavior()); + assertEquals(5, swing.getAlpha1()); + assertEquals(1, swing.getAlpha2()); + assertEquals(0.35, swing.getBeta(), 1e-9); + } + + @Test + public void testDataType() { + List<Row> rows = + new ArrayList<>(Arrays.asList(Row.of(0, "10"), Row.of(1, "11"), Row.of(2, ""))); + + DataStream<Row> dataStream = + env.fromCollection( + rows, + new RowTypeInfo( + new TypeInformation[] { + BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + }, + new String[] {"user_id", "item_id"})); + Table data = tEnv.fromDataStream(dataStream); + + try { + Table[] swingResultTables = + new Swing() + .setItemCol("item_id") + .setUserCol("user_id") + .setOutputCol("item_score") + .setMinUserBehavior(1) + .transform(data); + swingResultTables[0].execute().print(); + fail(); + } catch (RuntimeException e) { + assertEquals(IllegalArgumentException.class, e.getClass()); + assertEquals("The types of user and item columns must be Long.", e.getMessage()); + } + } + + @Test + public void testNumberFormat() { + List<Row> rows = + new ArrayList<>( + Arrays.asList( + Row.of(0L, 10L), + Row.of(null, 12L), + Row.of(1L, 13L), + Row.of(3L, 12L))); + + DataStream<Row> dataStream = env.fromCollection(rows, trainDataTypeInfo); + Table data = tEnv.fromDataStream(dataStream); + Swing swing = new Swing().setItemCol("item_id").setUserCol("user_id").setMinUserBehavior(1); + Table[] swingResultTables = swing.transform(data); + + try { + swingResultTables[0].execute().print(); + fail(); + } catch (RuntimeException e) { + Throwable exception = ExceptionUtils.getRootCause(e); + assertEquals(RuntimeException.class, exception.getClass()); + assertEquals("Data of user and item column must not be null.", exception.getMessage()); + } + } + + @Test + public void testOutputSchema() { + Swing swing = + new Swing() + .setItemCol("item_id") + .setUserCol("user_id") + .setOutputCol("item_score") + .setMinUserBehavior(1); + Table[] swingResultTables = swing.transform(trainData); + Table output = swingResultTables[0]; + + assertEquals( + Arrays.asList("item_id", "item_score"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFewerItemCase() { Review Comment: Let's also test `setMaxUserBehavior`. Could this move to `testTransform`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java: ########## @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Swing algorithm. + * + * <p>Swing is an item recall algorithm. The topology of user-item graph usually can be described as + * user-item-user or item-user-item, which are like 'swing'. For example, if both user <em>u</em> + * and user <em>v</em> have purchased the same commodity <em>i</em> , they will form a relationship + * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased commodity <em>j</em> in + * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are similar. The similarity + * between items in Swing is defined as + * + * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap + * U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha_2+|I_u\cap + * I_v|}} $$ + * + * <p>This implementation is based on the algorithm proposed in the paper: "Large Scale Product + * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, Yadong Zhu and Yi Zhang. + * (<a href="https://arxiv.org/pdf/2010.05525.pdf">https://arxiv.org/pdf/2010.05525.pdf</a>) + */ +public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Swing() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + Preconditions.checkArgument(inputs.length == 1); + final ResolvedSchema schema = inputs[0].getResolvedSchema(); + + if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol)) + && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, itemCol)))) { + throw new IllegalArgumentException("The types of user and item columns must be Long."); + } + + if (getMaxUserBehavior() < getMinUserBehavior()) { + throw new IllegalArgumentException( + "The maxUserBehavior must be larger or equal to minUserBehavior."); + } + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers = + tEnv.toDataStream(inputs[0]) + .map( + row -> { + if (row.getFieldAs(userCol) == null Review Comment: We should probably avoid repeatedly accessing a field by name since it is not that efficient. Can you store it in a local variable and use it in the return statement? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java: ########## @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Swing algorithm. + * + * <p>Swing is an item recall algorithm. The topology of user-item graph usually can be described as + * user-item-user or item-user-item, which are like 'swing'. For example, if both user <em>u</em> + * and user <em>v</em> have purchased the same commodity <em>i</em> , they will form a relationship + * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased commodity <em>j</em> in + * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are similar. The similarity + * between items in Swing is defined as + * + * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap + * U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha_2+|I_u\cap + * I_v|}} $$ + * + * <p>This implementation is based on the algorithm proposed in the paper: "Large Scale Product + * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, Yadong Zhu and Yi Zhang. + * (<a href="https://arxiv.org/pdf/2010.05525.pdf">https://arxiv.org/pdf/2010.05525.pdf</a>) + */ +public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Swing() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + Preconditions.checkArgument(inputs.length == 1); + final ResolvedSchema schema = inputs[0].getResolvedSchema(); + + if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol)) + && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, itemCol)))) { + throw new IllegalArgumentException("The types of user and item columns must be Long."); + } + + if (getMaxUserBehavior() < getMinUserBehavior()) { + throw new IllegalArgumentException( + "The maxUserBehavior must be larger or equal to minUserBehavior."); + } + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers = + tEnv.toDataStream(inputs[0]) + .map( + row -> { + if (row.getFieldAs(userCol) == null + || row.getFieldAs(itemCol) == null) { + throw new RuntimeException( + "Data of user and item column must not be null."); + } + return Tuple2.of( + (Long) row.getFieldAs(userCol), + (Long) row.getFieldAs(itemCol)); + }) + .returns(Types.TUPLE(Types.LONG, Types.LONG)); + + SingleOutputStreamOperator<Tuple3<Long, Long, Map<Long, String>>> userAllItemsStream = + itemUsers + .keyBy(tuple -> tuple.f0) + .transform( + "collectingUserBehavior", + Types.TUPLE( + Types.LONG, + Types.LONG, + Types.MAP(Types.LONG, Types.STRING)), + new CollectingUserBehavior( + getMinUserBehavior(), getMaxUserBehavior())); + + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + new TypeInformation[] {Types.LONG, Types.STRING}, + new String[] {getItemCol(), getOutputCol()}); + + DataStream<Row> output = + userAllItemsStream + .keyBy(tuple -> tuple.f1) + .transform( + "computingSimilarItems", + outputTypeInfo, + new ComputingSimilarItems( + getK(), + getMaxUserNumPerItem(), + getAlpha1(), + getAlpha2(), + getBeta())); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static Swing load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + /** + * Collects user behavior data and appends to the input table. + * + * <p>During the process, this operator collects users and all items he/she has purchased, and + * its input table must be bounded. Because Flink doesn't support type info of `Set` officially, + * The appended column is `Map` contains items as key and maps null value. + */ + private static class CollectingUserBehavior + extends AbstractStreamOperator<Tuple3<Long, Long, Map<Long, String>>> + implements OneInputStreamOperator< + Tuple2<Long, Long>, Tuple3<Long, Long, Map<Long, String>>>, + BoundedOneInput { + private final int minUserItemInteraction; + private final int maxUserItemInteraction; + + // Maps a user id to a set of items. Because ListState cannot keep values of type `Set`, + // we use `Map<Long, String>` with null values instead. So does `userItemsMap` and + // `itemUsersMap` in `ComputingSimilarItems`. + private Map<Long, Map<Long, String>> userItemsMap = new HashMap<>(); Review Comment: Is `userAndPurchasedItems` more intuitive? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java: ########## @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Swing algorithm. + * + * <p>Swing is an item recall algorithm. The topology of user-item graph usually can be described as + * user-item-user or item-user-item, which are like 'swing'. For example, if both user <em>u</em> + * and user <em>v</em> have purchased the same commodity <em>i</em> , they will form a relationship + * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased commodity <em>j</em> in + * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are similar. The similarity + * between items in Swing is defined as + * + * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap + * U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha_2+|I_u\cap + * I_v|}} $$ + * + * <p>This implementation is based on the algorithm proposed in the paper: "Large Scale Product + * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, Yadong Zhu and Yi Zhang. + * (<a href="https://arxiv.org/pdf/2010.05525.pdf">https://arxiv.org/pdf/2010.05525.pdf</a>) + */ +public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Swing() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + Preconditions.checkArgument(inputs.length == 1); + final ResolvedSchema schema = inputs[0].getResolvedSchema(); + + if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol)) + && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, itemCol)))) { + throw new IllegalArgumentException("The types of user and item columns must be Long."); + } + + if (getMaxUserBehavior() < getMinUserBehavior()) { + throw new IllegalArgumentException( + "The maxUserBehavior must be larger or equal to minUserBehavior."); + } + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers = + tEnv.toDataStream(inputs[0]) + .map( + row -> { + if (row.getFieldAs(userCol) == null + || row.getFieldAs(itemCol) == null) { + throw new RuntimeException( + "Data of user and item column must not be null."); + } + return Tuple2.of( + (Long) row.getFieldAs(userCol), + (Long) row.getFieldAs(itemCol)); + }) + .returns(Types.TUPLE(Types.LONG, Types.LONG)); + + SingleOutputStreamOperator<Tuple3<Long, Long, Map<Long, String>>> userAllItemsStream = + itemUsers + .keyBy(tuple -> tuple.f0) + .transform( + "collectingUserBehavior", + Types.TUPLE( + Types.LONG, + Types.LONG, + Types.MAP(Types.LONG, Types.STRING)), + new CollectingUserBehavior( + getMinUserBehavior(), getMaxUserBehavior())); + + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + new TypeInformation[] {Types.LONG, Types.STRING}, + new String[] {getItemCol(), getOutputCol()}); + + DataStream<Row> output = + userAllItemsStream + .keyBy(tuple -> tuple.f1) + .transform( + "computingSimilarItems", + outputTypeInfo, + new ComputingSimilarItems( + getK(), + getMaxUserNumPerItem(), + getAlpha1(), + getAlpha2(), + getBeta())); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static Swing load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + /** + * Collects user behavior data and appends to the input table. + * + * <p>During the process, this operator collects users and all items he/she has purchased, and + * its input table must be bounded. Because Flink doesn't support type info of `Set` officially, + * The appended column is `Map` contains items as key and maps null value. + */ + private static class CollectingUserBehavior + extends AbstractStreamOperator<Tuple3<Long, Long, Map<Long, String>>> + implements OneInputStreamOperator< + Tuple2<Long, Long>, Tuple3<Long, Long, Map<Long, String>>>, + BoundedOneInput { + private final int minUserItemInteraction; + private final int maxUserItemInteraction; + + // Maps a user id to a set of items. Because ListState cannot keep values of type `Set`, + // we use `Map<Long, String>` with null values instead. So does `userItemsMap` and + // `itemUsersMap` in `ComputingSimilarItems`. + private Map<Long, Map<Long, String>> userItemsMap = new HashMap<>(); + + private ListState<Map<Long, Map<Long, String>>> userAllItemsMapState; + + private CollectingUserBehavior(int minUserItemInteraction, int maxUserItemInteraction) { + this.minUserItemInteraction = minUserItemInteraction; + this.maxUserItemInteraction = maxUserItemInteraction; + } + + @Override + public void endInput() { + + userItemsMap.forEach( + (user, items) -> { + if (items.size() >= minUserItemInteraction + && items.size() <= maxUserItemInteraction) { + items.forEach( + (item, nullValue) -> + output.collect( + new StreamRecord<>( + new Tuple3<>(user, item, items)))); + } + }); + + userAllItemsMapState.clear(); + } + + @Override + public void processElement(StreamRecord<Tuple2<Long, Long>> element) { + Tuple2<Long, Long> userAndItem = element.getValue(); + long user = userAndItem.f0; + long item = userAndItem.f1; + Map<Long, String> items = userItemsMap.get(user); + + if (items == null) { + items = new LinkedHashMap<>(); + } + + if (items.size() <= maxUserItemInteraction) { + items.put(item, null); + } + + userItemsMap.put(user, items); Review Comment: It seems that we should remove this line and put the items in Line#220. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java: ########## @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Swing algorithm. + * + * <p>Swing is an item recall algorithm. The topology of user-item graph usually can be described as + * user-item-user or item-user-item, which are like 'swing'. For example, if both user <em>u</em> + * and user <em>v</em> have purchased the same commodity <em>i</em> , they will form a relationship + * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased commodity <em>j</em> in + * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are similar. The similarity + * between items in Swing is defined as + * + * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap + * U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha_2+|I_u\cap + * I_v|}} $$ + * + * <p>This implementation is based on the algorithm proposed in the paper: "Large Scale Product + * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, Yadong Zhu and Yi Zhang. + * (<a href="https://arxiv.org/pdf/2010.05525.pdf">https://arxiv.org/pdf/2010.05525.pdf</a>) + */ +public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Swing() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + + final String userCol = getUserCol(); + final String itemCol = getItemCol(); + Preconditions.checkArgument(inputs.length == 1); + final ResolvedSchema schema = inputs[0].getResolvedSchema(); + + if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol)) + && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, itemCol)))) { + throw new IllegalArgumentException("The types of user and item columns must be Long."); + } + + if (getMaxUserBehavior() < getMinUserBehavior()) { + throw new IllegalArgumentException( + "The maxUserBehavior must be larger or equal to minUserBehavior."); + } + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers = + tEnv.toDataStream(inputs[0]) + .map( + row -> { + if (row.getFieldAs(userCol) == null + || row.getFieldAs(itemCol) == null) { + throw new RuntimeException( + "Data of user and item column must not be null."); + } + return Tuple2.of( + (Long) row.getFieldAs(userCol), + (Long) row.getFieldAs(itemCol)); + }) + .returns(Types.TUPLE(Types.LONG, Types.LONG)); + + SingleOutputStreamOperator<Tuple3<Long, Long, Map<Long, String>>> userAllItemsStream = Review Comment: Is `userBehavior` easier for understanding the semantic of the computation logic? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java: ########## @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params for {@link Swing}. + * + * @param <T> The class type of this instance. + */ +public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> { + Param<String> USER_COL = + new StringParam("userCol", "Name of user column.", "user", ParamValidators.notNull()); Review Comment: Let's update the description as `User column name` and `User column name` following the existing conventions like `HasFeaturesCol`. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java: ########## @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params for {@link Swing}. + * + * @param <T> The class type of this instance. + */ +public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> { + Param<String> USER_COL = + new StringParam("userCol", "Name of user column.", "user", ParamValidators.notNull()); + + Param<String> ITEM_COL = + new StringParam("itemCol", "Name of item column.", "item", ParamValidators.notNull()); + + Param<Integer> MAX_USER_NUM_PER_ITEM = + new IntParam( + "maxUserNumPerItem", + "The max number of users that has purchased for each item. If the number of users that have " + + "purchased this item is larger than this value, then only maxUserNumPerItem users will " + + "be sampled and used in the computation logic.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> K = + new IntParam( + "k", + "The max number of similar items to output for each item.", + 100, + ParamValidators.gt(0)); + + Param<Integer> MIN_USER_BEHAVIOR = + new IntParam( + "minUserBehavior", + "The min number of interaction behavior between item and user.", Review Comment: How about using the following description: `The min number of items that a user purchases. If the items purchased by a user is smaller than this value, then this user is filtered out and will not be used in the computation logic.` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java: ########## @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params for {@link Swing}. + * + * @param <T> The class type of this instance. + */ +public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> { + Param<String> USER_COL = + new StringParam("userCol", "Name of user column.", "user", ParamValidators.notNull()); + + Param<String> ITEM_COL = + new StringParam("itemCol", "Name of item column.", "item", ParamValidators.notNull()); + + Param<Integer> MAX_USER_NUM_PER_ITEM = + new IntParam( + "maxUserNumPerItem", + "The max number of users that has purchased for each item. If the number of users that have " + + "purchased this item is larger than this value, then only maxUserNumPerItem users will " + + "be sampled and used in the computation logic.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> K = + new IntParam( + "k", + "The max number of similar items to output for each item.", + 100, + ParamValidators.gt(0)); + + Param<Integer> MIN_USER_BEHAVIOR = + new IntParam( + "minUserBehavior", + "The min number of interaction behavior between item and user.", + 10, + ParamValidators.gt(0)); + + Param<Integer> MAX_USER_BEHAVIOR = + new IntParam( + "maxUserBehavior", + "The max number of interaction behavior between item and user. " Review Comment: How about using the following description: `The max number of items that a user purchases. If the items purchased by a user is larger than this value, then this user is filtered out and will not be used in the computation logic.` ########## flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java: ########## @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.recommendation.swing.Swing; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** Tests {@link Swing}. */ +public class SwingTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + RowTypeInfo trainDataTypeInfo = Review Comment: Let's make `trainDataTypeInfo` and `trainRows` local variables since they are used only once. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java: ########## @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.recommendation.swing.Swing; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** Tests {@link Swing}. */ +public class SwingTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + RowTypeInfo trainDataTypeInfo = + new RowTypeInfo( + new TypeInformation[] { + BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO + }, + new String[] {"user_id", "item_id"}); + private static final List<Row> trainRows = + new ArrayList<>( + Arrays.asList( + Row.of(0L, 10L), + Row.of(0L, 11L), + Row.of(0L, 12L), + Row.of(1L, 13L), + Row.of(1L, 12L), + Row.of(2L, 10L), + Row.of(2L, 11L), + Row.of(2L, 12L), + Row.of(3L, 13L), + Row.of(3L, 12L))); + + private static final List<Row> expectedScoreRows = + new ArrayList<>( + Arrays.asList( + Row.of(10L, "11,0.058845768947156235;12,0.058845768947156235"), + Row.of(11L, "10,0.058845768947156235;12,0.058845768947156235"), + Row.of( + 12L, + "13,0.09134833828228624;10,0.058845768947156235;11,0.058845768947156235"), + Row.of(13L, "12,0.09134833828228624"))); + + @Before + public void before() { + env = TestUtils.getExecutionEnvironment(); + tEnv = StreamTableEnvironment.create(env); + trainData = tEnv.fromDataStream(env.fromCollection(trainRows, trainDataTypeInfo)); + } + + private void compareResultAndExpected(List<Row> results) { + results.sort((o1, o2) -> Long.compare(o1.getFieldAs(0), o2.getFieldAs(0))); + + for (int i = 0; i < results.size(); i++) { + Row result = results.get(i); + String itemRankScore = result.getFieldAs(1); + Row expect = expectedScoreRows.get(i); + assertEquals(expect.getField(0), result.getField(0)); + assertEquals(expect.getField(1), itemRankScore); + } + } + + @Test + public void testParam() { + Swing swing = new Swing(); + + assertEquals("item", swing.getItemCol()); + assertEquals("user", swing.getUserCol()); + assertEquals(100, swing.getK()); + assertEquals(1000, swing.getMaxUserNumPerItem()); + assertEquals(10, swing.getMinUserBehavior()); + assertEquals(1000, swing.getMaxUserBehavior()); + assertEquals(15, swing.getAlpha1()); + assertEquals(0, swing.getAlpha2()); + assertEquals(0.3, swing.getBeta(), 1e-9); + + swing.setItemCol("item_1") + .setUserCol("user_1") + .setK(20) + .setMaxUserNumPerItem(500) + .setMinUserBehavior(10) + .setMaxUserBehavior(50) + .setAlpha1(5) + .setAlpha2(1) + .setBeta(0.35); + + assertEquals("item_1", swing.getItemCol()); + assertEquals("user_1", swing.getUserCol()); + assertEquals(20, swing.getK()); + assertEquals(500, swing.getMaxUserNumPerItem()); + assertEquals(10, swing.getMinUserBehavior()); + assertEquals(50, swing.getMaxUserBehavior()); + assertEquals(5, swing.getAlpha1()); + assertEquals(1, swing.getAlpha2()); + assertEquals(0.35, swing.getBeta(), 1e-9); + } + + @Test + public void testDataType() { + List<Row> rows = + new ArrayList<>(Arrays.asList(Row.of(0, "10"), Row.of(1, "11"), Row.of(2, ""))); + + DataStream<Row> dataStream = + env.fromCollection( + rows, + new RowTypeInfo( + new TypeInformation[] { + BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + }, + new String[] {"user_id", "item_id"})); + Table data = tEnv.fromDataStream(dataStream); + + try { + Table[] swingResultTables = + new Swing() + .setItemCol("item_id") + .setUserCol("user_id") + .setOutputCol("item_score") + .setMinUserBehavior(1) + .transform(data); + swingResultTables[0].execute().print(); + fail(); + } catch (RuntimeException e) { + assertEquals(IllegalArgumentException.class, e.getClass()); + assertEquals("The types of user and item columns must be Long.", e.getMessage()); + } + } + + @Test + public void testNumberFormat() { Review Comment: How about renaming it as `testInputWithNull`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java: ########## @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params for {@link Swing}. + * + * @param <T> The class type of this instance. + */ +public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> { + Param<String> USER_COL = + new StringParam("userCol", "Name of user column.", "user", ParamValidators.notNull()); + + Param<String> ITEM_COL = + new StringParam("itemCol", "Name of item column.", "item", ParamValidators.notNull()); + + Param<Integer> MAX_USER_NUM_PER_ITEM = + new IntParam( + "maxUserNumPerItem", + "The max number of users that has purchased for each item. If the number of users that have " + + "purchased this item is larger than this value, then only maxUserNumPerItem users will " + + "be sampled and used in the computation logic.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> K = + new IntParam( + "k", + "The max number of similar items to output for each item.", + 100, + ParamValidators.gt(0)); + + Param<Integer> MIN_USER_BEHAVIOR = + new IntParam( + "minUserBehavior", + "The min number of interaction behavior between item and user.", + 10, + ParamValidators.gt(0)); + + Param<Integer> MAX_USER_BEHAVIOR = + new IntParam( + "maxUserBehavior", + "The max number of interaction behavior between item and user. " + + "The algorithm filters out activate users.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> ALPHA1 = + new IntParam( + "alpha1", + "This parameter is used to calculate weight of each user. " + + "The higher alpha1 is, the smaller weight each user gets.", + 15, + ParamValidators.gtEq(0)); + + Param<Integer> ALPHA2 = + new IntParam( + "alpha2", + "This parameter is used to calculate similarity of users. " Review Comment: Smooth factor for number of users that have purchased the two target items. The higher alpha2 is, the less purchasing behavior contributes to the similarity score. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java: ########## @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params for {@link Swing}. + * + * @param <T> The class type of this instance. + */ +public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> { + Param<String> USER_COL = + new StringParam("userCol", "Name of user column.", "user", ParamValidators.notNull()); + + Param<String> ITEM_COL = + new StringParam("itemCol", "Name of item column.", "item", ParamValidators.notNull()); + + Param<Integer> MAX_USER_NUM_PER_ITEM = + new IntParam( + "maxUserNumPerItem", + "The max number of users that has purchased for each item. If the number of users that have " + + "purchased this item is larger than this value, then only maxUserNumPerItem users will " + + "be sampled and used in the computation logic.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> K = + new IntParam( + "k", + "The max number of similar items to output for each item.", + 100, + ParamValidators.gt(0)); + + Param<Integer> MIN_USER_BEHAVIOR = + new IntParam( + "minUserBehavior", + "The min number of interaction behavior between item and user.", + 10, + ParamValidators.gt(0)); + + Param<Integer> MAX_USER_BEHAVIOR = + new IntParam( + "maxUserBehavior", + "The max number of interaction behavior between item and user. " + + "The algorithm filters out activate users.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> ALPHA1 = + new IntParam( + "alpha1", + "This parameter is used to calculate weight of each user. " + + "The higher alpha1 is, the smaller weight each user gets.", + 15, Review Comment: As discussed offline, we should set the default values properly. Could you double-check the default values of these parameters? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java: ########## @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.swing; + +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params for {@link Swing}. + * + * @param <T> The class type of this instance. + */ +public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> { + Param<String> USER_COL = + new StringParam("userCol", "Name of user column.", "user", ParamValidators.notNull()); + + Param<String> ITEM_COL = + new StringParam("itemCol", "Name of item column.", "item", ParamValidators.notNull()); + + Param<Integer> MAX_USER_NUM_PER_ITEM = + new IntParam( + "maxUserNumPerItem", + "The max number of users that has purchased for each item. If the number of users that have " + + "purchased this item is larger than this value, then only maxUserNumPerItem users will " + + "be sampled and used in the computation logic.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> K = + new IntParam( + "k", + "The max number of similar items to output for each item.", + 100, + ParamValidators.gt(0)); + + Param<Integer> MIN_USER_BEHAVIOR = + new IntParam( + "minUserBehavior", + "The min number of interaction behavior between item and user.", + 10, + ParamValidators.gt(0)); + + Param<Integer> MAX_USER_BEHAVIOR = + new IntParam( + "maxUserBehavior", + "The max number of interaction behavior between item and user. " + + "The algorithm filters out activate users.", + 1000, + ParamValidators.gt(0)); + + Param<Integer> ALPHA1 = + new IntParam( + "alpha1", + "This parameter is used to calculate weight of each user. " Review Comment: `Smooth factor for number of users that have purchased one item. The higher alpha1 is, the less purchasing behavior contributes to the similarity score.` ########## flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java: ########## @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.recommendation.swing.Swing; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** Tests {@link Swing}. */ +public class SwingTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + RowTypeInfo trainDataTypeInfo = + new RowTypeInfo( + new TypeInformation[] { + BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO + }, + new String[] {"user_id", "item_id"}); + private static final List<Row> trainRows = + new ArrayList<>( + Arrays.asList( + Row.of(0L, 10L), + Row.of(0L, 11L), + Row.of(0L, 12L), + Row.of(1L, 13L), + Row.of(1L, 12L), + Row.of(2L, 10L), + Row.of(2L, 11L), + Row.of(2L, 12L), + Row.of(3L, 13L), + Row.of(3L, 12L))); + + private static final List<Row> expectedScoreRows = + new ArrayList<>( + Arrays.asList( + Row.of(10L, "11,0.058845768947156235;12,0.058845768947156235"), + Row.of(11L, "10,0.058845768947156235;12,0.058845768947156235"), + Row.of( + 12L, + "13,0.09134833828228624;10,0.058845768947156235;11,0.058845768947156235"), + Row.of(13L, "12,0.09134833828228624"))); + + @Before + public void before() { + env = TestUtils.getExecutionEnvironment(); + tEnv = StreamTableEnvironment.create(env); + trainData = tEnv.fromDataStream(env.fromCollection(trainRows, trainDataTypeInfo)); + } + + private void compareResultAndExpected(List<Row> results) { + results.sort((o1, o2) -> Long.compare(o1.getFieldAs(0), o2.getFieldAs(0))); + + for (int i = 0; i < results.size(); i++) { + Row result = results.get(i); + String itemRankScore = result.getFieldAs(1); + Row expect = expectedScoreRows.get(i); + assertEquals(expect.getField(0), result.getField(0)); + assertEquals(expect.getField(1), itemRankScore); + } + } + + @Test + public void testParam() { + Swing swing = new Swing(); + + assertEquals("item", swing.getItemCol()); + assertEquals("user", swing.getUserCol()); + assertEquals(100, swing.getK()); + assertEquals(1000, swing.getMaxUserNumPerItem()); + assertEquals(10, swing.getMinUserBehavior()); + assertEquals(1000, swing.getMaxUserBehavior()); + assertEquals(15, swing.getAlpha1()); + assertEquals(0, swing.getAlpha2()); + assertEquals(0.3, swing.getBeta(), 1e-9); + + swing.setItemCol("item_1") + .setUserCol("user_1") + .setK(20) + .setMaxUserNumPerItem(500) + .setMinUserBehavior(10) + .setMaxUserBehavior(50) + .setAlpha1(5) + .setAlpha2(1) + .setBeta(0.35); + + assertEquals("item_1", swing.getItemCol()); + assertEquals("user_1", swing.getUserCol()); + assertEquals(20, swing.getK()); + assertEquals(500, swing.getMaxUserNumPerItem()); + assertEquals(10, swing.getMinUserBehavior()); + assertEquals(50, swing.getMaxUserBehavior()); + assertEquals(5, swing.getAlpha1()); + assertEquals(1, swing.getAlpha2()); + assertEquals(0.35, swing.getBeta(), 1e-9); + } + + @Test + public void testDataType() { Review Comment: Let's make it a more informative name, e.g., 'testInputWithIllegalDataType'. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
