jiangxin369 commented on code in PR #192:
URL: https://github.com/apache/flink-ml/pull/192#discussion_r1094587044


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java:
##########
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation.swing;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.Set;
+
+/**
+ * An AlgoOperator which implements the Swing algorithm.
+ *
+ * <p>Swing is an item recall model. The topology of user-item graph usually 
can be described as
+ * user-item-user or item-user-item, which are like 'swing'. For example, if 
both user <em>u</em>
+ * and user <em>v</em> have purchased the same commodity <em>i</em> , they 
will form a relationship
+ * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased 
commodity <em>j</em> in
+ * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are 
similar. The formula of
+ * Swing is
+ *
+ * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap
+ * 
U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha\_2+|I_u\cap
+ * I_v|}} $$
+ *
+ * <p>This implementation is based on the algorithm proposed in the paper: 
"Large Scale Product
+ * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, 
Yadong Zhu and Yi Zhang.
+ * (https://arxiv.org/pdf/2010.05525.pdf)
+ */
+public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public Swing() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+
+        final String userCol = getUserCol();
+        final String itemCol = getItemCol();
+        Preconditions.checkArgument(inputs.length == 1);
+        final ResolvedSchema schema = inputs[0].getResolvedSchema();
+
+        if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol))
+                && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, 
itemCol)))) {
+            throw new IllegalArgumentException("The types of user and item 
columns must be Long.");
+        }
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                row -> {
+                                    if (row.getFieldAs(userCol) == null
+                                            || row.getFieldAs(itemCol) == 
null) {
+                                        throw new RuntimeException(
+                                                "Data of user and item column 
must not be null");
+                                    }
+                                    return Tuple2.of(
+                                            ((Number) 
row.getFieldAs(userCol)).longValue(),
+                                            ((Number) 
row.getFieldAs(itemCol)).longValue());
+                                })
+                        .returns(Types.TUPLE(Types.LONG, Types.LONG));
+
+        SingleOutputStreamOperator<Tuple3<Long, Long, List<Long>>> 
userAllItemsStream =
+                itemUsers
+                        .keyBy(tuple -> tuple.f0)
+                        .transform(
+                                "fillUserItemsTable",
+                                Types.TUPLE(Types.LONG, Types.LONG, 
Types.LIST(Types.LONG)),
+                                new CollectingUserBehavior(
+                                        getMinUserBehavior(), 
getMaxUserBehavior()));
+
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        new TypeInformation[] {
+                            BasicTypeInfo.LONG_TYPE_INFO, 
BasicTypeInfo.STRING_TYPE_INFO
+                        },
+                        new String[] {getItemCol(), getOutputCol()});
+
+        DataStream<Row> output =
+                userAllItemsStream
+                        .keyBy(tuple -> tuple.f1)
+                        .transform(
+                                "computingSimilarItems",
+                                outputTypeInfo,
+                                new ComputingSimilarItems(
+                                        getK(),
+                                        getMaxUserNumPerItem(),
+                                        getAlpha1(),
+                                        getAlpha2(),
+                                        getBeta()));
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static Swing load(StreamTableEnvironment tEnv, String path) throws 
IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    /**
+     * Appends one column, that records all items the user has purchased, to 
the input table.
+     *
+     * <p>During the process, this operator collect users and all items a user 
has purchased into a
+     * map of list. When the input is finished, this operator appends the 
certain
+     * user-purchased-items list to each row.
+     */
+    private static class CollectingUserBehavior
+            extends AbstractStreamOperator<Tuple3<Long, Long, List<Long>>>
+            implements OneInputStreamOperator<Tuple2<Long, Long>, Tuple3<Long, 
Long, List<Long>>>,
+                    BoundedOneInput {
+        private final int minUserItemInteraction;
+        private final int maxUserItemInteraction;
+
+        private Map<Long, Set<Long>> userItemsMap = new HashMap<>();
+
+        private ListState<Map<Long, List<Long>>> userAllItemsMapState;
+
+        private CollectingUserBehavior(int minUserItemInteraction, int 
maxUserItemInteraction) {
+            this.minUserItemInteraction = minUserItemInteraction;
+            this.maxUserItemInteraction = maxUserItemInteraction;
+        }
+
+        @Override
+        public void endInput() {
+
+            for (Entry<Long, Set<Long>> entry : userItemsMap.entrySet()) {
+                List<Long> items = new ArrayList<>(entry.getValue());
+                Long user = entry.getKey();
+                if (items.size() < minUserItemInteraction
+                        || items.size() > maxUserItemInteraction) {
+                    continue;
+                }
+                for (Long item : items) {
+                    output.collect(new StreamRecord<>(new Tuple3<>(user, item, 
items)));
+                }
+            }
+
+            userAllItemsMapState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple2<Long, Long>> element) {
+            Tuple2<Long, Long> userAndItem = element.getValue();
+            long user = userAndItem.f0;
+            long item = userAndItem.f1;
+            Set<Long> items = userItemsMap.get(user);
+
+            if (items == null) {
+                Set<Long> value = new LinkedHashSet<>();
+                value.add(item);
+                userItemsMap.put(user, value);
+            } else {
+                if (items.size() <= maxUserItemInteraction) {
+                    items.add(item);
+                }
+            }
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            userAllItemsMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "userAllItemsMapState",
+                                            Types.MAP(Types.LONG, 
Types.LIST(Types.LONG))));
+
+            OperatorStateUtils.getUniqueElement(userAllItemsMapState, 
"userAllItemsMapState")
+                    .ifPresent(
+                            x -> {
+                                userItemsMap = new HashMap<>(x.size());
+                                for (long user : x.keySet()) {
+                                    List<Long> itemList = x.get(user);
+                                    userItemsMap.put(user, new 
LinkedHashSet<>(itemList));
+                                }
+                            });
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            Map<Long, List<Long>> userItemsList = new 
HashMap<>(userItemsMap.size());
+            for (Entry<Long, Set<Long>> entry : userItemsMap.entrySet()) {
+                userItemsList.put(entry.getKey(), new 
ArrayList<>(entry.getValue()));
+            }
+            
userAllItemsMapState.update(Collections.singletonList(userItemsList));
+        }
+    }
+
+    /** Calculates top N similar items of each item. */
+    private static class ComputingSimilarItems extends 
AbstractStreamOperator<Row>
+            implements OneInputStreamOperator<Tuple3<Long, Long, List<Long>>, 
Row>,
+                    BoundedOneInput {
+
+        private Map<Long, HashSet<Long>> userItemsMap = new HashMap<>();
+        private Map<Long, HashSet<Long>> itemUsersMap = new HashMap<>();
+        private ListState<Map<Long, List<Long>>> userLocalItemsMapState;
+        private ListState<Map<Long, List<Long>>> itemUsersMapState;
+
+        private final int k;
+        private final int maxUserNumPerItem;
+        private final int alpha1;
+        private final int alpha2;
+        private final double beta;
+
+        private static Character commaDelimiter = ',';
+        private static Character semicolonDelimiter = ';';
+
+        private ComputingSimilarItems(
+                int k, int maxUserNumPerItem, int alpha1, int alpha2, double 
beta) {
+            this.k = k;
+            this.maxUserNumPerItem = maxUserNumPerItem;
+            this.alpha1 = alpha1;
+            this.alpha2 = alpha2;
+            this.beta = beta;
+        }
+
+        @Override
+        public void endInput() throws Exception {
+
+            Map<Long, Double> userWeights = new HashMap<>(userItemsMap.size());
+            userItemsMap.forEach(
+                    (k, v) -> {
+                        int count = v.size();
+                        userWeights.put(k, calculateWeight(count));
+                    });
+
+            for (long mainItem : itemUsersMap.keySet()) {
+                List<Long> userList = 
sampleUserList(itemUsersMap.get(mainItem), maxUserNumPerItem);
+                HashMap<Long, Double> id2swing = new HashMap<>();
+
+                for (int i = 0; i < userList.size(); i++) {
+                    long u = userList.get(i);
+                    for (int j = i + 1; j < userList.size(); j++) {
+                        long v = userList.get(j);
+                        HashSet<Long> interaction = (HashSet<Long>) 
userItemsMap.get(u).clone();
+                        interaction.retainAll(userItemsMap.get(v));
+                        if (interaction.size() == 0) {
+                            continue;
+                        }
+                        double similarity =
+                                (userWeights.get(u)
+                                        * userWeights.get(v)
+                                        / (alpha2 + interaction.size()));
+                        for (long simItem : interaction) {
+                            if (simItem == mainItem) {
+                                continue;
+                            }
+                            double itemSimilarity =
+                                    id2swing.getOrDefault(simItem, 0.0) + 
similarity;
+                            id2swing.putIfAbsent(simItem, itemSimilarity);
+                        }
+                    }
+                }
+
+                ArrayList<Tuple2<Long, Double>> itemAndScore = new 
ArrayList<>();
+                id2swing.forEach((key, value) -> 
itemAndScore.add(Tuple2.of(key, value)));
+
+                itemAndScore.sort((o1, o2) -> Double.compare(o2.f1, o1.f1));
+
+                if (itemAndScore.size() == 0) {
+                    continue;
+                }
+
+                int itemNums = Math.min(k, itemAndScore.size());
+                StringBuilder sbd = new StringBuilder();
+                for (int i = 0; i < itemNums; i++) {
+                    sbd.append(itemAndScore.get(i).f0).append(commaDelimiter);
+                    
sbd.append(itemAndScore.get(i).f1).append(semicolonDelimiter);
+                }
+                String itemList = sbd.substring(0, sbd.length() - 1);
+
+                output.collect(new StreamRecord<>(Row.of(mainItem, itemList)));
+            }
+
+            userLocalItemsMapState.clear();
+            itemUsersMapState.clear();
+        }
+
+        private double calculateWeight(int size) {
+            return (1.0 / Math.pow(alpha1 + size, beta));
+        }
+
+        private static List<Long> sampleUserList(Set<Long> allUsers, int 
sampleSize) {
+            int totalSize = allUsers.size();
+            if (totalSize < sampleSize) {
+                return new ArrayList(allUsers);
+            }
+
+            List<Long> userList = new ArrayList<>(totalSize);
+            double prob = (double) sampleSize / totalSize;
+            Random rand = new Random();
+
+            for (long u : allUsers) {
+                double guess = rand.nextDouble();
+                if (guess < prob) {
+                    userList.add(u);
+                    sampleSize--;
+                }
+                totalSize--;
+                prob = (double) sampleSize / totalSize;
+            }
+
+            return userList;
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Long, Long, 
List<Long>>> streamRecord)
+                throws Exception {
+            Tuple3<Long, Long, List<Long>> tuple3 = streamRecord.getValue();
+            long user = tuple3.f0;
+            long item = tuple3.f1;
+            List<Long> items = tuple3.f2;
+
+            if (!userItemsMap.containsKey(user)) {
+                HashSet<Long> itemSet = new HashSet<>(items.size());
+                itemSet.addAll(items);

Review Comment:
   You can construct the HashSet from ArrayList by `new HashSet<>(items)` to 
simplify the code. So as the other HashSet construction.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.recommendation.swing.Swing;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Swing}. */
+public class SwingTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    RowTypeInfo trainDataTypeInfo =
+            new RowTypeInfo(
+                    new TypeInformation[] {
+                        BasicTypeInfo.LONG_TYPE_INFO, 
BasicTypeInfo.LONG_TYPE_INFO
+                    },
+                    new String[] {"user_id", "item_id"});
+    private static final List<Row> trainRows =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(0L, 10L),
+                            Row.of(0L, 11L),
+                            Row.of(0L, 12L),
+                            Row.of(1L, 13L),
+                            Row.of(1L, 12L),
+                            Row.of(2L, 10L),
+                            Row.of(2L, 11L),
+                            Row.of(2L, 12L),
+                            Row.of(3L, 13L),
+                            Row.of(3L, 12L)));
+
+    private static final List<Row> expectedScoreRows =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(10L, 
"11,0.058845768947156235;12,0.058845768947156235"),
+                            Row.of(11L, 
"10,0.058845768947156235;12,0.058845768947156235"),
+                            Row.of(
+                                    12L,
+                                    
"13,0.09134833828228624;10,0.058845768947156235;11,0.058845768947156235"),
+                            Row.of(13L, "12,0.09134833828228624")));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.getConfig().enableObjectReuse();
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(trainRows, 
trainDataTypeInfo);
+        trainData = tEnv.fromDataStream(dataStream);
+    }
+
+    private void compareResultAndExpected(List<Row> results) {
+        results.sort((o1, o2) -> Long.compare(o1.getFieldAs(0), 
o2.getFieldAs(0)));
+
+        for (int i = 0; i < results.size(); i++) {
+            Row result = results.get(i);
+            String itemRankScore = result.getFieldAs(1);
+            Row expect = expectedScoreRows.get(i);
+            Assert.assertEquals(result.getField(0), expect.getField(0));
+            Assert.assertEquals(itemRankScore, expect.getField(1));
+        }
+    }
+
+    @Test
+    public void testParam() {
+        Swing swing = new Swing();
+
+        assertEquals("item", swing.getItemCol());
+        assertEquals("user", swing.getUserCol());
+        assertEquals(100, swing.getK());
+        assertEquals(10, swing.getMinUserBehavior());
+        assertEquals(1000, swing.getMaxUserBehavior());
+        assertEquals(15, swing.getAlpha1());
+        assertEquals(0, swing.getAlpha2());
+        assertEquals(0.3, swing.getBeta(), 1e-9);
+
+        swing.setItemCol("item_1")
+                .setUserCol("user_1")
+                .setK(20)
+                .setMinUserBehavior(10)
+                .setMaxUserBehavior(50)
+                .setAlpha1(5)
+                .setAlpha2(1)
+                .setBeta(0.35);
+
+        assertEquals("item_1", swing.getItemCol());
+        assertEquals("user_1", swing.getUserCol());
+        assertEquals(20, swing.getK());
+        assertEquals(10, swing.getMinUserBehavior());
+        assertEquals(50, swing.getMaxUserBehavior());
+        assertEquals(5, swing.getAlpha1());
+        assertEquals(1, swing.getAlpha2());
+        assertEquals(0.35, swing.getBeta(), 1e-9);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testDataType() {
+        List<Row> rows =
+                new ArrayList<>(Arrays.asList(Row.of(0, "10"), Row.of(1, 
"11"), Row.of(2, "")));
+        DataStream<Row> dataStream =
+                env.fromCollection(
+                        rows,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    BasicTypeInfo.LONG_TYPE_INFO, 
BasicTypeInfo.STRING_TYPE_INFO
+                                },
+                                new String[] {"user_id", "item_id"}));
+        Table data = tEnv.fromDataStream(dataStream);
+        Table[] swingResultTables =
+                new Swing()
+                        .setItemCol("item_id")
+                        .setUserCol("user_id")
+                        .setOutputCol("item_score")
+                        .setMinUserBehavior(1)
+                        .transform(data);
+
+        swingResultTables[0].execute().collect();
+    }
+
+    @Test(expected = RuntimeException.class)
+    public void testNumberFormat() {
+        List<Row> rows =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Row.of(0L, 10L),
+                                Row.of(null, 12L),
+                                Row.of(1L, 13L),
+                                Row.of(3L, 12L)));
+        DataStream<Row> dataStream = env.fromCollection(rows, 
trainDataTypeInfo);
+        Table data = tEnv.fromDataStream(dataStream);
+        Swing swing = new 
Swing().setItemCol("item_id").setUserCol("user_id").setMinUserBehavior(1);
+        Table[] swingResultTables = swing.transform(data);
+        swingResultTables[0].execute().print();
+    }
+
+    @Test
+    public void testOutputSchema() {
+        Swing swing =
+                new Swing()
+                        .setItemCol("item_id")
+                        .setUserCol("user_id")
+                        .setOutputCol("item_score")
+                        .setMinUserBehavior(1);
+        Table[] swingResultTables = swing.transform(trainData);
+        Table output = swingResultTables[0];
+
+        assertEquals(
+                Arrays.asList("item_id", "item_score"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFewerItemCase() {
+        Swing swing = new 
Swing().setItemCol("item_id").setUserCol("user_id").setMinUserBehavior(5);
+        Table[] swingResultTables = swing.transform(trainData);
+        Table output = swingResultTables[0];
+        List<Row> results = IteratorUtils.toList(output.execute().collect());
+        Assert.assertEquals(0, results.size());
+    }
+
+    @Test
+    public void testTransform() throws Exception {

Review Comment:
   The `throws Exception` can be removed.



##########
flink-ml-python/pyflink/ml/lib/recommendation/tests/test_swing.py:
##########
@@ -0,0 +1,165 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.common import Types
+from pyflink.table import Table
+from typing import List
+from py4j.protocol import Py4JJavaError:
+
+from pyflink.ml.lib.recommendation.swing import Swing
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+# Tests Swing. 
+class SwingTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(SwingTest, self).setUp()
+        self.train_data = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (0, 10),
+                (0, 11),
+                (0, 12),
+                (1, 13),
+                (1, 12),
+                (2, 10),
+                (2, 11),
+                (2, 12),
+                (3, 13),
+                (3, 12)
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['user', 'item'],
+                    [Types.LONG(), Types.LONG()])
+            ))
+
+        self.wrong_type_data = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (0, 10),
+                (1, 11),
+                (2, 12)
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['user', 'item'],
+                    [Types.INT(), Types.LONG()])
+            ))
+
+        self.none_value_data = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (0, 10),
+                (None, 11),
+                (2, 12)
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['user', 'item'],
+                    [Types.LONG(), Types.LONG()])
+            ))
+
+        self.expected_data = [
+            [10, '11,0.058845768947156235;12,0.058845768947156235'],
+            [11, '10,0.058845768947156235;12,0.058845768947156235'],
+            [12, 
'13,0.09134833828228624;10,0.058845768947156235;11,0.058845768947156235'],
+            [13, '12,0.09134833828228624']]
+
+    def test_param(self):
+        swing = Swing()
+        self.assertEqual("item", swing.item_col)
+        self.assertEqual("user", swing.user_col)
+        self.assertEqual(100, swing.k)
+        self.assertEqual(10, swing.min_user_items)
+        self.assertEqual(1000, swing.max_user_items)
+        self.assertEqual(15, swing.alpha1)
+        self.assertEqual(0, swing.alpha2)
+        self.assertEqual(0.3, swing.beta, delta=1e-9)
+
+        swing.set_item_col("item_1") \
+            .set_user_col("user_1") \
+            .set_k(20) \
+            .set_min_user_behavior(20) \
+            .set_max_user_behavior(50) \
+            .set_alpha1(5) \
+            .set_alpha2(1) \
+            .set_beta(0.35)
+
+        self.assertEqual("item", swing.item_col)
+        self.assertEqual("user", swing.user_col)
+        self.assertEqual(20, swing.k)
+        self.assertEqual(20, swing.min_user_items)
+        self.assertEqual(50, swing.max_user_items)
+        self.assertEqual(5, swing.alpha1)
+        self.assertEqual(1, swing.alpha2)
+        self.assertEqual(0.35, swing.beta, delta=1e-9)
+
+    def test_output_schema(self):
+        swing = Swing() \
+            .set_item_col('test_item') \
+            .set_user_col('test_user') \
+            .set_output_col("item_score")
+
+        output = swing.transform(self.train_data.alias(['test_user', 
'test_item']))[0]
+        self.assertEqual(
+            ['test_item', 'item_score'],
+            output.get_schema().get_field_names())
+
+    def test_transform(self):
+        swing = Swing().set_min_user_behavior(1)
+        output = swing.transform(self.train_data)[0]
+        self.verify_output_result(
+            output,
+            swing.get_item_col(),
+            output.get_schema().get_field_names(),
+            self.expected_data)
+
+    def test_save_load_and_transform(self):
+        swing = Swing().set_min_user_behavior(1)
+        reloaded_swing = self.save_and_reload(swing)
+        output = reloaded_swing.transform(self.train_data)[0]
+        self.verify_output_result(
+            output,
+            swing.get_item_col(),
+            output.get_schema().get_field_names(),
+            self.expected_data)
+
+    def test_data_type(self):
+        try:
+            swing = Swing().set_min_user_behavior(1)
+            output = swing.transform(self.wrong_type_data)[0]
+            self.t_env.to_data_stream(output).execute_and_collect()
+        except Py4JJavaError:
+            pass
+
+    def test_number_format(self):

