我写了一个基于时间和数量的trigger,我本来希望实现的是,比如:
1.在5秒内,如果数量达到了5000条,则提前触发,
2.如果到达5秒后,则不管有多少条都触发.
现在的情况是同一个时间点会触发多个基于时间的,我记得注册定时器,连续注册定时器,只有一个定时器才会触发,但是实际好像不是这样?
请大佬们帮我看下一下的代码有什么问题吗?
import lombok.extern.slf4j.Slf4j;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.streaming.api.windowing.triggers.Trigger;
import org.apache.flink.streaming.api.windowing.triggers.TriggerResult;
import org.apache.flink.streaming.api.windowing.windows.Window;

/**
 * Description:基于时间和数量的触发器
 *
 * @Author:dinghaohao
 * @Create:2024-09-09-10:30
 */
@Slf4j
public class CountWithTimeoutTrigger<T, W extends Window> extends Trigger<T, W> 
{
    private static final long serialVersionUID = 1L;

    private final long maxCount;

    private final ReducingStateDescriptor<Long> stateDesc = new 
ReducingStateDescriptor<>("count", new Sum(),
            LongSerializer.INSTANCE);

    private CountWithTimeoutTrigger(long maxCount) {
        this.maxCount = maxCount;
    }

    public static <T, W extends Window> CountWithTimeoutTrigger<T, W> of(long 
maxCount) {
        return new CountWithTimeoutTrigger<>(maxCount);
    }

    @Override
    public TriggerResult onEventTime(long time, W window, TriggerContext ctx) {
        return TriggerResult.CONTINUE;
    }

    private void registerNextTimer(TriggerContext ctx, W window) {
        // 在注册新定时器之前,先清除可能存在的旧定时器
        ctx.deleteProcessingTimeTimer(window.maxTimestamp());
        ctx.registerProcessingTimeTimer(window.maxTimestamp());
    }

    @Override
    public TriggerResult onElement(T element, long timestamp, W window, 
TriggerContext ctx) throws Exception {
        ReducingState<Long> count = ctx.getPartitionedState(stateDesc);
        count.add(1L);
        
        // 只在第一个元素到达时注册定时器
        if (count.get() == 1L) {
            registerNextTimer(ctx, window);
        }
        
        if (count.get() >= maxCount) {
            log.info("base on count trigger,count:{}", count.get());
            count.clear();
            // 清除旧的定时器
            ctx.deleteProcessingTimeTimer(window.maxTimestamp());
            // 注册新的定时器
            registerNextTimer(ctx, window);
            return TriggerResult.FIRE_AND_PURGE;
        }
        return TriggerResult.CONTINUE;
    }

    @Override
    public TriggerResult onProcessingTime(long time, W window, TriggerContext 
ctx) throws Exception {
        ReducingState<Long> count = ctx.getPartitionedState(stateDesc);
        if (count.get() == null || count.get() == 0L) {
            return TriggerResult.CONTINUE;
        }
        log.info("base on timeout trigger,count:{},time:{}", count.get(), 
window.maxTimestamp());
        count.clear();
        return TriggerResult.FIRE_AND_PURGE;
    }

    @Override
    public void clear(W window, TriggerContext ctx) throws Exception {
        ctx.deleteProcessingTimeTimer(window.maxTimestamp());
        ctx.getPartitionedState(stateDesc).clear();
    }

    @Override
    public boolean canMerge() {
        return true;
    }

    @Override
    public void onMerge(W window, OnMergeContext ctx) throws Exception {
        ctx.mergePartitionedState(stateDesc);
        // only register a timer if the time is not yet past the end of the 
merged
        // window
        // this is in line with the logic in onElement(). If the time is past 
the end of
        // the window onElement() will fire and setting a timer here would fire 
the
        // window twice.
        long windowMaxTimestamp = window.maxTimestamp();
        if (windowMaxTimestamp > ctx.getCurrentProcessingTime()) {
            ctx.registerProcessingTimeTimer(windowMaxTimestamp);
        }
    }

    @Override
    public String toString() {
        return "CountTrigger(" + maxCount + ")";
    }

    private static class Sum implements ReduceFunction<Long> {
        private static final long serialVersionUID = 1L;

        @Override
        public Long reduce(Long value1, Long value2) throws Exception {
            return value1 + value2;
        }
    }
}

回复