Davis-Zhang-Onehouse commented on code in PR #13489: URL: https://github.com/apache/hudi/pull/13489#discussion_r2220008258
########## hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/data/TestHoodieJavaPairRDDDynamicRepartition.java: ########## @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hudi.data; + +import org.apache.hudi.client.common.HoodieSparkEngineContext; +import org.apache.hudi.common.data.HoodiePairData; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.metadata.HoodieBackedTestDelayedTableMetadata; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import scala.Tuple2; +import scala.Tuple3; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TestHoodieJavaPairRDDDynamicRepartition { + + private static final Logger LOG = LoggerFactory.getLogger(HoodieBackedTestDelayedTableMetadata.class); + + private JavaSparkContext jsc; + + /** + * Generates a random RDD with unbalanced data distribution across partitions. + * + * @param sc Spark context + * @param maxValueByKey Map of key to maximum number of values + * @param partitionWeights List of weights for each partition + * @param seed seed used for randomization + * @return RDD with weighted partition distribution + */ + public static JavaPairRDD<Integer, String> generateRandomRDDWithWeightedPartitions( + JavaSparkContext sc, + Map<Integer, Long> maxValueByKey, + List<Double> partitionWeights, + long seed) { + + // Generate all possible pairs of key and value in a single list. + List<Tuple2<Integer, String>> allPairs = new ArrayList<>(); + for (Map.Entry<Integer, Long> e : maxValueByKey.entrySet()) { + for (long v = 1; v <= e.getValue(); v++) { + allPairs.add(new Tuple2<>(e.getKey(), Long.toString(v))); + } + } + + Collections.shuffle(allPairs, new Random(seed)); + + int total = allPairs.size(); + List<JavaPairRDD<Integer, String>> rdds = new ArrayList<>(); + int start = 0; + + // Split the list into partitions based on the weights. + for (int i = 0; i < partitionWeights.size(); i++) { + int end = (i == partitionWeights.size() - 1) + ? total + : Math.min(total, start + (int) Math.round(partitionWeights.get(i) * total)); + + List<Tuple2<Integer, String>> slice = allPairs.subList(start, end); + JavaPairRDD<Integer, String> sliceRdd = sc.parallelize(slice, 1).mapToPair(t -> t); + rdds.add(sliceRdd); + start = end; + if (start >= total) { + break; + } + } + + // Combine all the partitions into a single RDD. + JavaPairRDD<Integer, String> combined = rdds.get(0); + for (int i = 1; i < rdds.size(); i++) { + combined = combined.union(rdds.get(i)); + } + + return combined; + } + + /** + * Validates various properties of a repartitioned RDD, including: + * 1. Each key is in exactly one partition. + * 2. The keys are sorted within each partition. + * 3. For partitions containing entries of the same key, the value ranges are not overlapping. + * 4. Number of keys per partition is probably at most maxKeyPerBucket. + * + * @param originalRdd Original RDD + * @param repartitionedRdd Repartitioned RDD + * @param maxPartitionCountByKey Map of key to maximum number of partitions + * @throws AssertionError if any check fails + */ + private static Map<Integer, Map<Integer, List<String>>> validateRepartitionedRDDProperties( + HoodiePairData<Integer, String> originalRdd, + HoodiePairData<Integer, String> repartitionedRdd, + Option<Map<Integer, Integer>> maxPartitionCountByKey) { + JavaPairRDD<Integer, String> javaPairRDD = HoodieJavaPairRDD.getJavaPairRDD(repartitionedRdd); + + Map<Integer, Map<Integer, List<String>>> actualPartitionContents = dumpRDDContent(javaPairRDD); + + try { + // Values in each partition are sorted. + for (Map.Entry<Integer, Map<Integer, List<String>>> p : actualPartitionContents.entrySet()) { + int partitionId = p.getKey(); + Map<Integer, List<String>> keyToValues = p.getValue(); + + if (keyToValues.size() != 1) { + assertEquals(1, keyToValues.size(), + "Each partition should contain exactly one key, but found keys " + keyToValues.keySet() + + " in partition " + partitionId); + logRDDContent("validation failure, original rdd ", originalRdd); + logRDDContent("validation failure, repartitioned rdd ", repartitionedRdd); + } + + for (Map.Entry<Integer, List<String>> kv : keyToValues.entrySet()) { + List<String> values = kv.getValue(); + List<String> sorted = new ArrayList<>(values); + Collections.sort(sorted); + if (!values.equals(sorted)) { + throw new AssertionError( + "Partition " + partitionId + ", key " + kv.getKey() + + " has unsorted values: " + values); + } + } + } + + // Build key → list<(partitionId, min, max)> + Map<Integer, List<Tuple3<Integer, String, String>>> keyToPartitionRanges = new HashMap<>(); + + for (Map.Entry<Integer, Map<Integer, List<String>>> p : actualPartitionContents.entrySet()) { + int partitionId = p.getKey(); + for (Map.Entry<Integer, List<String>> kv : p.getValue().entrySet()) { + List<String> sorted = new ArrayList<>(kv.getValue()); + Collections.sort(sorted); + keyToPartitionRanges + .computeIfAbsent(kv.getKey(), k -> new ArrayList<>()) + .add(new Tuple3<>(partitionId, sorted.get(0), sorted.get(sorted.size() - 1))); + } + } + + // Range-overlap check and expected-partition-count check + for (Map.Entry<Integer, List<Tuple3<Integer, String, String>>> e : keyToPartitionRanges.entrySet()) { + int key = e.getKey(); + List<Tuple3<Integer, String, String>> ranges = e.getValue(); + + // Confirm expected #partitions + if (maxPartitionCountByKey.isPresent()) { + Integer maxPartitionCnt = maxPartitionCountByKey.get().get(key); + if (maxPartitionCnt == null) { + throw new AssertionError("Unexpected key " + key + + " appeared in RDD but not in expectedPartitionsPerKey map"); + } + if (ranges.size() > maxPartitionCnt) { + throw new AssertionError("Key " + key + " should occupy at most " + maxPartitionCnt + + " partitions but actually occupies " + ranges.size()); + } + } + + // Check that ranges do not overlap (string order) + ranges.sort(Comparator.comparing(t -> t._2())); // sort by min + for (int i = 1; i < ranges.size(); i++) { + Tuple3<Integer, String, String> prev = ranges.get(i - 1); + Tuple3<Integer, String, String> curr = ranges.get(i); + if (curr._2().compareTo(prev._3()) <= 0) { + throw new AssertionError( + String.format( + "Key %d has overlapping ranges: partition %d [%s-%s] vs partition %d [%s-%s]", + key, + prev._1(), prev._2(), prev._3(), + curr._1(), curr._2(), curr._3())); + } + } + } + + // Verify no key is missing from actual data + if (maxPartitionCountByKey.isPresent()) { + for (Integer expectedKey : maxPartitionCountByKey.get().keySet()) { + if (!keyToPartitionRanges.containsKey(expectedKey)) { + throw new AssertionError("Expected key " + expectedKey + " never appeared in the RDD"); + } + } + } + } catch (AssertionError e) { + logRDDContent("Original RDD", originalRdd); + logRDDContent("Repartitioned RDD", repartitionedRdd); + LOG.error("Validation failed: " + e.getMessage(), e); + throw e; // rethrow to fail the test + } + return actualPartitionContents; // handy for unit-test callers + } + + /** + * Dumps the content of an RDD to a map of partition id to key to values. + * + * @param javaPairRDD RDD to dump + * @return Map of partition id to key to values + */ + private static Map<Integer, Map<Integer, List<String>>> dumpRDDContent(JavaPairRDD<Integer, String> javaPairRDD) { + Map<Integer, Map<Integer, List<String>>> actualPartitionContents = new HashMap<>(); + + javaPairRDD + .mapPartitionsWithIndex((idx, iter) -> { + Map<Integer, List<String>> keyToValues = new HashMap<>(); + while (iter.hasNext()) { + Tuple2<Integer, String> row = iter.next(); + keyToValues + .computeIfAbsent(row._1(), k -> new ArrayList<>()) + .add(row._2()); + } + return Collections.singletonList(new Tuple2<>(idx, keyToValues)).iterator(); + }, true) + .collect() + .forEach(t -> actualPartitionContents.put(t._1(), t._2())); + return actualPartitionContents; + } + + /** + * Logs the content of an RDD to the console. + * + * @param label Label for the RDD + * @param pairData RDD to log + */ + private static void logRDDContent(String label, HoodiePairData<Integer, String> pairData) { + JavaPairRDD<Integer, String> rdd = HoodieJavaPairRDD.getJavaPairRDD(pairData); + + LOG.info("===== {} =====", label); + rdd + .mapPartitionsWithIndex((idx, iter) -> { + StringBuilder builder = new StringBuilder(); + builder.append("Partition ").append(idx).append(": ["); + + while (iter.hasNext()) { + Tuple2<Integer, String> kv = iter.next(); + builder.append("(").append(kv._1).append(", ").append(kv._2).append(")").append(", "); + } + builder.append("]"); + return Collections.singletonList(builder.toString()).iterator(); + }, true) + .collect() + .forEach(LOG::info); + LOG.info("============================\n"); + } + + @BeforeEach + public void setUp() { + jsc = new JavaSparkContext("local[2]", "test"); + } + + @AfterEach + public void tearDown() { + if (jsc != null) { + jsc.stop(); + jsc = null; + } + } + + @Test + public void testRangeBasedRepartitionForEachKey() { Review Comment: moved, thanks for pointing it out -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