Review Comment:
   It seems that`test_data_type` and `test_number_format` test nothing because 
you are not checking the error message. Btw, I'm not sure whether we need these 
tests in Python, because it is already covered in Java tests.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java:
##########
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation.swing;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.Set;
+
+/**
+ * An AlgoOperator which implements the Swing algorithm.
+ *
+ * <p>Swing is an item recall model. The topology of user-item graph usually 
can be described as
+ * user-item-user or item-user-item, which are like 'swing'. For example, if 
both user <em>u</em>
+ * and user <em>v</em> have purchased the same commodity <em>i</em> , they 
will form a relationship
+ * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased 
commodity <em>j</em> in
+ * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are 
similar. The formula of
+ * Swing is
+ *
+ * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap
+ * 
U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha\_2+|I_u\cap
+ * I_v|}} $$
+ *
+ * <p>This implementation is based on the algorithm proposed in the paper: 
"Large Scale Product
+ * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, 
Yadong Zhu and Yi Zhang.
+ * (https://arxiv.org/pdf/2010.05525.pdf)
+ */
+public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public Swing() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+
+        final String userCol = getUserCol();
+        final String itemCol = getItemCol();
+        Preconditions.checkArgument(inputs.length == 1);
+        final ResolvedSchema schema = inputs[0].getResolvedSchema();
+
+        if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol))
+                && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, 
itemCol)))) {
+            throw new IllegalArgumentException("The types of user and item 
columns must be Long.");
+        }
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                row -> {
+                                    if (row.getFieldAs(userCol) == null
+                                            || row.getFieldAs(itemCol) == 
null) {
+                                        throw new RuntimeException(
+                                                "Data of user and item column 
must not be null");
+                                    }
+                                    return Tuple2.of(
+                                            ((Number) 
row.getFieldAs(userCol)).longValue(),
+                                            ((Number) 
row.getFieldAs(itemCol)).longValue());
+                                })
+                        .returns(Types.TUPLE(Types.LONG, Types.LONG));
+
+        SingleOutputStreamOperator<Tuple3<Long, Long, List<Long>>> 
userAllItemsStream =
+                itemUsers
+                        .keyBy(tuple -> tuple.f0)
+                        .transform(
+                                "fillUserItemsTable",
+                                Types.TUPLE(Types.LONG, Types.LONG, 
Types.LIST(Types.LONG)),
+                                new CollectingUserBehavior(
+                                        getMinUserBehavior(), 
getMaxUserBehavior()));
+
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        new TypeInformation[] {
+                            BasicTypeInfo.LONG_TYPE_INFO, 
BasicTypeInfo.STRING_TYPE_INFO
+                        },
+                        new String[] {getItemCol(), getOutputCol()});
+
+        DataStream<Row> output =
+                userAllItemsStream
+                        .keyBy(tuple -> tuple.f1)
+                        .transform(
+                                "computingSimilarItems",
+                                outputTypeInfo,
+                                new ComputingSimilarItems(
+                                        getK(),
+                                        getMaxUserNumPerItem(),
+                                        getAlpha1(),
+                                        getAlpha2(),
+                                        getBeta()));
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static Swing load(StreamTableEnvironment tEnv, String path) throws 
IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    /**
+     * Appends one column, that records all items the user has purchased, to 
the input table.
+     *
+     * <p>During the process, this operator collect users and all items a user 
has purchased into a
+     * map of list. When the input is finished, this operator appends the 
certain
+     * user-purchased-items list to each row.
+     */
+    private static class CollectingUserBehavior
+            extends AbstractStreamOperator<Tuple3<Long, Long, List<Long>>>
+            implements OneInputStreamOperator<Tuple2<Long, Long>, Tuple3<Long, 
Long, List<Long>>>,
+                    BoundedOneInput {
+        private final int minUserItemInteraction;
+        private final int maxUserItemInteraction;
+
+        private Map<Long, Set<Long>> userItemsMap = new HashMap<>();
+
+        private ListState<Map<Long, List<Long>>> userAllItemsMapState;
+
+        private CollectingUserBehavior(int minUserItemInteraction, int 
maxUserItemInteraction) {
+            this.minUserItemInteraction = minUserItemInteraction;
+            this.maxUserItemInteraction = maxUserItemInteraction;
+        }
+
+        @Override
+        public void endInput() {
+
+            for (Entry<Long, Set<Long>> entry : userItemsMap.entrySet()) {
+                List<Long> items = new ArrayList<>(entry.getValue());
+                Long user = entry.getKey();
+                if (items.size() < minUserItemInteraction
+                        || items.size() > maxUserItemInteraction) {
+                    continue;
+                }
+                for (Long item : items) {
+                    output.collect(new StreamRecord<>(new Tuple3<>(user, item, 
items)));
+                }
+            }
+
+            userAllItemsMapState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple2<Long, Long>> element) {
+            Tuple2<Long, Long> userAndItem = element.getValue();
+            long user = userAndItem.f0;
+            long item = userAndItem.f1;
+            Set<Long> items = userItemsMap.get(user);
+
+            if (items == null) {
+                Set<Long> value = new LinkedHashSet<>();
+                value.add(item);
+                userItemsMap.put(user, value);
+            } else {
+                if (items.size() <= maxUserItemInteraction) {
+                    items.add(item);
+                }
+            }
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            userAllItemsMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "userAllItemsMapState",
+                                            Types.MAP(Types.LONG, 
Types.LIST(Types.LONG))));
+
+            OperatorStateUtils.getUniqueElement(userAllItemsMapState, 
"userAllItemsMapState")
+                    .ifPresent(
+                            x -> {
+                                userItemsMap = new HashMap<>(x.size());
+                                for (long user : x.keySet()) {
+                                    List<Long> itemList = x.get(user);
+                                    userItemsMap.put(user, new 
LinkedHashSet<>(itemList));
+                                }
+                            });
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            Map<Long, List<Long>> userItemsList = new 
HashMap<>(userItemsMap.size());
+            for (Entry<Long, Set<Long>> entry : userItemsMap.entrySet()) {
+                userItemsList.put(entry.getKey(), new 
ArrayList<>(entry.getValue()));
+            }
+            
userAllItemsMapState.update(Collections.singletonList(userItemsList));
+        }
+    }
+
+    /** Calculates top N similar items of each item. */
+    private static class ComputingSimilarItems extends 
AbstractStreamOperator<Row>
+            implements OneInputStreamOperator<Tuple3<Long, Long, List<Long>>, 
Row>,
+                    BoundedOneInput {
+
+        private Map<Long, HashSet<Long>> userItemsMap = new HashMap<>();
+        private Map<Long, HashSet<Long>> itemUsersMap = new HashMap<>();
+        private ListState<Map<Long, List<Long>>> userLocalItemsMapState;
+        private ListState<Map<Long, List<Long>>> itemUsersMapState;
+
+        private final int k;
+        private final int maxUserNumPerItem;
+        private final int alpha1;
+        private final int alpha2;
+        private final double beta;
+
+        private static Character commaDelimiter = ',';
+        private static Character semicolonDelimiter = ';';
+
+        private ComputingSimilarItems(
+                int k, int maxUserNumPerItem, int alpha1, int alpha2, double 
beta) {
+            this.k = k;
+            this.maxUserNumPerItem = maxUserNumPerItem;
+            this.alpha1 = alpha1;
+            this.alpha2 = alpha2;
+            this.beta = beta;
+        }
+
+        @Override
+        public void endInput() throws Exception {
+
+            Map<Long, Double> userWeights = new HashMap<>(userItemsMap.size());
+            userItemsMap.forEach(
+                    (k, v) -> {
+                        int count = v.size();
+                        userWeights.put(k, calculateWeight(count));
+                    });
+
+            for (long mainItem : itemUsersMap.keySet()) {
+                List<Long> userList = 
sampleUserList(itemUsersMap.get(mainItem), maxUserNumPerItem);
+                HashMap<Long, Double> id2swing = new HashMap<>();
+
+                for (int i = 0; i < userList.size(); i++) {
+                    long u = userList.get(i);
+                    for (int j = i + 1; j < userList.size(); j++) {
+                        long v = userList.get(j);
+                        HashSet<Long> interaction = (HashSet<Long>) 
userItemsMap.get(u).clone();
+                        interaction.retainAll(userItemsMap.get(v));
+                        if (interaction.size() == 0) {
+                            continue;
+                        }
+                        double similarity =
+                                (userWeights.get(u)
+                                        * userWeights.get(v)
+                                        / (alpha2 + interaction.size()));
+                        for (long simItem : interaction) {
+                            if (simItem == mainItem) {
+                                continue;
+                            }
+                            double itemSimilarity =
+                                    id2swing.getOrDefault(simItem, 0.0) + 
similarity;
+                            id2swing.putIfAbsent(simItem, itemSimilarity);
+                        }
+                    }
+                }
+
+                ArrayList<Tuple2<Long, Double>> itemAndScore = new 
ArrayList<>();
+                id2swing.forEach((key, value) -> 
itemAndScore.add(Tuple2.of(key, value)));
+
+                itemAndScore.sort((o1, o2) -> Double.compare(o2.f1, o1.f1));
+
+                if (itemAndScore.size() == 0) {
+                    continue;
+                }
+
+                int itemNums = Math.min(k, itemAndScore.size());
+                StringBuilder sbd = new StringBuilder();
+                for (int i = 0; i < itemNums; i++) {
+                    sbd.append(itemAndScore.get(i).f0).append(commaDelimiter);
+                    
sbd.append(itemAndScore.get(i).f1).append(semicolonDelimiter);
+                }
+                String itemList = sbd.substring(0, sbd.length() - 1);
+
+                output.collect(new StreamRecord<>(Row.of(mainItem, itemList)));
+            }
+
+            userLocalItemsMapState.clear();
+            itemUsersMapState.clear();
+        }
+
+        private double calculateWeight(int size) {
+            return (1.0 / Math.pow(alpha1 + size, beta));
+        }
+
+        private static List<Long> sampleUserList(Set<Long> allUsers, int 
sampleSize) {
+            int totalSize = allUsers.size();
+            if (totalSize < sampleSize) {
+                return new ArrayList(allUsers);
+            }
+
+            List<Long> userList = new ArrayList<>(totalSize);
+            double prob = (double) sampleSize / totalSize;
+            Random rand = new Random();
+
+            for (long u : allUsers) {
+                double guess = rand.nextDouble();
+                if (guess < prob) {
+                    userList.add(u);
+                    sampleSize--;
+                }
+                totalSize--;
+                prob = (double) sampleSize / totalSize;
+            }
+
+            return userList;
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Long, Long, 
List<Long>>> streamRecord)
+                throws Exception {
+            Tuple3<Long, Long, List<Long>> tuple3 = streamRecord.getValue();
+            long user = tuple3.f0;
+            long item = tuple3.f1;
+            List<Long> items = tuple3.f2;
+
+            if (!userItemsMap.containsKey(user)) {
+                HashSet<Long> itemSet = new HashSet<>(items.size());
+                itemSet.addAll(items);
+                userItemsMap.put(user, itemSet);
+            }
+
+            if (!itemUsersMap.containsKey(item)) {
+                itemUsersMap.put(item, new HashSet<>());
+            }
+            itemUsersMap.get(item).add(user);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            userLocalItemsMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "userLocalItemsMapState",
+                                            Types.MAP(Types.LONG, 
Types.LIST(Types.LONG))));
+
+            OperatorStateUtils.getUniqueElement(userLocalItemsMapState, 
"userLocalItemsMapState")
+                    .ifPresent(
+                            x -> {
+                                userItemsMap = new HashMap<>(x.size());
+                                x.forEach(
+                                        (k, v) -> {
+                                            HashSet<Long> itemSet = new 
HashSet<>(v.size());
+                                            itemSet.addAll(v);
+                                            userItemsMap.put(k, itemSet);
+                                        });
+                            });
+
+            itemUsersMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "itemUsersMapState",
+                                            Types.MAP(Types.LONG, 
Types.LIST(Types.LONG))));
+
+            OperatorStateUtils.getUniqueElement(itemUsersMapState, 
"itemUsersMapState")
+                    .ifPresent(
+                            x -> {
+                                itemUsersMap = new HashMap<>(x.size());
+                                x.forEach(
+                                        (k, v) -> {
+                                            HashSet<Long> itemSet = new 
HashSet<>(v.size());
+                                            itemSet.addAll(v);
+                                            itemUsersMap.put(k, itemSet);
+                                        });
+                            });
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            Map tmpUerItemsMap = new HashMap(userItemsMap.size());

