This is an automated email from the ASF dual-hosted git repository.

guohao pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/dubbo.git


The following commit(s) were added to refs/heads/3.0 by this push:
     new 7c2f52d  [3.0-Triple] support streamObserver cancel (#8946)
7c2f52d is described below

commit 7c2f52d4b59200a4a1f345c3718c4d95171e5848
Author: earthchen <[email protected]>
AuthorDate: Fri Oct 8 04:49:01 2021 -0500

    [3.0-Triple] support streamObserver cancel (#8946)
    
    * support observer cancel
    
    * remove context where need
    
    * Rename to cancelableStreamObserver and remove method in triple util
    
    * support cast CancelableStreamObserver for cancel
    
    * Change how the server context is used
    
    * remove unused code
    
    * remove unused code
    
    * remove server stream execute
    
    * fix typo
    
    * remove unused code
    
    Co-authored-by: guohao <[email protected]>
---
 .../main/java/org/apache/dubbo/rpc/RpcContext.java | 11 +++
 .../java/org/apache/dubbo/rpc/RpcInvocation.java   | 10 ---
 .../rpc/protocol/tri/AbstractClientStream.java     | 22 ++++-
 .../rpc/protocol/tri/AbstractServerStream.java     | 18 ++--
 .../dubbo/rpc/protocol/tri/AbstractStream.java     | 24 ++++--
 .../rpc/protocol/tri/CancelableStreamObserver.java | 41 +++++++++
 .../dubbo/rpc/protocol/tri/ClientStream.java       |  4 +-
 .../dubbo/rpc/protocol/tri/ServerStream.java       | 97 +++++++++++++++++-----
 .../protocol/tri/TripleClientRequestHandler.java   | 52 ++++++------
 .../tri/TripleHttp2FrameServerHandler.java         |  9 +-
 .../dubbo/rpc/protocol/tri/TripleInvoker.java      |  5 --
 11 files changed, 204 insertions(+), 89 deletions(-)

diff --git 
a/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/RpcContext.java 
b/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/RpcContext.java
index d1e15ee..e39f421 100644
--- a/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/RpcContext.java
+++ b/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/RpcContext.java
@@ -84,6 +84,9 @@ public class RpcContext {
         }
     };
 
+    /**
+     * use by cancel call
+     */
     private static final InternalThreadLocal<CancellationContext> 
CANCELLATION_CONTEXT = new InternalThreadLocal<CancellationContext>() {
         @Override
         protected CancellationContext initialValue() {
@@ -96,6 +99,14 @@ public class RpcContext {
         return CANCELLATION_CONTEXT.get();
     }
 
+    public static void removeCancellationContext() {
+        CANCELLATION_CONTEXT.remove();
+    }
+
+    public static void restoreCancellationContext(CancellationContext 
oldContext) {
+        CANCELLATION_CONTEXT.set(oldContext);
+    }
+
 
     private boolean remove = true;
 
diff --git 
a/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/RpcInvocation.java 
b/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/RpcInvocation.java
index 56ccb61..1b7b412 100644
--- 
a/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/RpcInvocation.java
+++ 
b/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/RpcInvocation.java
@@ -86,16 +86,6 @@ public class RpcInvocation implements Invocation, 
Serializable {
 
     private transient InvokeMode invokeMode;
 
-    private transient CancellationContext cancellationContext;
-
-    public CancellationContext getCancellationContext() {
-        return cancellationContext;
-    }
-
-    public void setCancellationContext(CancellationContext 
cancellationContext) {
-        this.cancellationContext = cancellationContext;
-    }
-
     public RpcInvocation() {
     }
 
diff --git 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractClientStream.java
 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractClientStream.java
index 9d34461..164e799 100644
--- 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractClientStream.java
+++ 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractClientStream.java
@@ -19,9 +19,9 @@ package org.apache.dubbo.rpc.protocol.tri;
 
 import org.apache.dubbo.common.URL;
 import org.apache.dubbo.common.constants.CommonConstants;
-import org.apache.dubbo.common.stream.StreamObserver;
 import org.apache.dubbo.remoting.api.Connection;
 import org.apache.dubbo.remoting.exchange.support.DefaultFuture2;
+import org.apache.dubbo.rpc.CancellationContext;
 import org.apache.dubbo.rpc.RpcInvocation;
 import org.apache.dubbo.rpc.model.ConsumerModel;
 import org.apache.dubbo.triple.TripleWrapper;
@@ -50,10 +50,20 @@ public abstract class AbstractClientStream extends 
AbstractStream implements Str
         return new UnaryClientStream(url);
     }
 
-    public static AbstractClientStream stream(URL url) {
+    public static ClientStream stream(URL url) {
         return new ClientStream(url);
     }
 
+    public static AbstractClientStream newClientStream(URL url, boolean unary) 
{
+        AbstractClientStream stream = unary ? unary(url) : stream(url);
+        final CancellationContext cancellationContext = 
stream.getCancellationContext();
+        // for client cancel,send rst frame to server
+        cancellationContext.addListener(context -> {
+            stream.asTransportObserver().onReset(Http2Error.CANCEL);
+        });
+        return stream;
+    }
+
     public AbstractClientStream service(ConsumerModel model) {
         this.consumerModel = model;
         return this;
@@ -166,7 +176,7 @@ public abstract class AbstractClientStream extends 
AbstractStream implements Str
         return metadata;
     }
 
-    protected class ClientStreamObserver implements StreamObserver<Object> {
+    protected class ClientStreamObserver extends 
CancelableStreamObserver<Object> {
 
         @Override
         public void onNext(Object data) {
@@ -179,7 +189,6 @@ public abstract class AbstractClientStream extends 
AbstractStream implements Str
 
         @Override
         public void onError(Throwable throwable) {
-
         }
 
         @Override
@@ -192,4 +201,9 @@ public abstract class AbstractClientStream extends 
AbstractStream implements Str
     protected void cancelByRemoteReset(Http2Error http2Error) {
         DefaultFuture2.getFuture(getRequest().getId()).cancel();
     }
+
+    @Override
+    protected void cancelByLocal(Throwable throwable) {
+        getCancellationContext().cancel(throwable);
+    }
 }
diff --git 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractServerStream.java
 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractServerStream.java
index 7069ab9..71891b4 100644
--- 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractServerStream.java
+++ 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractServerStream.java
@@ -23,7 +23,6 @@ import 
org.apache.dubbo.common.threadpool.manager.ExecutorRepository;
 import org.apache.dubbo.remoting.Constants;
 import org.apache.dubbo.rpc.HeaderFilter;
 import org.apache.dubbo.rpc.Invoker;
-import org.apache.dubbo.rpc.RpcContext;
 import org.apache.dubbo.rpc.RpcInvocation;
 import org.apache.dubbo.rpc.model.FrameworkServiceRepository;
 import org.apache.dubbo.rpc.model.MethodDescriptor;
@@ -82,14 +81,18 @@ public abstract class AbstractServerStream extends 
AbstractStream implements Str
         return executor;
     }
 
-    public static AbstractServerStream unary(URL url) {
+    public static UnaryServerStream unary(URL url) {
         return new UnaryServerStream(url);
     }
 
-    public static AbstractServerStream stream(URL url) {
+    public static ServerStream stream(URL url) {
         return new ServerStream(url);
     }
 
+    public static AbstractServerStream newServerStream(URL url, boolean unary) 
{
+        return unary ? unary(url) : stream(url);
+    }
+
     private static ProviderModel lookupProviderModel(URL url) {
         FrameworkServiceRepository repo = 
ScopeModelUtil.getFrameworkModel(url.getScopeModel()).getServiceRepository();
         final ProviderModel model = 
repo.lookupExportedService(url.getServiceKey());
@@ -133,9 +136,6 @@ public abstract class AbstractServerStream extends 
AbstractStream implements Str
         for (HeaderFilter headerFilter : getHeaderFilters()) {
             inv = headerFilter.invoke(getInvoker(), inv);
         }
-        if (getCancellationContext() == null) {
-            setCancellationContext(RpcContext.getCancellationContext());
-        }
         return inv;
     }
 
@@ -234,4 +234,10 @@ public abstract class AbstractServerStream extends 
AbstractStream implements Str
     protected void cancelByRemoteReset(Http2Error http2Error) {
         getCancellationContext().cancel(null);
     }
+
+
+    @Override
+    protected void cancelByLocal(Throwable throwable) {
+        asTransportObserver().onReset(Http2Error.CANCEL);
+    }
 }
diff --git 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractStream.java
 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractStream.java
index ebb48e9..e1f6c74 100644
--- 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractStream.java
+++ 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/AbstractStream.java
@@ -26,6 +26,7 @@ import org.apache.dubbo.common.utils.StringUtils;
 import org.apache.dubbo.config.Constants;
 import org.apache.dubbo.remoting.exchange.Request;
 import org.apache.dubbo.rpc.CancellationContext;
+import org.apache.dubbo.rpc.RpcContext;
 import org.apache.dubbo.rpc.model.MethodDescriptor;
 import org.apache.dubbo.rpc.model.ServiceDescriptor;
 import org.apache.dubbo.rpc.protocol.tri.GrpcStatus.Code;
@@ -76,7 +77,7 @@ public abstract class AbstractStream implements Stream {
     private StreamObserver<Object> streamSubscriber;
     private TransportObserver transportSubscriber;
 
-    private CancellationContext cancellationContext;
+    private final CancellationContext cancellationContext;
     private boolean cancelled = false;
 
     public boolean isCancelled() {
@@ -91,18 +92,15 @@ public abstract class AbstractStream implements Stream {
         return cancellationContext;
     }
 
-    protected void setCancellationContext(CancellationContext 
cancellationContext) {
-        this.cancellationContext = cancellationContext;
-    }
-
     protected AbstractStream(URL url, Executor executor) {
         this.url = url;
         this.executor = executor;
         final String value = 
url.getParameter(Constants.MULTI_SERIALIZATION_KEY, 
CommonConstants.DEFAULT_KEY);
         this.multipleSerialization = 
url.getOrDefaultFrameworkModel().getExtensionLoader(MultipleSerialization.class)
                 .getExtension(value);
-        this.streamObserver = createStreamObserver();
+        this.cancellationContext = new CancellationContext();
         this.transportObserver = createTransportObserver();
+        this.streamObserver = createStreamObserver();
     }
 
     private static Executor allocateCallbackExecutor() {
@@ -142,8 +140,14 @@ public abstract class AbstractStream implements Stream {
      *
      * @param cause cancel case
      */
-    protected void cancel(Throwable cause) {
-        getCancellationContext().cancel(cause);
+    protected final void cancel(Throwable cause) {
+        cancel();
+        cancelByLocal(cause);
+    }
+
+    private void cancel() {
+        cancelled = true;
+        execute(RpcContext::removeCancellationContext);
     }
 
     /**
@@ -152,12 +156,14 @@ public abstract class AbstractStream implements Stream {
      * @param http2Error {@link Http2Error}
      */
     protected final void cancelByRemote(Http2Error http2Error) {
-        cancelled = true;
+        cancel();
         cancelByRemoteReset(http2Error);
     }
 
     protected abstract void cancelByRemoteReset(Http2Error http2Error);
 
+    protected abstract void cancelByLocal(Throwable throwable);
+
     protected abstract StreamObserver<Object> createStreamObserver();
 
     protected abstract TransportObserver createTransportObserver();
diff --git 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/CancelableStreamObserver.java
 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/CancelableStreamObserver.java
new file mode 100644
index 0000000..09eee4e
--- /dev/null
+++ 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/CancelableStreamObserver.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.dubbo.rpc.protocol.tri;
+
+import org.apache.dubbo.common.stream.StreamObserver;
+import org.apache.dubbo.rpc.CancellationContext;
+
+public abstract class CancelableStreamObserver<T> implements StreamObserver<T> 
{
+
+    private CancellationContext cancellationContext;
+
+    public CancellationContext getCancellationContext() {
+        return cancellationContext;
+    }
+
+    public void setCancellationContext(CancellationContext 
cancellationContext) {
+        this.cancellationContext = cancellationContext;
+    }
+
+    public final void cancel(Throwable throwable) {
+        if (cancellationContext == null) {
+            return;
+        }
+        cancellationContext.cancel(throwable);
+    }
+}
diff --git 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/ClientStream.java
 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/ClientStream.java
index c4242f0..f2502d3 100644
--- 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/ClientStream.java
+++ 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/ClientStream.java
@@ -29,7 +29,7 @@ public class ClientStream extends AbstractClientStream 
implements Stream {
 
     @Override
     protected StreamObserver<Object> createStreamObserver() {
-        return new ClientStreamObserver() {
+        ClientStreamObserver clientStreamObserver = new ClientStreamObserver() 
{
             boolean metaSent;
 
             @Override
@@ -48,6 +48,8 @@ public class ClientStream extends AbstractClientStream 
implements Stream {
                 transportError(throwable);
             }
         };
+        clientStreamObserver.setCancellationContext(getCancellationContext());
+        return clientStreamObserver;
     }
 
     @Override
diff --git 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/ServerStream.java
 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/ServerStream.java
index e7901ed..b4374f8 100644
--- 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/ServerStream.java
+++ 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/ServerStream.java
@@ -20,6 +20,7 @@ package org.apache.dubbo.rpc.protocol.tri;
 import org.apache.dubbo.common.URL;
 import org.apache.dubbo.common.stream.StreamObserver;
 import org.apache.dubbo.rpc.Result;
+import org.apache.dubbo.rpc.RpcContext;
 import org.apache.dubbo.rpc.RpcInvocation;
 import org.apache.dubbo.rpc.model.MethodDescriptor;
 
@@ -54,8 +55,8 @@ public class ServerStream extends AbstractServerStream 
implements Stream {
         @Override
         public void onError(Throwable throwable) {
             final GrpcStatus status = 
GrpcStatus.fromCode(GrpcStatus.Code.INTERNAL)
-                .withCause(throwable)
-                .withDescription("Biz exception");
+                    .withCause(throwable)
+                    .withDescription("Biz exception");
             transportError(status);
         }
 
@@ -70,20 +71,40 @@ public class ServerStream extends AbstractServerStream 
implements Stream {
 
     private class StreamTransportObserver extends AbstractTransportObserver 
implements TransportObserver {
 
+        /**
+         * for server stream the method only save header
+         * <p>
+         * for bi stream run api impl code and put observer to streamSubscriber
+         *
+         * <pre class="code">
+         * public StreamObserver<GreeterRequest> 
biStream(StreamObserver<GreeterReply> replyStream) {
+         *      // happen on this
+         *      // you can add cancel listener on use {@link 
RpcContext#getCancellationContext()}
+         *      return new StreamObserver<GreeterRequest>() {
+         *          // ...
+         *      };
+         * }
+         * </pre>
+         */
         @Override
         public void onMetadata(Metadata metadata, boolean endStream) {
             super.onMetadata(metadata, endStream);
             if (getMethodDescriptor().getRpcType() == 
MethodDescriptor.RpcType.SERVER_STREAM) {
                 return;
             }
-            final RpcInvocation inv = buildInvocation(metadata);
-            inv.setArguments(new Object[]{asStreamObserver()});
-            final Result result = getInvoker().invoke(inv);
             try {
-                subscribe((StreamObserver<Object>) result.getValue());
-            } catch (Throwable t) {
-                transportError(GrpcStatus.fromCode(GrpcStatus.Code.INTERNAL)
-                    .withDescription("Failed to create server's observer"));
+                
RpcContext.restoreCancellationContext(getCancellationContext());
+                final RpcInvocation inv = buildInvocation(metadata);
+                inv.setArguments(new Object[]{asStreamObserver()});
+                final Result result = getInvoker().invoke(inv);
+                try {
+                    subscribe((StreamObserver<Object>) result.getValue());
+                } catch (Throwable t) {
+                    
transportError(GrpcStatus.fromCode(GrpcStatus.Code.INTERNAL)
+                            .withDescription("Failed to create server's 
observer"));
+                }
+            } finally {
+                RpcContext.removeCancellationContext();
             }
         }
 
@@ -91,25 +112,57 @@ public class ServerStream extends AbstractServerStream 
implements Stream {
         public void onData(byte[] in, boolean endStream) {
             try {
                 if (getMethodDescriptor().getRpcType() == 
MethodDescriptor.RpcType.SERVER_STREAM) {
-                    RpcInvocation inv = buildInvocation(getHeaders());
-                    final Object[] arguments = deserializeRequest(in);
-                    if (arguments != null) {
-                        inv.setArguments(new Object[]{arguments[0], 
asStreamObserver()});
-                        getInvoker().invoke(inv);
-                    }
-                } else {
-                    final Object[] arguments = deserializeRequest(in);
-                    if (arguments != null) {
-                        getStreamSubscriber().onNext(arguments[0]);
-                    }
+                    serverStreamOnData(in);
+                    return;
                 }
+                biStreamOnData(in);
             } catch (Throwable t) {
                 transportError(GrpcStatus.fromCode(GrpcStatus.Code.INTERNAL)
-                    .withDescription("Deserialize request failed")
-                    .withCause(t));
+                        .withDescription("Deserialize request failed")
+                        .withCause(t));
+            }
+        }
+
+        /**
+         * call observer onNext
+         */
+        private void biStreamOnData(byte[] in) {
+            final Object[] arguments = deserializeRequest(in);
+            if (arguments != null) {
+                getStreamSubscriber().onNext(arguments[0]);
+            }
+        }
+
+        /**
+         * call api impl code
+         *
+         * <pre class="code">
+         * public void cancelServerStream(GreeterRequest request, 
StreamObserver<GreeterReply> replyStream) {
+         *      // happen on this
+         *      // you can add cancel listener on use {@link 
RpcContext#getCancellationContext()}
+         *      // if you want listener cancel,plz do not call onCompleted()
+         *     }
+         * </pre>
+         */
+        private void serverStreamOnData(byte[] in) {
+            try {
+                
RpcContext.restoreCancellationContext(getCancellationContext());
+                RpcInvocation inv = buildInvocation(getHeaders());
+                final Object[] arguments = deserializeRequest(in);
+                if (arguments != null) {
+                    inv.setArguments(new Object[]{arguments[0], 
asStreamObserver()});
+                    getInvoker().invoke(inv);
+                }
+            } finally {
+                RpcContext.removeCancellationContext();
             }
         }
 
+        /**
+         * for server stream the method do nothing
+         * <p>
+         * for bi stream call onCompleted
+         */
         @Override
         public void onComplete() {
             if (getMethodDescriptor().getRpcType() == 
MethodDescriptor.RpcType.SERVER_STREAM) {
diff --git 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleClientRequestHandler.java
 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleClientRequestHandler.java
index 2254343..0149799 100644
--- 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleClientRequestHandler.java
+++ 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleClientRequestHandler.java
@@ -36,7 +36,6 @@ import org.apache.dubbo.rpc.model.MethodDescriptor;
 import io.netty.channel.ChannelDuplexHandler;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelPromise;
-import io.netty.handler.codec.http2.Http2Error;
 
 import java.util.Arrays;
 import java.util.List;
@@ -64,33 +63,22 @@ public class TripleClientRequestHandler extends 
ChannelDuplexHandler {
         final URL url = inv.getInvoker().getUrl();
         ConsumerModel consumerModel = inv.getServiceModel() != null ? 
(ConsumerModel) inv.getServiceModel() : (ConsumerModel) url.getServiceModel();
 
-        MethodDescriptor methodDescriptor = 
getTriMethodDescriptor(consumerModel,inv);
+        MethodDescriptor methodDescriptor = 
getTriMethodDescriptor(consumerModel, inv);
 
         ClassLoadUtil.switchContextLoader(consumerModel.getClassLoader());
-        AbstractClientStream stream;
-        if (methodDescriptor.isUnary()) {
-            stream = AbstractClientStream.unary(url);
-        } else {
-            stream = AbstractClientStream.stream(url);
-        }
-        final CancellationContext cancellationContext = 
inv.getCancellationContext();
-        // for client cancel,send rst frame to server
-        cancellationContext.addListener(context -> {
-            stream.asTransportObserver().onReset(Http2Error.CANCEL);;
-        });
-        stream.setCancellationContext(cancellationContext);
+        final AbstractClientStream stream = 
AbstractClientStream.newClientStream(url, methodDescriptor.isUnary());
 
         String ssl = url.getParameter(CommonConstants.SSL_ENABLED_KEY);
         if (StringUtils.isNotEmpty(ssl)) {
             
ctx.channel().attr(TripleConstant.SSL_ATTRIBUTE_KEY).set(Boolean.parseBoolean(ssl));
         }
         stream.service(consumerModel)
-            .connection(Connection.getConnectionFromChannel(ctx.channel()))
-            .method(methodDescriptor)
-            .methodName(methodDescriptor.getMethodName())
-            .request(req)
-            .serialize((String) 
inv.getObjectAttachment(Constants.SERIALIZATION_KEY))
-            .subscribe(new ClientTransportObserver(ctx, stream, promise));
+                .connection(Connection.getConnectionFromChannel(ctx.channel()))
+                .method(methodDescriptor)
+                .methodName(methodDescriptor.getMethodName())
+                .request(req)
+                .serialize((String) 
inv.getObjectAttachment(Constants.SERIALIZATION_KEY))
+                .subscribe(new ClientTransportObserver(ctx, stream, promise));
 
         if (methodDescriptor.isUnary()) {
             stream.asStreamObserver().onNext(inv);
@@ -100,13 +88,15 @@ public class TripleClientRequestHandler extends 
ChannelDuplexHandler {
             AppResponse result;
             // the stream method params is fixed
             if (methodDescriptor.getRpcType() == 
MethodDescriptor.RpcType.BIDIRECTIONAL_STREAM
-                || methodDescriptor.getRpcType() == 
MethodDescriptor.RpcType.CLIENT_STREAM) {
-                final StreamObserver<Object> streamObserver = 
(StreamObserver<Object>) inv.getArguments()[0];
-                stream.subscribe(streamObserver);
+                    || methodDescriptor.getRpcType() == 
MethodDescriptor.RpcType.CLIENT_STREAM) {
+                StreamObserver<Object> obServer = (StreamObserver<Object>) 
inv.getArguments()[0];
+                obServer = attachCancelContext(obServer, 
stream.getCancellationContext());
+                stream.subscribe(obServer);
                 result = new AppResponse(stream.asStreamObserver());
             } else {
-                final StreamObserver<Object> streamObserver = 
(StreamObserver<Object>) inv.getArguments()[1];
-                stream.subscribe(streamObserver);
+                StreamObserver<Object> obServer = (StreamObserver<Object>) 
inv.getArguments()[1];
+                obServer = attachCancelContext(obServer, 
stream.getCancellationContext());
+                stream.subscribe(obServer);
                 result = new AppResponse();
                 stream.asStreamObserver().onNext(inv.getArguments()[0]);
                 stream.asStreamObserver().onCompleted();
@@ -117,7 +107,7 @@ public class TripleClientRequestHandler extends 
ChannelDuplexHandler {
     }
 
     /**
-     * Get the trI protocol special MethodDescriptor
+     * Get the tri protocol special MethodDescriptor
      */
     private MethodDescriptor getTriMethodDescriptor(ConsumerModel 
consumerModel, RpcInvocation inv) {
         List<MethodDescriptor> methodDescriptors = 
consumerModel.getServiceModel().getMethods(inv.getMethodName());
@@ -131,4 +121,14 @@ public class TripleClientRequestHandler extends 
ChannelDuplexHandler {
         }
         throw new IllegalStateException("methodDescriptors must not be null 
method=" + inv.getMethodName());
     }
+
+
+    public <T> StreamObserver<T> attachCancelContext(StreamObserver<T> 
observer, CancellationContext context) {
+        if (observer instanceof CancelableStreamObserver) {
+            CancelableStreamObserver<T> streamObserver = 
((CancelableStreamObserver<T>) observer);
+            streamObserver.setCancellationContext(context);
+            return streamObserver;
+        }
+        return observer;
+    }
 }
diff --git 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleHttp2FrameServerHandler.java
 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleHttp2FrameServerHandler.java
index e3d0421..2c86a95 100644
--- 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleHttp2FrameServerHandler.java
+++ 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleHttp2FrameServerHandler.java
@@ -208,12 +208,9 @@ public class TripleHttp2FrameServerHandler extends 
ChannelDuplexHandler {
                 methodDescriptor = methodDescriptors.get(0);
             }
         }
-        final AbstractServerStream stream;
-        if (methodDescriptor != null && methodDescriptor.isStream()) {
-            stream = AbstractServerStream.stream(invoker.getUrl());
-        } else {
-            stream = AbstractServerStream.unary(invoker.getUrl());
-        }
+        boolean isUnary = methodDescriptor != null && 
methodDescriptor.isUnary();
+        final AbstractServerStream stream = 
AbstractServerStream.newServerStream(invoker.getUrl(), isUnary);
+
         Channel channel = ctx.channel();
         // You can add listeners to ChannelPromise here if you want to listen 
for the result of sending a frame
         stream.service(providerModel.getServiceModel())
diff --git 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleInvoker.java
 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleInvoker.java
index e0afab2..2011a04 100644
--- 
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleInvoker.java
+++ 
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TripleInvoker.java
@@ -29,7 +29,6 @@ import org.apache.dubbo.remoting.exchange.Response;
 import org.apache.dubbo.remoting.exchange.support.DefaultFuture2;
 import org.apache.dubbo.rpc.AppResponse;
 import org.apache.dubbo.rpc.AsyncRpcResult;
-import org.apache.dubbo.rpc.CancellationContext;
 import org.apache.dubbo.rpc.FutureContext;
 import org.apache.dubbo.rpc.Invocation;
 import org.apache.dubbo.rpc.Invoker;
@@ -80,10 +79,6 @@ public class TripleInvoker<T> extends AbstractInvoker<T> {
     protected Result doInvoke(final Invocation invocation) throws Throwable {
         RpcInvocation inv = (RpcInvocation) invocation;
 
-        // set cancel context to RpcInvocation to transport to stream
-        final CancellationContext cancellationContext = 
RpcContext.getCancellationContext();
-        inv.setCancellationContext(cancellationContext);
-
         final String methodName = RpcUtils.getMethodName(invocation);
         
inv.setServiceModel(RpcContext.getServiceContext().getConsumerUrl().getServiceModel());
         inv.setAttachment(PATH_KEY, getUrl().getPath());

Reply via email to