This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 2a599182 [FLINK-31325] Improve performance of Swing
2a599182 is described below

commit 2a5991827f8885d7a17023dee43da016be12dd79
Author: vacaly <[email protected]>
AuthorDate: Wed Apr 12 18:16:14 2023 +0800

    [FLINK-31325] Improve performance of Swing
    
    This closes #220.
---
 .../flink/ml/recommendation/swing/Swing.java       | 123 ++++++++++++++-------
 .../flink/ml/recommendation/swing/SwingParams.java |   3 +-
 .../apache/flink/ml/recommendation/SwingTest.java  |  20 +++-
 flink-ml-python/pyflink/ml/recommendation/swing.py |   5 +-
 .../pyflink/ml/recommendation/tests/test_swing.py  |  16 ++-
 5 files changed, 119 insertions(+), 48 deletions(-)

diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java
index 5a1ebc0c..022f44b7 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java
@@ -49,12 +49,13 @@ import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.stream.Collectors;
 
 /**
@@ -147,9 +148,11 @@ public class Swing implements AlgoOperator<Swing>, 
SwingParams<Swing> {
                                 new ComputingSimilarItems(
                                         getK(),
                                         getMaxUserNumPerItem(),
+                                        getMaxUserBehavior(),
                                         getAlpha1(),
                                         getAlpha2(),
-                                        getBeta()));
+                                        getBeta(),
+                                        getSeed()));
 
         return new Table[] {tEnv.fromDataStream(output)};
     }
@@ -182,8 +185,7 @@ public class Swing implements AlgoOperator<Swing>, 
SwingParams<Swing> {
         private final int maxUserItemInteraction;
 
         // Maps a user id to a set of items. Because ListState cannot keep 
values of type `Set`,
-        // we use `Map<Long, String>` with null values instead. So does 
`userAndPurchasedItems` and
-        // `itemAndPurchasers` in `ComputingSimilarItems`.
+        // we use `Map<Long, String>` with null values instead.
         private Map<Long, Map<Long, String>> userAndPurchasedItems = new 
HashMap<>();
 
         private ListState<Map<Long, Map<Long, String>>> 
userAndPurchasedItemsState;
@@ -259,13 +261,9 @@ public class Swing implements AlgoOperator<Swing>, 
SwingParams<Swing> {
     private static class ComputingSimilarItems extends 
AbstractStreamOperator<Row>
             implements OneInputStreamOperator<Tuple3<Long, Long, long[]>, 
Row>, BoundedOneInput {
 
-        private Map<Long, Map<Long, String>> userAndPurchasedItems = new 
HashMap<>();
-        private Map<Long, Map<Long, String>> itemAndPurchasers = new 
HashMap<>();
-        private ListState<Map<Long, Map<Long, String>>> 
userAndPurchasedItemsState;
-        private ListState<Map<Long, Map<Long, String>>> itemAndPurchasersState;
-
         private final int k;
         private final int maxUserNumPerItem;
+        private final int maxUserBehavior;
 
         private final int alpha1;
         private final int alpha2;
@@ -274,13 +272,29 @@ public class Swing implements AlgoOperator<Swing>, 
SwingParams<Swing> {
         private static final Character commaDelimiter = ',';
         private static final Character semicolonDelimiter = ';';
 
+        private final Random random;
+
+        private Map<Long, long[]> userAndPurchasedItems = new HashMap<>();
+        private Map<Long, List<Long>> itemAndPurchasers = new HashMap<>();
+
+        private ListState<Map<Long, long[]>> userAndPurchasedItemsState;
+        private ListState<Map<Long, List<Long>>> itemAndPurchasersState;
+
         private ComputingSimilarItems(
-                int k, int maxUserNumPerItem, int alpha1, int alpha2, double 
beta) {
+                int k,
+                int maxUserNumPerItem,
+                int maxUserBehavior,
+                int alpha1,
+                int alpha2,
+                double beta,
+                long seed) {
             this.k = k;
             this.maxUserNumPerItem = maxUserNumPerItem;
+            this.maxUserBehavior = maxUserBehavior;
             this.alpha1 = alpha1;
             this.alpha2 = alpha2;
             this.beta = beta;
+            this.random = new Random(seed);
         }
 
         @Override
@@ -289,36 +303,40 @@ public class Swing implements AlgoOperator<Swing>, 
SwingParams<Swing> {
             Map<Long, Double> userWeights = new 
HashMap<>(userAndPurchasedItems.size());
             userAndPurchasedItems.forEach(
                     (k, v) -> {
-                        int count = v.size();
+                        int count = v.length;
                         userWeights.put(k, calculateWeight(count));
                     });
 
+            long[] interaction = new long[maxUserBehavior];
             for (long mainItem : itemAndPurchasers.keySet()) {
-                List<Long> userList =
-                        sampleUserList(itemAndPurchasers.get(mainItem), 
maxUserNumPerItem);
+                List<Long> userList = itemAndPurchasers.get(mainItem);
                 HashMap<Long, Double> id2swing = new HashMap<>();
 
-                for (int i = 0; i < userList.size(); i++) {
+                for (int i = 1; i < userList.size(); i++) {
                     long u = userList.get(i);
+                    int interactionSize;
                     for (int j = i + 1; j < userList.size(); j++) {
                         long v = userList.get(j);
-                        HashSet<Long> interaction =
-                                new 
HashSet<>(userAndPurchasedItems.get(u).keySet());
-                        
interaction.retainAll(userAndPurchasedItems.get(v).keySet());
-                        if (interaction.size() == 0) {
+                        interactionSize =
+                                calculateCommonItems(
+                                        userAndPurchasedItems.get(u),
+                                        userAndPurchasedItems.get(v),
+                                        interaction);
+                        if (interactionSize == 0) {
                             continue;
                         }
                         double similarity =
-                                (userWeights.get(u)
+                                userWeights.get(u)
                                         * userWeights.get(v)
-                                        / (alpha2 + interaction.size()));
-                        for (long simItem : interaction) {
+                                        / (alpha2 + interactionSize);
+                        for (int k = 0; k < interactionSize; k++) {
+                            long simItem = interaction[k];
                             if (simItem == mainItem) {
                                 continue;
                             }
                             double itemSimilarity =
                                     id2swing.getOrDefault(simItem, 0.0) + 
similarity;
-                            id2swing.putIfAbsent(simItem, itemSimilarity);
+                            id2swing.put(simItem, itemSimilarity);
                         }
                     }
                 }
@@ -350,16 +368,22 @@ public class Swing implements AlgoOperator<Swing>, 
SwingParams<Swing> {
             return (1.0 / Math.pow(alpha1 + size, beta));
         }
 
-        private static List<Long> sampleUserList(Map<Long, String> allUsers, 
int sampleSize) {
-            int totalSize = allUsers.size();
-            List<Long> userList = new ArrayList<>(allUsers.keySet());
-
-            if (totalSize < sampleSize) {
-                return userList;
+        private static int calculateCommonItems(long[] u, long[] v, long[] 
interaction) {
+            int pointerU = 0;
+            int pointerV = 0;
+            int interactionSize = 0;
+            while (pointerU < u.length && pointerV < v.length) {
+                if (u[pointerU] == v[pointerV]) {
+                    interaction[interactionSize++] = u[pointerU];
+                    pointerU++;
+                    pointerV++;
+                } else if (u[pointerU] < v[pointerV]) {
+                    pointerU++;
+                } else {
+                    pointerV++;
+                }
             }
-
-            Collections.shuffle(userList);
-            return userList.subList(0, sampleSize);
+            return interactionSize;
         }
 
         @Override
@@ -367,15 +391,33 @@ public class Swing implements AlgoOperator<Swing>, 
SwingParams<Swing> {
                 throws Exception {
             Tuple3<Long, Long, long[]> tuple3 = streamRecord.getValue();
             long user = tuple3.f0;
+            long[] userBehavior = tuple3.f2;
             long mainItem = tuple3.f1;
-            Map<Long, String> items = new HashMap<>();
-            for (long item : tuple3.f2) {
-                items.put(item, null);
+
+            if (!userAndPurchasedItems.containsKey(user)) {
+                Arrays.sort(userBehavior);
+                userAndPurchasedItems.put(user, userBehavior);
             }
 
-            userAndPurchasedItems.putIfAbsent(user, items);
-            itemAndPurchasers.putIfAbsent(mainItem, new HashMap<>());
-            itemAndPurchasers.get(mainItem).putIfAbsent(user, null);
+            itemAndPurchasers.putIfAbsent(mainItem, new ArrayList<>());
+            List<Long> purchasers = itemAndPurchasers.get(mainItem);
+            // Use the Reservoir Sampling method to randomly select k 
purchasers from
+            // the stream of records where 1<=k<=maxUserNumPerItem.
+            // See https://en.wikipedia.org/wiki/Reservoir_sampling for more 
information on
+            // Reservoir Sampling.
+            if (purchasers.size() == 0) {
+                purchasers.add(0L);
+            }
+            long total = purchasers.get(0);
+            if (purchasers.size() <= maxUserNumPerItem) {
+                purchasers.add(user);
+            } else {
+                int index = random.nextInt((int) total) + 1;
+                if (index <= maxUserNumPerItem) {
+                    purchasers.set(index, user);
+                }
+            }
+            purchasers.set(0, ++total);
         }
 
         @Override
@@ -388,7 +430,8 @@ public class Swing implements AlgoOperator<Swing>, 
SwingParams<Swing> {
                                             "userAndPurchasedItemsState",
                                             Types.MAP(
                                                     Types.LONG,
-                                                    Types.MAP(Types.LONG, 
Types.STRING))));
+                                                    PrimitiveArrayTypeInfo
+                                                            
.LONG_PRIMITIVE_ARRAY_TYPE_INFO)));
 
             OperatorStateUtils.getUniqueElement(
                             userAndPurchasedItemsState, 
"userAndPurchasedItemsState")
@@ -399,9 +442,7 @@ public class Swing implements AlgoOperator<Swing>, 
SwingParams<Swing> {
                             .getListState(
                                     new ListStateDescriptor<>(
                                             "itemAndPurchasersState",
-                                            Types.MAP(
-                                                    Types.LONG,
-                                                    Types.MAP(Types.LONG, 
Types.STRING))));
+                                            Types.MAP(Types.LONG, 
Types.LIST(Types.LONG))));
 
             OperatorStateUtils.getUniqueElement(itemAndPurchasersState, 
"itemAndPurchasersState")
                     .ifPresent(stat -> itemAndPurchasers = stat);
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java
index 2f36c431..29623b1e 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java
@@ -19,6 +19,7 @@
 package org.apache.flink.ml.recommendation.swing;
 
 import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.common.param.HasSeed;
 import org.apache.flink.ml.param.DoubleParam;
 import org.apache.flink.ml.param.IntParam;
 import org.apache.flink.ml.param.Param;
@@ -31,7 +32,7 @@ import org.apache.flink.ml.param.WithParams;
  *
  * @param <T> The class type of this instance.
  */