Review Comment:
   Let's replace the `Map` with `Map<Long, List<Long>>` to avoid raw use. So as 
the other parameterized class like `ArrayList`, `Map`. 



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.recommendation.swing.Swing;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Swing}. */
+public class SwingTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    RowTypeInfo trainDataTypeInfo =
+            new RowTypeInfo(
+                    new TypeInformation[] {
+                        BasicTypeInfo.LONG_TYPE_INFO, 
BasicTypeInfo.LONG_TYPE_INFO
+                    },
+                    new String[] {"user_id", "item_id"});
+    private static final List<Row> trainRows =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(0L, 10L),
+                            Row.of(0L, 11L),
+                            Row.of(0L, 12L),
+                            Row.of(1L, 13L),
+                            Row.of(1L, 12L),
+                            Row.of(2L, 10L),
+                            Row.of(2L, 11L),
+                            Row.of(2L, 12L),
+                            Row.of(3L, 13L),
+                            Row.of(3L, 12L)));
+
+    private static final List<Row> expectedScoreRows =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(10L, 
"11,0.058845768947156235;12,0.058845768947156235"),
+                            Row.of(11L, 
"10,0.058845768947156235;12,0.058845768947156235"),
+                            Row.of(
+                                    12L,
+                                    
"13,0.09134833828228624;10,0.058845768947156235;11,0.058845768947156235"),
+                            Row.of(13L, "12,0.09134833828228624")));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.getConfig().enableObjectReuse();
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(trainRows, 
trainDataTypeInfo);
+        trainData = tEnv.fromDataStream(dataStream);
+    }
+
+    private void compareResultAndExpected(List<Row> results) {
+        results.sort((o1, o2) -> Long.compare(o1.getFieldAs(0), 
o2.getFieldAs(0)));
+
+        for (int i = 0; i < results.size(); i++) {
+            Row result = results.get(i);
+            String itemRankScore = result.getFieldAs(1);
+            Row expect = expectedScoreRows.get(i);
+            Assert.assertEquals(result.getField(0), expect.getField(0));
+            Assert.assertEquals(itemRankScore, expect.getField(1));
+        }
+    }
+
+    @Test
+    public void testParam() {
+        Swing swing = new Swing();
+
+        assertEquals("item", swing.getItemCol());
+        assertEquals("user", swing.getUserCol());
+        assertEquals(100, swing.getK());
+        assertEquals(10, swing.getMinUserBehavior());
+        assertEquals(1000, swing.getMaxUserBehavior());
+        assertEquals(15, swing.getAlpha1());
+        assertEquals(0, swing.getAlpha2());
+        assertEquals(0.3, swing.getBeta(), 1e-9);
+
+        swing.setItemCol("item_1")
+                .setUserCol("user_1")
+                .setK(20)
+                .setMinUserBehavior(10)
+                .setMaxUserBehavior(50)
+                .setAlpha1(5)
+                .setAlpha2(1)
+                .setBeta(0.35);
+
+        assertEquals("item_1", swing.getItemCol());
+        assertEquals("user_1", swing.getUserCol());
+        assertEquals(20, swing.getK());
+        assertEquals(10, swing.getMinUserBehavior());
+        assertEquals(50, swing.getMaxUserBehavior());
+        assertEquals(5, swing.getAlpha1());
+        assertEquals(1, swing.getAlpha2());
+        assertEquals(0.35, swing.getBeta(), 1e-9);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testDataType() {
+        List<Row> rows =
+                new ArrayList<>(Arrays.asList(Row.of(0, "10"), Row.of(1, 
"11"), Row.of(2, "")));
+        DataStream<Row> dataStream =
+                env.fromCollection(
+                        rows,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    BasicTypeInfo.LONG_TYPE_INFO, 
BasicTypeInfo.STRING_TYPE_INFO
+                                },
+                                new String[] {"user_id", "item_id"}));
+        Table data = tEnv.fromDataStream(dataStream);
+        Table[] swingResultTables =
+                new Swing()
+                        .setItemCol("item_id")
+                        .setUserCol("user_id")
+                        .setOutputCol("item_score")
+                        .setMinUserBehavior(1)
+                        .transform(data);
+
+        swingResultTables[0].execute().collect();
+    }
+
+    @Test(expected = RuntimeException.class)

