This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 2a599182 [FLINK-31325] Improve performance of Swing
2a599182 is described below
commit 2a5991827f8885d7a17023dee43da016be12dd79
Author: vacaly <[email protected]>
AuthorDate: Wed Apr 12 18:16:14 2023 +0800
[FLINK-31325] Improve performance of Swing
This closes #220.
---
.../flink/ml/recommendation/swing/Swing.java | 123 ++++++++++++++-------
.../flink/ml/recommendation/swing/SwingParams.java | 3 +-
.../apache/flink/ml/recommendation/SwingTest.java | 20 +++-
flink-ml-python/pyflink/ml/recommendation/swing.py | 5 +-
.../pyflink/ml/recommendation/tests/test_swing.py | 16 ++-
5 files changed, 119 insertions(+), 48 deletions(-)
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java
index 5a1ebc0c..022f44b7 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/Swing.java
@@ -49,12 +49,13 @@ import org.apache.flink.util.Preconditions;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.Random;
import java.util.stream.Collectors;
/**
@@ -147,9 +148,11 @@ public class Swing implements AlgoOperator<Swing>,
SwingParams<Swing> {
new ComputingSimilarItems(
getK(),
getMaxUserNumPerItem(),
+ getMaxUserBehavior(),
getAlpha1(),
getAlpha2(),
- getBeta()));
+ getBeta(),
+ getSeed()));
return new Table[] {tEnv.fromDataStream(output)};
}
@@ -182,8 +185,7 @@ public class Swing implements AlgoOperator<Swing>,
SwingParams<Swing> {
private final int maxUserItemInteraction;
// Maps a user id to a set of items. Because ListState cannot keep
values of type `Set`,
- // we use `Map<Long, String>` with null values instead. So does
`userAndPurchasedItems` and
- // `itemAndPurchasers` in `ComputingSimilarItems`.
+ // we use `Map<Long, String>` with null values instead.
private Map<Long, Map<Long, String>> userAndPurchasedItems = new
HashMap<>();
private ListState<Map<Long, Map<Long, String>>>
userAndPurchasedItemsState;
@@ -259,13 +261,9 @@ public class Swing implements AlgoOperator<Swing>,
SwingParams<Swing> {
private static class ComputingSimilarItems extends
AbstractStreamOperator<Row>
implements OneInputStreamOperator<Tuple3<Long, Long, long[]>,
Row>, BoundedOneInput {
- private Map<Long, Map<Long, String>> userAndPurchasedItems = new
HashMap<>();
- private Map<Long, Map<Long, String>> itemAndPurchasers = new
HashMap<>();
- private ListState<Map<Long, Map<Long, String>>>
userAndPurchasedItemsState;
- private ListState<Map<Long, Map<Long, String>>> itemAndPurchasersState;
-
private final int k;
private final int maxUserNumPerItem;
+ private final int maxUserBehavior;
private final int alpha1;
private final int alpha2;
@@ -274,13 +272,29 @@ public class Swing implements AlgoOperator<Swing>,
SwingParams<Swing> {
private static final Character commaDelimiter = ',';
private static final Character semicolonDelimiter = ';';
+ private final Random random;
+
+ private Map<Long, long[]> userAndPurchasedItems = new HashMap<>();
+ private Map<Long, List<Long>> itemAndPurchasers = new HashMap<>();
+
+ private ListState<Map<Long, long[]>> userAndPurchasedItemsState;
+ private ListState<Map<Long, List<Long>>> itemAndPurchasersState;
+
private ComputingSimilarItems(
- int k, int maxUserNumPerItem, int alpha1, int alpha2, double
beta) {
+ int k,
+ int maxUserNumPerItem,
+ int maxUserBehavior,
+ int alpha1,
+ int alpha2,
+ double beta,
+ long seed) {
this.k = k;
this.maxUserNumPerItem = maxUserNumPerItem;
+ this.maxUserBehavior = maxUserBehavior;
this.alpha1 = alpha1;
this.alpha2 = alpha2;
this.beta = beta;
+ this.random = new Random(seed);
}
@Override
@@ -289,36 +303,40 @@ public class Swing implements AlgoOperator<Swing>,
SwingParams<Swing> {
Map<Long, Double> userWeights = new
HashMap<>(userAndPurchasedItems.size());
userAndPurchasedItems.forEach(
(k, v) -> {
- int count = v.size();
+ int count = v.length;
userWeights.put(k, calculateWeight(count));
});
+ long[] interaction = new long[maxUserBehavior];
for (long mainItem : itemAndPurchasers.keySet()) {
- List<Long> userList =
- sampleUserList(itemAndPurchasers.get(mainItem),
maxUserNumPerItem);
+ List<Long> userList = itemAndPurchasers.get(mainItem);
HashMap<Long, Double> id2swing = new HashMap<>();
- for (int i = 0; i < userList.size(); i++) {
+ for (int i = 1; i < userList.size(); i++) {
long u = userList.get(i);
+ int interactionSize;
for (int j = i + 1; j < userList.size(); j++) {
long v = userList.get(j);
- HashSet<Long> interaction =
- new
HashSet<>(userAndPurchasedItems.get(u).keySet());
-
interaction.retainAll(userAndPurchasedItems.get(v).keySet());
- if (interaction.size() == 0) {
+ interactionSize =
+ calculateCommonItems(
+ userAndPurchasedItems.get(u),
+ userAndPurchasedItems.get(v),
+ interaction);
+ if (interactionSize == 0) {
continue;
}
double similarity =
- (userWeights.get(u)
+ userWeights.get(u)
* userWeights.get(v)
- / (alpha2 + interaction.size()));
- for (long simItem : interaction) {
+ / (alpha2 + interactionSize);
+ for (int k = 0; k < interactionSize; k++) {
+ long simItem = interaction[k];
if (simItem == mainItem) {
continue;
}
double itemSimilarity =
id2swing.getOrDefault(simItem, 0.0) +
similarity;
- id2swing.putIfAbsent(simItem, itemSimilarity);
+ id2swing.put(simItem, itemSimilarity);
}
}
}
@@ -350,16 +368,22 @@ public class Swing implements AlgoOperator<Swing>,
SwingParams<Swing> {
return (1.0 / Math.pow(alpha1 + size, beta));
}
- private static List<Long> sampleUserList(Map<Long, String> allUsers,
int sampleSize) {
- int totalSize = allUsers.size();
- List<Long> userList = new ArrayList<>(allUsers.keySet());
-
- if (totalSize < sampleSize) {
- return userList;
+ private static int calculateCommonItems(long[] u, long[] v, long[]
interaction) {
+ int pointerU = 0;
+ int pointerV = 0;
+ int interactionSize = 0;
+ while (pointerU < u.length && pointerV < v.length) {
+ if (u[pointerU] == v[pointerV]) {
+ interaction[interactionSize++] = u[pointerU];
+ pointerU++;
+ pointerV++;
+ } else if (u[pointerU] < v[pointerV]) {
+ pointerU++;
+ } else {
+ pointerV++;
+ }
}
-
- Collections.shuffle(userList);
- return userList.subList(0, sampleSize);
+ return interactionSize;
}
@Override
@@ -367,15 +391,33 @@ public class Swing implements AlgoOperator<Swing>,
SwingParams<Swing> {
throws Exception {
Tuple3<Long, Long, long[]> tuple3 = streamRecord.getValue();
long user = tuple3.f0;
+ long[] userBehavior = tuple3.f2;
long mainItem = tuple3.f1;
- Map<Long, String> items = new HashMap<>();
- for (long item : tuple3.f2) {
- items.put(item, null);
+
+ if (!userAndPurchasedItems.containsKey(user)) {
+ Arrays.sort(userBehavior);
+ userAndPurchasedItems.put(user, userBehavior);
}
- userAndPurchasedItems.putIfAbsent(user, items);
- itemAndPurchasers.putIfAbsent(mainItem, new HashMap<>());
- itemAndPurchasers.get(mainItem).putIfAbsent(user, null);
+ itemAndPurchasers.putIfAbsent(mainItem, new ArrayList<>());
+ List<Long> purchasers = itemAndPurchasers.get(mainItem);
+ // Use the Reservoir Sampling method to randomly select k
purchasers from
+ // the stream of records where 1<=k<=maxUserNumPerItem.
+ // See https://en.wikipedia.org/wiki/Reservoir_sampling for more
information on
+ // Reservoir Sampling.
+ if (purchasers.size() == 0) {
+ purchasers.add(0L);
+ }
+ long total = purchasers.get(0);
+ if (purchasers.size() <= maxUserNumPerItem) {
+ purchasers.add(user);
+ } else {
+ int index = random.nextInt((int) total) + 1;
+ if (index <= maxUserNumPerItem) {
+ purchasers.set(index, user);
+ }
+ }
+ purchasers.set(0, ++total);
}
@Override
@@ -388,7 +430,8 @@ public class Swing implements AlgoOperator<Swing>,
SwingParams<Swing> {
"userAndPurchasedItemsState",
Types.MAP(
Types.LONG,
- Types.MAP(Types.LONG,
Types.STRING))));
+ PrimitiveArrayTypeInfo
+
.LONG_PRIMITIVE_ARRAY_TYPE_INFO)));
OperatorStateUtils.getUniqueElement(
userAndPurchasedItemsState,
"userAndPurchasedItemsState")
@@ -399,9 +442,7 @@ public class Swing implements AlgoOperator<Swing>,
SwingParams<Swing> {
.getListState(
new ListStateDescriptor<>(
"itemAndPurchasersState",
- Types.MAP(
- Types.LONG,
- Types.MAP(Types.LONG,
Types.STRING))));
+ Types.MAP(Types.LONG,
Types.LIST(Types.LONG))));
OperatorStateUtils.getUniqueElement(itemAndPurchasersState,
"itemAndPurchasersState")
.ifPresent(stat -> itemAndPurchasers = stat);
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java
index 2f36c431..29623b1e 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/swing/SwingParams.java
@@ -19,6 +19,7 @@
package org.apache.flink.ml.recommendation.swing;
import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.common.param.HasSeed;
import org.apache.flink.ml.param.DoubleParam;
import org.apache.flink.ml.param.IntParam;
import org.apache.flink.ml.param.Param;
@@ -31,7 +32,7 @@ import org.apache.flink.ml.param.WithParams;
*
* @param <T> The class type of this instance.
*/
-public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T> {
+public interface SwingParams<T> extends WithParams<T>, HasOutputCol<T>,
HasSeed<T> {
Param<String> USER_COL =
new StringParam("userCol", "User column name.", "user",
ParamValidators.notNull());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
index cfa0f9cd..b45b6c29 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/SwingTest.java
@@ -42,6 +42,7 @@ import java.util.Comparator;
import java.util.List;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.fail;
/** Tests {@link Swing}. */
@@ -120,6 +121,7 @@ public class SwingTest {
assertEquals(15, swing.getAlpha1());
assertEquals(0, swing.getAlpha2());
assertEquals(0.3, swing.getBeta(), 1e-9);
+ assertEquals(swing.getClass().getName().hashCode(), swing.getSeed());
swing.setItemCol("item_1")
.setUserCol("user_1")
@@ -129,7 +131,8 @@ public class SwingTest {
.setMaxUserBehavior(50)
.setAlpha1(5)
.setAlpha2(1)
- .setBeta(0.35);
+ .setBeta(0.35)
+ .setSeed(1);
assertEquals("item_1", swing.getItemCol());
assertEquals("user_1", swing.getUserCol());
@@ -140,6 +143,7 @@ public class SwingTest {
assertEquals(5, swing.getAlpha1());
assertEquals(1, swing.getAlpha2());
assertEquals(0.35, swing.getBeta(), 1e-9);
+ assertEquals(1, swing.getSeed());
}
@Test
@@ -220,7 +224,7 @@ public class SwingTest {
@Test
public void testSaveLoadAndTransform() throws Exception {
- Swing swing = new Swing().setMinUserBehavior(1);
+ Swing swing = new Swing().setMinUserBehavior(2).setMaxUserBehavior(3);
Swing loadedSwing =
TestUtils.saveAndReload(
tEnv, swing, tempFolder.newFolder().getAbsolutePath(),
Swing::load);
@@ -228,4 +232,16 @@ public class SwingTest {
List<Row> results =
IteratorUtils.toList(outputTable.execute().collect());
compareResultAndExpected(results);
}
+
+ @Test
+ public void testSamplingMethod() {
+ env.setParallelism(1);
+ Swing swing1 = new
Swing().setMinUserBehavior(1).setMaxUserNumPerItem(2).setSeed(3);
+ Swing swing2 = new
Swing().setMinUserBehavior(1).setMaxUserNumPerItem(2);
+ Table[] result1 = swing1.transform(inputTable);
+ Table[] result2 = swing2.transform(inputTable);
+ int result1Size =
IteratorUtils.toList(result1[0].execute().collect()).size();
+ int result2Size =
IteratorUtils.toList(result2[0].execute().collect()).size();
+ assertNotEquals(result1Size, result2Size);
+ }
}
diff --git a/flink-ml-python/pyflink/ml/recommendation/swing.py
b/flink-ml-python/pyflink/ml/recommendation/swing.py
index 41fa1b87..7cdeef8f 100644
--- a/flink-ml-python/pyflink/ml/recommendation/swing.py
+++ b/flink-ml-python/pyflink/ml/recommendation/swing.py
@@ -17,7 +17,7 @@
################################################################################
import typing
-from pyflink.ml.common.param import HasOutputCol
+from pyflink.ml.common.param import HasOutputCol, HasSeed
from pyflink.ml.param import Param, StringParam, IntParam, FloatParam,
ParamValidators
from pyflink.ml.recommendation.common import JavaRecommendationAlgoOperator
from pyflink.ml.wrapper import JavaWithParams
@@ -25,7 +25,8 @@ from pyflink.ml.wrapper import JavaWithParams
class _SwingParams(
JavaWithParams,
- HasOutputCol
+ HasOutputCol,
+ HasSeed
):
"""
Params for :class:`Swing`.
diff --git a/flink-ml-python/pyflink/ml/recommendation/tests/test_swing.py
b/flink-ml-python/pyflink/ml/recommendation/tests/test_swing.py
index daa93575..45446a46 100644
--- a/flink-ml-python/pyflink/ml/recommendation/tests/test_swing.py
+++ b/flink-ml-python/pyflink/ml/recommendation/tests/test_swing.py
@@ -84,16 +84,18 @@ class SwingTest(PyFlinkMLTestCase):
self.assertEqual(15, swing.alpha1)
self.assertEqual(0, swing.alpha2)
self.assertAlmostEqual(0.3, swing.beta, delta=1e-9)
+ self.assertEqual(438758276, swing.seed)
swing.set_item_col("item_1") \
.set_user_col("user_1") \
.set_k(20) \
- .set_max_user_num_per_item(500)\
+ .set_max_user_num_per_item(500) \
.set_min_user_behavior(20) \
.set_max_user_behavior(50) \
.set_alpha1(5) \
.set_alpha2(1) \
- .set_beta(0.35)
+ .set_beta(0.35) \
+ .set_seed(1)
self.assertEqual("item_1", swing.item_col)
self.assertEqual("user_1", swing.user_col)
@@ -104,6 +106,7 @@ class SwingTest(PyFlinkMLTestCase):
self.assertEqual(5, swing.alpha1)
self.assertEqual(1, swing.alpha2)
self.assertAlmostEqual(0.35, swing.beta, delta=1e-9)
+ self.assertEqual(1, swing.seed)
def test_output_schema(self):
swing = Swing() \
@@ -149,3 +152,12 @@ class SwingTest(PyFlinkMLTestCase):
results.append([main_item, item_rank_score])
results.sort(key=lambda x: x[0])
self.assertEqual(expected_result, results)
+
+ def test_sampling_method(self):
+ swing1 =
Swing().set_min_user_behavior(1).set_max_user_num_per_item(2).set_seed(3)
+ swing2 = Swing().set_min_user_behavior(1).set_max_user_num_per_item(2)
+ output1 = swing1.transform(self.input_table)[0]
+ output2 = swing2.transform(self.input_table)[0]
+ result1 = [result for result in
self.t_env.to_data_stream(output1).execute_and_collect()]
+ result2 = [result for result in
self.t_env.to_data_stream(output2).execute_and_collect()]
+ self.assertNotEqual(len(result1), len(result2))