jerqi commented on code in PR #1245: URL: https://github.com/apache/incubator-uniffle/pull/1245#discussion_r1363115246
########## client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssMergeManager.java: ########## @@ -0,0 +1,606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.runtime.library.common.shuffle.orderedgrouped; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.FileChunk; +import org.apache.hadoop.io.RawComparator; +import org.apache.hadoop.io.compress.CompressionCodec; +import org.apache.hadoop.util.Progressable; +import org.apache.hadoop.util.Time; +import org.apache.tez.common.counters.TaskCounter; +import org.apache.tez.common.counters.TezCounter; +import org.apache.tez.runtime.api.InputContext; +import org.apache.tez.runtime.library.api.TezRuntimeConfiguration; +import org.apache.tez.runtime.library.common.ConfigUtils; +import org.apache.tez.runtime.library.common.Constants; +import org.apache.tez.runtime.library.common.InputAttemptIdentifier; +import org.apache.tez.runtime.library.common.combine.Combiner; +import org.apache.tez.runtime.library.common.sort.impl.IFile; +import org.apache.tez.runtime.library.common.sort.impl.TezMerger; +import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.filesystem.HadoopFilesystemProvider; + +import static org.apache.tez.common.RssTezConfig.RSS_REMOTE_SPILL_STORAGE_PATH; + +public class RssMergeManager extends MergeManager { + + private static final Logger LOG = LoggerFactory.getLogger(MergeManager.class); + + private Configuration conf; + private InputContext inputContext; + private ExceptionReporter exceptionReporter; + private RssInMemoryMerger inMemoryMerger; + private Combiner combiner; + private CompressionCodec codec; + private boolean ifileReadAhead; + private int ifileReadAheadLength; + private int ifileBufferSize; + private String appAttemptId; + + private final long initialMemoryAvailable; + private final long memoryLimit; + private final long maxSingleShuffleLimit; + private long mergeThreshold; + private long commitMemory; + private long usedMemory; + + private String spillBasePath; + private FileSystem remoteFS; + + // Variables for stats and logging + private long lastInMemSegmentLogTime = -1L; + private final SegmentStatsTracker statsInMemTotal = new SegmentStatsTracker(); + private final SegmentStatsTracker statsInMemLastLog = new SegmentStatsTracker(); + + private final TezCounter spilledRecordsCounter; + private final TezCounter mergedMapOutputsCounter; + private final TezCounter additionalBytesRead; + + private final AtomicBoolean isShutdown = new AtomicBoolean(false); + private final boolean cleanup; + + private final Progressable progressable = + new Progressable() { + @Override + public void progress() { + inputContext.notifyProgress(); + } + }; + + public RssMergeManager( + Configuration conf, + FileSystem localFS, + InputContext inputContext, + Combiner combiner, + TezCounter spilledRecordsCounter, + TezCounter reduceCombineInputCounter, + TezCounter mergedMapOutputsCounter, + ExceptionReporter exceptionReporter, + long initialMemoryAvailable, + CompressionCodec codec, + boolean ifileReadAheadEnabled, + int ifileReadAheadLength, + Configuration remoteConf, + int replication, + int retries, + String appAttemptId) { + super( + conf, + localFS, + null, + inputContext, + combiner, + spilledRecordsCounter, + reduceCombineInputCounter, + mergedMapOutputsCounter, + exceptionReporter, + initialMemoryAvailable, + codec, + ifileReadAheadEnabled, + ifileReadAheadLength); + this.conf = conf; + this.inputContext = inputContext; + this.exceptionReporter = exceptionReporter; + this.codec = codec; + this.combiner = combiner; + this.initialMemoryAvailable = initialMemoryAvailable; + this.ifileReadAhead = ifileReadAheadEnabled; + if (this.ifileReadAhead) { + this.ifileReadAheadLength = ifileReadAheadLength; + } else { + this.ifileReadAheadLength = 0; + } + this.ifileBufferSize = + conf.getInt( + "io.file.buffer.size", TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_BUFFER_SIZE_DEFAULT); + this.appAttemptId = appAttemptId; + this.cleanup = + conf.getBoolean( + TezRuntimeConfiguration.TEZ_RUNTIME_CLEANUP_FILES_ON_INTERRUPT, + TezRuntimeConfiguration.TEZ_RUNTIME_CLEANUP_FILES_ON_INTERRUPT_DEFAULT); + + // Set memory, here ignore some check which have done in MergeManager + final float maxInMemCopyUse = + conf.getFloat( + TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_BUFFER_PERCENT, + TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_BUFFER_PERCENT_DEFAULT); + long memLimit = + conf.getLong( + Constants.TEZ_RUNTIME_TASK_MEMORY, + (long) (inputContext.getTotalMemoryAvailableToTask() * maxInMemCopyUse)); + if (this.initialMemoryAvailable < memLimit) { + this.memoryLimit = this.initialMemoryAvailable; + } else { + this.memoryLimit = memLimit; + } + this.mergeThreshold = + (long) + (this.memoryLimit + * conf.getFloat( + TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MERGE_PERCENT, + TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MERGE_PERCENT_DEFAULT)); + final float singleShuffleMemoryLimitPercent = + conf.getFloat( + TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MEMORY_LIMIT_PERCENT, + TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MEMORY_LIMIT_PERCENT_DEFAULT); + this.maxSingleShuffleLimit = + (long) Math.min((memoryLimit * singleShuffleMemoryLimitPercent), Integer.MAX_VALUE); + + // counter + this.spilledRecordsCounter = spilledRecordsCounter; + this.mergedMapOutputsCounter = mergedMapOutputsCounter; + this.additionalBytesRead = + inputContext.getCounters().findCounter(TaskCounter.ADDITIONAL_SPILLS_BYTES_READ); + + // remote fs + Configuration remoteConfCopied = new Configuration(remoteConf); + this.spillBasePath = conf.get(RSS_REMOTE_SPILL_STORAGE_PATH); + try { + remoteConfCopied.setInt("dfs.replication", replication); + remoteConfCopied.setInt("dfs.client.block.write.retries", retries); + this.remoteFS = + HadoopFilesystemProvider.getFilesystem(new Path(spillBasePath), remoteConfCopied); + } catch (Exception e) { + throw new RssException("Cannot init remoteFS on path:" + spillBasePath); + } + if (StringUtils.isBlank(this.spillBasePath)) { + throw new RssException("You must set remote spill path!"); + } + this.inMemoryMerger = createRssInMemoryMerger(); + } + + private RssInMemoryMerger createRssInMemoryMerger() { + return new RssInMemoryMerger( + this, + this.conf, + inputContext, + combiner, + exceptionReporter, + codec, + remoteFS, + spillBasePath, + appAttemptId); + } + + @Override + void configureAndStart() { + this.inMemoryMerger.start(); + } + + @Override + public void waitForInMemoryMerge() throws InterruptedException { + inMemoryMerger.waitForMerge(); + + /** + * Memory released during merge process could have been used by active fetchers and if they + * merge was already in progress, this would not have kicked off another merge and fetchers + * could get into indefinite wait state later. To address this, trigger another merge process if + * needed and wait for it to complete (to release committedMemory & usedMemory). + */ + boolean triggerAdditionalMerge = false; + synchronized (this) { + if (commitMemory >= mergeThreshold) { + startMemToRemoteMerge(); + triggerAdditionalMerge = true; + } + } + if (triggerAdditionalMerge) { + inMemoryMerger.waitForMerge(); + if (LOG.isDebugEnabled()) { + LOG.debug("Additional in-memory merge triggered"); + } + } + } + + private boolean canShuffleToMemory(long requestedSize) { + // TODO: large shuffle data should be store in remote fs directly + return true; + } + + @Override + public synchronized void waitForShuffleToMergeMemory() throws InterruptedException { + long startTime = System.currentTimeMillis(); + while (usedMemory > memoryLimit) { + wait(); + } + if (LOG.isDebugEnabled()) { + LOG.debug( + "Waited for " + + (System.currentTimeMillis() - startTime) + + " for memory to become" + + " available"); + } + } + + private final MapOutput stallShuffle = MapOutput.createWaitMapOutput(null); + + @Override + public synchronized MapOutput reserve( + InputAttemptIdentifier srcAttemptIdentifier, + long requestedSize, + long compressedLength, + int fetcher) + throws IOException { + if (!canShuffleToMemory(requestedSize)) { + if (LOG.isDebugEnabled()) { + LOG.debug( + srcAttemptIdentifier + + ": Shuffling to remote fs since " + + requestedSize + + " is greater than maxSingleShuffleLimit (" + + maxSingleShuffleLimit + + ")"); + } + throw new RssException("Shuffle large date is not implemented!"); + } + + // Stall shuffle if we are above the memory limit + + // It is possible that all threads could just be stalling and not make + // progress at all. This could happen when: + // + // requested size is causing the used memory to go above limit && + // requested size < singleShuffleLimit && + // current used size < mergeThreshold (merge will not get triggered) + // + // To avoid this from happening, we allow exactly one thread to go past + // the memory limit. We check (usedMemory > memoryLimit) and not + // (usedMemory + requestedSize > memoryLimit). When this thread is done + // fetching, this will automatically trigger a merge thereby unlocking + // all the stalled threads + + if (usedMemory > memoryLimit) { + if (LOG.isDebugEnabled()) { + LOG.debug( + srcAttemptIdentifier + + ": Stalling shuffle since usedMemory (" + + usedMemory + + ") is greater than memoryLimit (" + + memoryLimit + + ")." + + " CommitMemory is (" + + commitMemory + + ")"); + } + return stallShuffle; + } + + // Allow the in-memory shuffle to progress + if (LOG.isDebugEnabled()) { + LOG.debug( + srcAttemptIdentifier + + ": Proceeding with shuffle since usedMemory (" + + usedMemory + + ") is lesser than memoryLimit (" + + memoryLimit + + ")." + + "CommitMemory is (" + + commitMemory + + ")"); + } + return unconditionalReserve(srcAttemptIdentifier, requestedSize, true); + } + + private synchronized MapOutput unconditionalReserve( + InputAttemptIdentifier srcAttemptIdentifier, long requestedSize, boolean primaryMapOutput) + throws IOException { + usedMemory += requestedSize; + return MapOutput.createMemoryMapOutput( + srcAttemptIdentifier, this, (int) requestedSize, primaryMapOutput); + } + + @Override + public synchronized void unreserve(long size) { + usedMemory -= size; + if (LOG.isDebugEnabled()) { + LOG.debug( + "Notifying unreserve : size=" + + size + + ", commitMemory=" + + commitMemory + + ", usedMemory=" + + usedMemory + + ", mergeThreshold=" + + mergeThreshold); + } + notifyAll(); + } + + @Override + public synchronized void releaseCommittedMemory(long size) { + commitMemory -= size; + unreserve(size); + } + + @Override + public synchronized void closeInMemoryFile(MapOutput mapOutput) { + inMemoryMapOutputs.add(mapOutput); + trackAndLogCloseInMemoryFile(mapOutput); + commitMemory += mapOutput.getSize(); + if (commitMemory >= mergeThreshold) { + startMemToRemoteMerge(); + } + } + + private void trackAndLogCloseInMemoryFile(MapOutput mapOutput) { + statsInMemTotal.updateStats(mapOutput.getSize()); + + if (LOG.isDebugEnabled()) { + LOG.debug( + "closeInMemoryFile -> map-output of size: " + + mapOutput.getSize() + + ", inMemoryMapOutputs.size() -> " + + inMemoryMapOutputs.size() + + ", commitMemory -> " + + commitMemory + + ", usedMemory ->" + + usedMemory + + ", mapOutput=" + + mapOutput); + } else { Review Comment: ``` if (LOG.isDebugEnabled()) { LOG.debug( "closeInMemoryFile -> map-output of size: " + mapOutput.getSize() + ", inMemoryMapOutputs.size() -> " + inMemoryMapOutputs.size() + ", commitMemory -> " + commitMemory + ", usedMemory ->" + usedMemory + ", mapOutput=" + mapOutput); } statsInMemLastLog.updateStats(mapOutput.getSize()); ``` Maybe we could remove the `else`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