Review Comment:
   Would it be better if we use `try { ...; fail();} catch { assert errorMsg}` 
instead of using `expected = RuntimeException.class`? It avoids the situation 
that unexpected RuntimeException happens but the test is passed. You can refer 
to `testFitOnEmptyData` in `VarianceThresholdSelectorTest`.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.recommendation.swing.Swing;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Swing}. */
+public class SwingTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    RowTypeInfo trainDataTypeInfo =
+            new RowTypeInfo(
+                    new TypeInformation[] {
+                        BasicTypeInfo.LONG_TYPE_INFO, 
BasicTypeInfo.LONG_TYPE_INFO
+                    },
+                    new String[] {"user_id", "item_id"});
+    private static final List<Row> trainRows =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(0L, 10L),
+                            Row.of(0L, 11L),
+                            Row.of(0L, 12L),
+                            Row.of(1L, 13L),
+                            Row.of(1L, 12L),
+                            Row.of(2L, 10L),
+                            Row.of(2L, 11L),
+                            Row.of(2L, 12L),
+                            Row.of(3L, 13L),
+                            Row.of(3L, 12L)));
+
+    private static final List<Row> expectedScoreRows =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(10L, 
"11,0.058845768947156235;12,0.058845768947156235"),
+                            Row.of(11L, 
"10,0.058845768947156235;12,0.058845768947156235"),
+                            Row.of(
+                                    12L,
+                                    
"13,0.09134833828228624;10,0.058845768947156235;11,0.058845768947156235"),
+                            Row.of(13L, "12,0.09134833828228624")));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.getConfig().enableObjectReuse();