-public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> {
+public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T>, 
HasSeed<T> {
     Param<String> USER_COL =
             new StringParam("userCol", "User column name.", "user", 
ParamValidators.notNull());
 
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
index cfa0f9cd..b45b6c29 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
@@ -42,6 +42,7 @@ import java.util.Comparator;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.fail;
 
 /** Tests {@link Swing}. */
@@ -120,6 +121,7 @@ public class SwingTest {
         assertEquals(15, swing.getAlpha1());
         assertEquals(0, swing.getAlpha2());
         assertEquals(0.3, swing.getBeta(), 1e-9);
+        assertEquals(swing.getClass().getName().hashCode(), swing.getSeed());
 
         swing.setItemCol("item_1")
                 .setUserCol("user_1")
@@ -129,7 +131,8 @@ public class SwingTest {
                 .setMaxUserBehavior(50)
                 .setAlpha1(5)
                 .setAlpha2(1)
-                .setBeta(0.35);
+                .setBeta(0.35)
+                .setSeed(1);
 
         assertEquals("item_1", swing.getItemCol());
         assertEquals("user_1", swing.getUserCol());
@@ -140,6 +143,7 @@ public class SwingTest {
         assertEquals(5, swing.getAlpha1());
         assertEquals(1, swing.getAlpha2());
         assertEquals(0.35, swing.getBeta(), 1e-9);
+        assertEquals(1, swing.getSeed());
     }
 
     @Test
@@ -220,7 +224,7 @@ public class SwingTest {
 
     @Test
     public void testSaveLoadAndTransform() throws Exception {
-        Swing swing = new Swing().setMinUserBehavior(1);
+        Swing swing = new Swing().setMinUserBehavior(2).setMaxUserBehavior(3);
         Swing loadedSwing =
                 TestUtils.saveAndReload(
                         tEnv, swing, tempFolder.newFolder().getAbsolutePath(), 
Swing::load);
@@ -228,4 +232,16 @@ public class SwingTest {
         List<Row> results = 
IteratorUtils.toList(outputTable.execute().collect());
         compareResultAndExpected(results);
     }
+
+    @Test
+    public void testSamplingMethod() {
+        env.setParallelism(1);
+        Swing swing1 = new 
Swing().setMinUserBehavior(1).setMaxUserNumPerItem(2).setSeed(3);
+        Swing swing2 = new 
Swing().setMinUserBehavior(1).setMaxUserNumPerItem(2);
+        Table[] result1 = swing1.transform(inputTable);
+        Table[] result2 = swing2.transform(inputTable);
+        int result1Size = 
IteratorUtils.toList(result1[0].execute().collect()).size();
+        int result2Size = 
IteratorUtils.toList(result2[0].execute().collect()).size();
+        assertNotEquals(result1Size, result2Size);
+    }
 }
diff --git a/flink-ml-python/pyflink/ml/recommendation/swing.py 
b/flink-ml-python/pyflink/ml/recommendation/swing.py
index 41fa1b87..7cdeef8f 100644
--- a/flink-ml-python/pyflink/ml/recommendation/swing.py
+++ b/flink-ml-python/pyflink/ml/recommendation/swing.py
@@ -17,7 +17,7 @@
 
################################################################################
 import typing
 
-from pyflink.ml.common.param import HasOutputCol
+from pyflink.ml.common.param import HasOutputCol, HasSeed
 from pyflink.ml.param import Param, StringParam, IntParam, FloatParam, 
ParamValidators
 from pyflink.ml.recommendation.common import JavaRecommendationAlgoOperator
 from pyflink.ml.wrapper import JavaWithParams
@@ -25,7 +25,8 @@ from pyflink.ml.wrapper import JavaWithParams
 
 class _SwingParams(
     JavaWithParams,
-    HasOutputCol
+    HasOutputCol,
+    HasSeed
 ):
     """
     Params for :class:`Swing`.
diff --git a/flink-ml-python/pyflink/ml/recommendation/tests/test_swing.py 
b/flink-ml-python/pyflink/ml/recommendation/tests/test_swing.py
index daa93575..45446a46 100644
--- a/flink-ml-python/pyflink/ml/recommendation/tests/test_swing.py
+++ b/flink-ml-python/pyflink/ml/recommendation/tests/test_swing.py
@@ -84,16 +84,18 @@ class SwingTest(PyFlinkMLTestCase):
         self.assertEqual(15, swing.alpha1)
         self.assertEqual(0, swing.alpha2)
         self.assertAlmostEqual(0.3, swing.beta, delta=1e-9)
+        self.assertEqual(438758276, swing.seed)
 
         swing.set_item_col("item_1") \
             .set_user_col("user_1") \
             .set_k(20) \
-            .set_max_user_num_per_item(500)\
+            .set_max_user_num_per_item(500) \
             .set_min_user_behavior(20) \
             .set_max_user_behavior(50) \
             .set_alpha1(5) \
             .set_alpha2(1) \
-            .set_beta(0.35)
+            .set_beta(0.35) \
+            .set_seed(1)
 
         self.assertEqual("item_1", swing.item_col)
         self.assertEqual("user_1", swing.user_col)
@@ -104,6 +106,7 @@ class SwingTest(PyFlinkMLTestCase):
         self.assertEqual(5, swing.alpha1)
         self.assertEqual(1, swing.alpha2)
         self.assertAlmostEqual(0.35, swing.beta, delta=1e-9)
+        self.assertEqual(1, swing.seed)
 
     def test_output_schema(self):
         swing = Swing() \
@@ -149,3 +152,12 @@ class SwingTest(PyFlinkMLTestCase):
             results.append([main_item, item_rank_score])
         results.sort(key=lambda x: x[0])
         self.assertEqual(expected_result, results)
+
+    def test_sampling_method(self):
+        swing1 = 
Swing().set_min_user_behavior(1).set_max_user_num_per_item(2).set_seed(3)
+        swing2 = Swing().set_min_user_behavior(1).set_max_user_num_per_item(2)
+        output1 = swing1.transform(self.input_table)[0]
+        output2 = swing2.transform(self.input_table)[0]
+        result1 = [result for result in 
self.t_env.to_data_stream(output1).execute_and_collect()]
+        result2 = [result for result in 
self.t_env.to_data_stream(output2).execute_and_collect()]
+        self.assertNotEqual(len(result1), len(result2))

Reply via email to