Review Comment:
   Please append the `.disableGenericTypes()` to the execution config to ensure 
that no inefficient serialization is used. We will add this config for all 
algorithms in the future.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java:
##########
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation.swing;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.Set;
+
+/**
+ * An AlgoOperator which implements the Swing algorithm.
+ *
+ * <p>Swing is an item recall model. The topology of user-item graph usually 
can be described as
+ * user-item-user or item-user-item, which are like 'swing'. For example, if 
both user <em>u</em>
+ * and user <em>v</em> have purchased the same commodity <em>i</em> , they 
will form a relationship
+ * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased 
commodity <em>j</em> in
+ * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are 
similar. The formula of
+ * Swing is
+ *
+ * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap
+ * 
U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha\_2+|I_u\cap
+ * I_v|}} $$
+ *
+ * <p>This implementation is based on the algorithm proposed in the paper: 
"Large Scale Product
+ * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, 
Yadong Zhu and Yi Zhang.
+ * (https://arxiv.org/pdf/2010.05525.pdf)
+ */
+public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public Swing() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+
+        final String userCol = getUserCol();
+        final String itemCol = getItemCol();
+        Preconditions.checkArgument(inputs.length == 1);
+        final ResolvedSchema schema = inputs[0].getResolvedSchema();
+
+        if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol))
+                && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, 
itemCol)))) {
+            throw new IllegalArgumentException("The types of user and item 
columns must be Long.");
+        }
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                row -> {
+                                    if (row.getFieldAs(userCol) == null
+                                            || row.getFieldAs(itemCol) == 
null) {
+                                        throw new RuntimeException(
+                                                "Data of user and item column 
must not be null");
+                                    }
+                                    return Tuple2.of(
+                                            ((Number) 
row.getFieldAs(userCol)).longValue(),
+                                            ((Number) 
row.getFieldAs(itemCol)).longValue());

Review Comment:
   Would it be simpler to replace the `(Number) 
row.getFieldAs(itemCol)).longValue()` with `(Long) row.getFieldAs(userCol)`?



##########
flink-ml-python/pyflink/examples/ml/recommendation/swing_example.py:
##########
@@ -0,0 +1,75 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.table import StreamTableEnvironment, Schema
+from pyflink.table.types import DataTypes
+
+from pyflink.ml.lib.recommendation.swing import Swing
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input data.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (0, 10),
+        (0, 11),
+        (0, 12),
+        (1, 13),
+        (1, 12),
+        (2, 10),
+        (2, 11),
+        (2, 12),
+        (3, 13),
+        (3, 12)
+    ],
+        type_info=Types.ROW_NAMED(
+        ['user', 'item'],
+        [Types.LONG(), Types.LONG()])
+    ))
+
+# Creates a swing object and initialize its parameters.
+swing = Swing()\
+    .set_item_col('item')\
+    .set_user_col("user")\
+    .set_min_user_behavior(1)
+
+# Transforms the data to Swing algorithm result.
+swingTable = swing.transform(input_table)

Review Comment:
   How about following the existing naming style and naming it `output`? And 
almost all the other python examples contain a docstring like `Simple program 
that creates a ...`, could you add it?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java:
##########
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation.swing;
+
+import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Params for {@link Swing}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> {
+    Param<String> USER_COL =
+            new StringParam("userCol", "Name of user column.", "user", 
ParamValidators.notNull());
+
+    Param<String> ITEM_COL =
+            new StringParam("itemCol", "Name of item column.", "item", 
ParamValidators.notNull());
+
+    Param<Integer> MAX_USER_NUM_PER_ITEM =
+            new IntParam(
+                    "maxUserNumPerItem",
+                    "The max number of users that has purchased for each item. 
If the number of users that have "
+                            + "purchased this item is larger than this value, 
then only maxUserNumPerItem users will "
+                            + "be sampled and used in the computation logic.",
+                    1000,
+                    ParamValidators.gt(0));
+
+    Param<Integer> K =
+            new IntParam(
+                    "k",
+                    "The max number of similar items to output for each item.",
+                    100,
+                    ParamValidators.gt(0));
+
+    Param<Integer> MIN_USER_BEHAVIOR =
+            new IntParam(
+                    "minUserBehavior",
+                    "The min number of interaction behavior between item and 
user.",
+                    10,
+                    ParamValidators.gt(0));
+
+    Param<Integer> MAX_USER_BEHAVIOR =
+            new IntParam(
+                    "maxUserBehavior",
+                    "The max number of interaction behavior between item and 
user. "
+                            + "The algorithm filters out activate users.",
+                    1000,
+                    ParamValidators.gt(0));
+
+    Param<Integer> ALPHA1 =
+            new IntParam(
+                    "alpha1",
+                    "This parameter is used to calculate weight of each user. "
+                            + "The higher alpha1 is, the smaller weight each 
user gets.",
+                    15,
+                    ParamValidators.gtEq(0));
+
+    Param<Integer> ALPHA2 =
+            new IntParam(
+                    "alpha2",
+                    "This parameter is used to calculate similarity of users. "
+                            + "The higher alpha2 is, the less the similarity 
score is.",
+                    0,
+                    ParamValidators.gtEq(0));
+
+    Param<Double> BETA =
+            new DoubleParam(
+                    "beta",
+                    "This parameter is used to calculate weight of each user. "
+                            + "The higher beta is, the weight is subject to 
exponential decay.",
+                    0.3,
+                    ParamValidators.gtEq(0));
+
+    default String getUserCol() {
+        return get(USER_COL);
+    }
+
+    default T setUserCol(String value) {
+        return set(USER_COL, value);
+    }
+
+    default String getItemCol() {
+        return get(ITEM_COL);
+    }
+
+    default T setItemCol(String value) {
+        return set(ITEM_COL, value);
+    }
+
+    default int getK() {
+        return get(K);
+    }
+
+    default T setK(Integer value) {
+        return set(K, value);
+    }
+
+    default int getMaxUserNumPerItem() {
+        return get(MAX_USER_NUM_PER_ITEM);
+    }
+
+    default T setMaxUserNumPerItem(Integer value) {

Review Comment:
   Please add tests for `setMaxUserNumPerItem` in `test_param()` and 
`testParam()`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java:
##########
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation.swing;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.Set;
+
+/**
+ * An AlgoOperator which implements the Swing algorithm.
+ *
+ * <p>Swing is an item recall model. The topology of user-item graph usually 
can be described as
+ * user-item-user or item-user-item, which are like 'swing'. For example, if 
both user <em>u</em>
+ * and user <em>v</em> have purchased the same commodity <em>i</em> , they 
will form a relationship
+ * diagram similar to a swing. If <em>u</em> and <em>v</em> have purchased 
commodity <em>j</em> in
+ * addition to <em>i</em>, it is supposed <em>i</em> and <em>j</em> are 
similar. The formula of
+ * Swing is
+ *
+ * <p>$$ w_{(i,j)}=\sum_{u\in U_i\cap U_j}\sum_{v\in U_i\cap
+ * 
U_j}{\frac{1}{{(I_u+\alpha_1)}^\beta}}*{\frac{1}{{(I_v+\alpha_1)}^\beta}}*{\frac{1}{\alpha\_2+|I_u\cap
+ * I_v|}} $$
+ *
+ * <p>This implementation is based on the algorithm proposed in the paper: 
"Large Scale Product
+ * Graph Construction for Recommendation in E-commerce" by Xiaoyong Yang, 
Yadong Zhu and Yi Zhang.
+ * (https://arxiv.org/pdf/2010.05525.pdf)
+ */
+public class Swing implements AlgoOperator<Swing>, SwingParams<Swing> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public Swing() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+
+        final String userCol = getUserCol();
+        final String itemCol = getItemCol();
+        Preconditions.checkArgument(inputs.length == 1);
+        final ResolvedSchema schema = inputs[0].getResolvedSchema();
+
+        if (!(Types.LONG.equals(TableUtils.getTypeInfoByName(schema, userCol))
+                && Types.LONG.equals(TableUtils.getTypeInfoByName(schema, 
itemCol)))) {
+            throw new IllegalArgumentException("The types of user and item 
columns must be Long.");
+        }
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        SingleOutputStreamOperator<Tuple2<Long, Long>> itemUsers =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                row -> {
+                                    if (row.getFieldAs(userCol) == null
+                                            || row.getFieldAs(itemCol) == 
null) {
+                                        throw new RuntimeException(
+                                                "Data of user and item column 
must not be null");
+                                    }
+                                    return Tuple2.of(
+                                            ((Number) 
row.getFieldAs(userCol)).longValue(),
+                                            ((Number) 
row.getFieldAs(itemCol)).longValue());
+                                })
+                        .returns(Types.TUPLE(Types.LONG, Types.LONG));
+
+        SingleOutputStreamOperator<Tuple3<Long, Long, List<Long>>> 
userAllItemsStream =
+                itemUsers
+                        .keyBy(tuple -> tuple.f0)
+                        .transform(
+                                "fillUserItemsTable",
+                                Types.TUPLE(Types.LONG, Types.LONG, 
Types.LIST(Types.LONG)),
+                                new CollectingUserBehavior(
+                                        getMinUserBehavior(), 
getMaxUserBehavior()));
+
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        new TypeInformation[] {
+                            BasicTypeInfo.LONG_TYPE_INFO, 
BasicTypeInfo.STRING_TYPE_INFO
+                        },
+                        new String[] {getItemCol(), getOutputCol()});
+
+        DataStream<Row> output =
+                userAllItemsStream
+                        .keyBy(tuple -> tuple.f1)
+                        .transform(
+                                "computingSimilarItems",
+                                outputTypeInfo,
+                                new ComputingSimilarItems(
+                                        getK(),
+                                        getMaxUserNumPerItem(),
+                                        getAlpha1(),
+                                        getAlpha2(),
+                                        getBeta()));
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static Swing load(StreamTableEnvironment tEnv, String path) throws 
IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    /**
+     * Appends one column, that records all items the user has purchased, to 
the input table.
+     *
+     * <p>During the process, this operator collect users and all items a user 
has purchased into a
+     * map of list. When the input is finished, this operator appends the 
certain
+     * user-purchased-items list to each row.
+     */
+    private static class CollectingUserBehavior
+            extends AbstractStreamOperator<Tuple3<Long, Long, List<Long>>>
+            implements OneInputStreamOperator<Tuple2<Long, Long>, Tuple3<Long, 
Long, List<Long>>>,
+                    BoundedOneInput {
+        private final int minUserItemInteraction;
+        private final int maxUserItemInteraction;
+
+        private Map<Long, Set<Long>> userItemsMap = new HashMap<>();
+
+        private ListState<Map<Long, List<Long>>> userAllItemsMapState;
+
+        private CollectingUserBehavior(int minUserItemInteraction, int 
maxUserItemInteraction) {
+            this.minUserItemInteraction = minUserItemInteraction;
+            this.maxUserItemInteraction = maxUserItemInteraction;
+        }
+
+        @Override
+        public void endInput() {
+
+            for (Entry<Long, Set<Long>> entry : userItemsMap.entrySet()) {
+                List<Long> items = new ArrayList<>(entry.getValue());
+                Long user = entry.getKey();
+                if (items.size() < minUserItemInteraction
+                        || items.size() > maxUserItemInteraction) {
+                    continue;
+                }
+                for (Long item : items) {
+                    output.collect(new StreamRecord<>(new Tuple3<>(user, item, 
items)));
+                }
+            }
+
+            userAllItemsMapState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple2<Long, Long>> element) {
+            Tuple2<Long, Long> userAndItem = element.getValue();
+            long user = userAndItem.f0;
+            long item = userAndItem.f1;
+            Set<Long> items = userItemsMap.get(user);
+
+            if (items == null) {
+                Set<Long> value = new LinkedHashSet<>();
+                value.add(item);
+                userItemsMap.put(user, value);
+            } else {
+                if (items.size() <= maxUserItemInteraction) {
+                    items.add(item);
+                }
+            }
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            userAllItemsMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "userAllItemsMapState",
+                                            Types.MAP(Types.LONG, 
Types.LIST(Types.LONG))));
+
+            OperatorStateUtils.getUniqueElement(userAllItemsMapState, 
"userAllItemsMapState")
+                    .ifPresent(
+                            x -> {
+                                userItemsMap = new HashMap<>(x.size());
+                                for (long user : x.keySet()) {
+                                    List<Long> itemList = x.get(user);
+                                    userItemsMap.put(user, new 
LinkedHashSet<>(itemList));
+                                }
+                            });
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            Map<Long, List<Long>> userItemsList = new 
HashMap<>(userItemsMap.size());
+            for (Entry<Long, Set<Long>> entry : userItemsMap.entrySet()) {
+                userItemsList.put(entry.getKey(), new 
ArrayList<>(entry.getValue()));
+            }
+            
userAllItemsMapState.update(Collections.singletonList(userItemsList));
+        }
+    }
+
+    /** Calculates top N similar items of each item. */
+    private static class ComputingSimilarItems extends 
AbstractStreamOperator<Row>
+            implements OneInputStreamOperator<Tuple3<Long, Long, List<Long>>, 
Row>,
+                    BoundedOneInput {
+
+        private Map<Long, HashSet<Long>> userItemsMap = new HashMap<>();
+        private Map<Long, HashSet<Long>> itemUsersMap = new HashMap<>();
+        private ListState<Map<Long, List<Long>>> userLocalItemsMapState;
+        private ListState<Map<Long, List<Long>>> itemUsersMapState;
+
+        private final int k;
+        private final int maxUserNumPerItem;
+        private final int alpha1;
+        private final int alpha2;
+        private final double beta;
+
+        private static Character commaDelimiter = ',';
+        private static Character semicolonDelimiter = ';';

Review Comment:
   Let's declare these two delimiters as `final`, or just remove the 
declaration because they are only used in one place.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to