This is an automated email from the ASF dual-hosted git repository.
earthchen pushed a commit to branch 3.3
in repository https://gitbox.apache.org/repos/asf/dubbo.git
The following commit(s) were added to refs/heads/3.3 by this push:
new 1f6d441de6 [Feature 3.3] Triple Rest Cors Support (#14073)
1f6d441de6 is described below
commit 1f6d441de6a3de7c9b78868ff1b9f94dc213ebcf
Author: Rawven <[email protected]>
AuthorDate: Thu May 23 11:05:17 2024 +0800
[Feature 3.3] Triple Rest Cors Support (#14073)
* feat(): add cors key
* feat(): add base cors class
* feat(): add cors class in rpc-triple and rest-spring
* feat(): add cors class in rpc-triple and rest-spring
* test(): add rpc-triple cors test
* fix(): fix CorsUtil bug
* fix(): fix CorsUtil bug
* fix(): fix objectUtil bug
* fix(): fix corsmeta set bug
* fix(): fix config load fail bug
* fix(): option method can not be look fail
* fix(): CorsMeta method will null
* fix(): request-header not set will fail
* refactor(): improve CorsMeta CorsProcess some code
* fix(): coreMeta combine priority
* test(): remove rest-spring cors test to sample
* docs(): add docs
* revert(): test version
* fix(): getCorsMeta can be null
* fix(): combine can be null
* fix(): save option and vary bug
* fix(): pom version
* fix(): spring version will cause allowPrivateWork resolve error
* fix(): ci
* refactor(): delete useless code
* refactor(): accept some sonarcloud issue
* refactor(): add @Nullable to point the CorsMeta Attributes
* refactor(): style
* fix(): fix prelight logic
* fix(): remove credential & privateNetWork report
* refactor(): Move globalMetaMerge in RequestMappingResolver
* refactor(): use array replace corsConfig string
* refactor(): move CorsProcessor to CorsHeaderFilterAdapter
* fix(): fix unit test
* fix(): fix test failure
* fix(): delete useless param
* fix(): fix sonarcloud
* fix(): fix wrong class place & naming
* fix(): fix wrong static global corsMeta
* fix(): refactor CorsUtil from sonar issue
* feat(rest): refine cors support
* feat(rest): refine cors support
* feat(rest): refine cors support bugfix
* fix(): getBoolean will throw exception when null
* fix(rest-spring): fix crossOrigin allowCredentials is string
* fix(): fix globalCorsMeta load null
* fix(): fix vary header bug
* fix(): fix unit test && Fix cors specification
* fix(): fix pom
* fix(): fix combine bug
* fix(): fix some sonar issue
* fix(): fix style
* feat(rest): refine cors support
* fix(): fix style
* fix(): fix needed sonar issue
* refactor(): refactor CorsMeta.combine() and add comment
* fix(): Replenish license
* fix(): update test
* test(): Refactor the test class and add credential test cases
* test(): Refactor the test class and add credential test cases
* fix(rest): revert api HeaderFilter
* fix(): accept sonar issue
---------
Co-authored-by: Sean Yang <[email protected]>
Co-authored-by: earthchen <[email protected]>
---
.licenserc.yaml | 1 +
.../java/org/apache/dubbo/config/CorsConfig.java | 117 ++++++
.../java/org/apache/dubbo/config/RestConfig.java | 16 +
.../support/jaxrs/JaxrsRequestMappingResolver.java | 9 +
.../spring/SpringMvcRequestMappingResolver.java | 38 +-
.../spring/compatible/SpringDemoServiceImpl.java | 2 +
.../http12/AbstractServerHttpChannelObserver.java | 1 +
.../apache/dubbo/remoting/http12/HttpStatus.java | 1 +
.../tri/h12/AbstractServerTransportListener.java | 4 +-
.../dubbo/rpc/protocol/tri/rest/RestConstants.java | 9 +
.../protocol/tri/rest/cors/CorsHeaderFilter.java | 273 +++++++++++++
.../rpc/protocol/tri/rest/cors/CorsUtils.java | 53 +++
.../tri/rest/filter/RestHeaderFilterAdapter.java | 8 +-
.../protocol/tri/rest/mapping/RequestMapping.java | 55 ++-
.../rest/mapping/condition/MethodsCondition.java | 7 +
.../protocol/tri/rest/mapping/meta/CorsMeta.java | 306 +++++++++++++++
.../internal/org.apache.dubbo.rpc.HeaderFilter | 1 +
.../tri/rest/cors/CorsHeaderFilterTest.java | 423 +++++++++++++++++++++
pom.xml | 2 +
19 files changed, 1301 insertions(+), 25 deletions(-)
diff --git a/.licenserc.yaml b/.licenserc.yaml
index e497f4b10e..98957a25f7 100644
--- a/.licenserc.yaml
+++ b/.licenserc.yaml
@@ -79,6 +79,7 @@ header:
-
'dubbo-config/dubbo-config-spring/src/test/java/org/apache/dubbo/config/spring/EmbeddedZooKeeper.java'
-
'dubbo-test/dubbo-test-common/src/main/java/org/apache/dubbo/test/common/utils/TestSocketUtils.java'
-
'dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/TriHttp2RemoteFlowController.java'
+ -
'dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsHeaderFilter.java'
-
'dubbo-common/src/main/java/org/apache/dubbo/common/threadpool/serial/SerializingExecutor.java'
-
'dubbo-maven-plugin/src/main/java/org/apache/dubbo/maven/plugin/aot/AbstractAotMojo.java'
-
'dubbo-maven-plugin/src/main/java/org/apache/dubbo/maven/plugin/aot/AbstractDependencyFilterMojo.java'
diff --git a/dubbo-common/src/main/java/org/apache/dubbo/config/CorsConfig.java
b/dubbo-common/src/main/java/org/apache/dubbo/config/CorsConfig.java
new file mode 100644
index 0000000000..657f20d039
--- /dev/null
+++ b/dubbo-common/src/main/java/org/apache/dubbo/config/CorsConfig.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.dubbo.config;
+
+import java.io.Serializable;
+
+public class CorsConfig implements Serializable {
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * A list of origins for which cross-origin requests are allowed. Values
may be a specific domain, e.g.
+ * {@code "https://domain1.com"}, or the CORS defined special value {@code
"*"} for all origins.
+ * <p>By default this is not set which means that no origins are allowed.
+ * However, an instance of this class is often initialized further, e.g.
for {@code @CrossOrigin}, via
+ * {@code
org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.CorsMeta.Builder#applyDefault()}.
+ */
+ private String[] allowedOrigins;
+
+ /**
+ * Set the HTTP methods to allow, e.g. {@code "GET"}, {@code "POST"},
+ * {@code "PUT"}, etc. The special value {@code "*"} allows all methods.
+ * <p>If not set, only {@code "GET"} and {@code "HEAD"} are allowed.
+ * <p>By default this is not set.
+ */
+ private String[] allowedMethods;
+
+ /**
+ * /**
+ * Set the list of headers that a pre-flight request can list as allowed
+ * for use during an actual request. The special value {@code "*"} allows
+ * actual requests to send any header.
+ * <p>By default this is not set.
+ */
+ private String[] allowedHeaders;
+
+ /**
+ * Set the list of response headers that an actual response might have
+ * and can be exposed to the client. The special value {@code "*"}
+ * allows all headers to be exposed.
+ * <p>By default this is not set.
+ */
+ private String[] exposedHeaders;
+
+ /**
+ * Whether user credentials are supported.
+ * <p>By default this is not set (i.e. user credentials are not supported).
+ */
+ private Boolean allowCredentials;
+
+ /**
+ * Configure how long, as a duration, the response from a pre-flight
request
+ * can be cached by clients.
+ */
+ private Long maxAge;
+
+ public String[] getAllowedOrigins() {
+ return allowedOrigins;
+ }
+
+ public void setAllowedOrigins(String[] allowedOrigins) {
+ this.allowedOrigins = allowedOrigins;
+ }
+
+ public String[] getAllowedMethods() {
+ return allowedMethods;
+ }
+
+ public void setAllowedMethods(String[] allowedMethods) {
+ this.allowedMethods = allowedMethods;
+ }
+
+ public String[] getAllowedHeaders() {
+ return allowedHeaders;
+ }
+
+ public void setAllowedHeaders(String[] allowedHeaders) {
+ this.allowedHeaders = allowedHeaders;
+ }
+
+ public String[] getExposedHeaders() {
+ return exposedHeaders;
+ }
+
+ public void setExposedHeaders(String[] exposedHeaders) {
+ this.exposedHeaders = exposedHeaders;
+ }
+
+ public Boolean getAllowCredentials() {
+ return allowCredentials;
+ }
+
+ public void setAllowCredentials(Boolean allowCredentials) {
+ this.allowCredentials = allowCredentials;
+ }
+
+ public Long getMaxAge() {
+ return maxAge;
+ }
+
+ public void setMaxAge(Long maxAge) {
+ this.maxAge = maxAge;
+ }
+}
diff --git a/dubbo-common/src/main/java/org/apache/dubbo/config/RestConfig.java
b/dubbo-common/src/main/java/org/apache/dubbo/config/RestConfig.java
index 6a3674e119..dc7e988928 100644
--- a/dubbo-common/src/main/java/org/apache/dubbo/config/RestConfig.java
+++ b/dubbo-common/src/main/java/org/apache/dubbo/config/RestConfig.java
@@ -16,6 +16,8 @@
*/
package org.apache.dubbo.config;
+import org.apache.dubbo.config.support.Nested;
+
import java.io.Serializable;
/**
@@ -68,6 +70,12 @@ public class RestConfig implements Serializable {
*/
private String formatParameterName;
+ /**
+ * The config is used to set the Global CORS configuration properties.
+ */
+ @Nested
+ private CorsConfig cors;
+
public Integer getMaxBodySize() {
return maxBodySize;
}
@@ -115,4 +123,12 @@ public class RestConfig implements Serializable {
public void setFormatParameterName(String formatParameterName) {
this.formatParameterName = formatParameterName;
}
+
+ public CorsConfig getCors() {
+ return cors;
+ }
+
+ public void setCors(CorsConfig cors) {
+ this.cors = cors;
+ }
}
diff --git
a/dubbo-plugin/dubbo-rest-jaxrs/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/support/jaxrs/JaxrsRequestMappingResolver.java
b/dubbo-plugin/dubbo-rest-jaxrs/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/support/jaxrs/JaxrsRequestMappingResolver.java
index 9d7dc5a143..72cb325351 100644
---
a/dubbo-plugin/dubbo-rest-jaxrs/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/support/jaxrs/JaxrsRequestMappingResolver.java
+++
b/dubbo-plugin/dubbo-rest-jaxrs/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/support/jaxrs/JaxrsRequestMappingResolver.java
@@ -18,12 +18,14 @@ package
org.apache.dubbo.rpc.protocol.tri.rest.support.jaxrs;
import org.apache.dubbo.common.extension.Activate;
import org.apache.dubbo.rpc.model.FrameworkModel;
+import org.apache.dubbo.rpc.protocol.tri.rest.cors.CorsUtils;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RequestMapping;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RequestMapping.Builder;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RequestMappingResolver;
import
org.apache.dubbo.rpc.protocol.tri.rest.mapping.condition.ServiceVersionCondition;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.AnnotationMeta;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.AnnotationSupport;
+import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.CorsMeta;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.MethodMeta;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.ServiceMeta;
import org.apache.dubbo.rpc.protocol.tri.rest.util.RestToolKit;
@@ -31,9 +33,12 @@ import
org.apache.dubbo.rpc.protocol.tri.rest.util.RestToolKit;
@Activate(onClass = "javax.ws.rs.Path")
public class JaxrsRequestMappingResolver implements RequestMappingResolver {
+ private final FrameworkModel frameworkModel;
private final RestToolKit toolKit;
+ private CorsMeta globalCorsMeta;
public JaxrsRequestMappingResolver(FrameworkModel frameworkModel) {
+ this.frameworkModel = frameworkModel;
toolKit = new JaxrsRestToolKit(frameworkModel);
}
@@ -65,10 +70,14 @@ public class JaxrsRequestMappingResolver implements
RequestMappingResolver {
return null;
}
ServiceMeta serviceMeta = methodMeta.getServiceMeta();
+ if (globalCorsMeta == null) {
+ globalCorsMeta = CorsUtils.getGlobalCorsMeta(frameworkModel);
+ }
return builder(methodMeta, path, httpMethod)
.name(methodMeta.getMethod().getName())
.contextPath(methodMeta.getServiceMeta().getContextPath())
.custom(new
ServiceVersionCondition(serviceMeta.getServiceGroup(),
serviceMeta.getServiceVersion()))
+ .cors(globalCorsMeta)
.build();
}
diff --git
a/dubbo-plugin/dubbo-rest-spring/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/support/spring/SpringMvcRequestMappingResolver.java
b/dubbo-plugin/dubbo-rest-spring/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/support/spring/SpringMvcRequestMappingResolver.java
index ca4dd68a3b..e6de4f568b 100644
---
a/dubbo-plugin/dubbo-rest-spring/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/support/spring/SpringMvcRequestMappingResolver.java
+++
b/dubbo-plugin/dubbo-rest-spring/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/support/spring/SpringMvcRequestMappingResolver.java
@@ -16,14 +16,17 @@
*/
package org.apache.dubbo.rpc.protocol.tri.rest.support.spring;
+import org.apache.dubbo.common.constants.CommonConstants;
import org.apache.dubbo.common.extension.Activate;
import org.apache.dubbo.common.utils.StringUtils;
import org.apache.dubbo.rpc.model.FrameworkModel;
+import org.apache.dubbo.rpc.protocol.tri.rest.cors.CorsUtils;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RequestMapping;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RequestMapping.Builder;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RequestMappingResolver;
import
org.apache.dubbo.rpc.protocol.tri.rest.mapping.condition.ServiceVersionCondition;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.AnnotationMeta;
+import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.CorsMeta;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.MethodMeta;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.ServiceMeta;
import org.apache.dubbo.rpc.protocol.tri.rest.util.RestToolKit;
@@ -35,6 +38,7 @@ public class SpringMvcRequestMappingResolver implements
RequestMappingResolver {
private final FrameworkModel frameworkModel;
private volatile RestToolKit toolKit;
+ private CorsMeta globalCorsMeta;
public SpringMvcRequestMappingResolver(FrameworkModel frameworkModel) {
this.frameworkModel = frameworkModel;
@@ -62,9 +66,13 @@ public class SpringMvcRequestMappingResolver implements
RequestMappingResolver {
return null;
}
AnnotationMeta<?> responseStatus =
serviceMeta.findMergedAnnotation(Annotations.ResponseStatus);
+ AnnotationMeta<?> crossOrigin =
serviceMeta.findMergedAnnotation(Annotations.CrossOrigin);
+ String[] methods = requestMapping.getStringArray("method");
return builder(requestMapping, responseStatus)
+ .method(methods)
.name(serviceMeta.getType().getSimpleName())
.contextPath(serviceMeta.getContextPath())
+ .cors(buildCorsMeta(crossOrigin, methods))
.build();
}
@@ -80,10 +88,14 @@ public class SpringMvcRequestMappingResolver implements
RequestMappingResolver {
}
ServiceMeta serviceMeta = methodMeta.getServiceMeta();
AnnotationMeta<?> responseStatus =
methodMeta.findMergedAnnotation(Annotations.ResponseStatus);
+ AnnotationMeta<?> crossOrigin =
methodMeta.findMergedAnnotation(Annotations.CrossOrigin);
+ String[] methods = requestMapping.getStringArray("method");
return builder(requestMapping, responseStatus)
+ .method(methods)
.name(methodMeta.getMethod().getName())
.contextPath(serviceMeta.getContextPath())
.custom(new
ServiceVersionCondition(serviceMeta.getServiceGroup(),
serviceMeta.getServiceVersion()))
+ .cors(buildCorsMeta(crossOrigin, methods))
.build();
}
@@ -98,10 +110,34 @@ public class SpringMvcRequestMappingResolver implements
RequestMappingResolver {
}
}
return builder.path(requestMapping.getValueArray())
- .method(requestMapping.getStringArray("method"))
.param(requestMapping.getStringArray("params"))
.header(requestMapping.getStringArray("headers"))
.consume(requestMapping.getStringArray("consumes"))
.produce(requestMapping.getStringArray("produces"));
}
+
+ private CorsMeta buildCorsMeta(AnnotationMeta<?> crossOrigin, String[]
methods) {
+ if (globalCorsMeta == null) {
+ globalCorsMeta = CorsUtils.getGlobalCorsMeta(frameworkModel);
+ }
+ if (crossOrigin == null) {
+ return globalCorsMeta;
+ }
+ String[] allowedMethods = crossOrigin.getStringArray("methods");
+ if (allowedMethods.length == 0) {
+ allowedMethods = methods;
+ if (allowedMethods.length == 0) {
+ allowedMethods = new String[] {CommonConstants.ANY_VALUE};
+ }
+ }
+ CorsMeta corsMeta = CorsMeta.builder()
+ .allowedOrigins(crossOrigin.getStringArray("origins"))
+ .allowedMethods(allowedMethods)
+ .allowedHeaders(crossOrigin.getStringArray("allowedHeaders"))
+ .exposedHeaders(crossOrigin.getStringArray("exposedHeaders"))
+ .allowCredentials(crossOrigin.getString("allowCredentials"))
+ .maxAge(crossOrigin.getNumber("maxAge"))
+ .build();
+ return globalCorsMeta.combine(corsMeta);
+ }
}
diff --git
a/dubbo-plugin/dubbo-rest-spring/src/test/java/org/apache/dubbo/rpc/protocol/tri/rest/support/spring/compatible/SpringDemoServiceImpl.java
b/dubbo-plugin/dubbo-rest-spring/src/test/java/org/apache/dubbo/rpc/protocol/tri/rest/support/spring/compatible/SpringDemoServiceImpl.java
index bf235fed57..46498e5162 100644
---
a/dubbo-plugin/dubbo-rest-spring/src/test/java/org/apache/dubbo/rpc/protocol/tri/rest/support/spring/compatible/SpringDemoServiceImpl.java
+++
b/dubbo-plugin/dubbo-rest-spring/src/test/java/org/apache/dubbo/rpc/protocol/tri/rest/support/spring/compatible/SpringDemoServiceImpl.java
@@ -23,8 +23,10 @@ import java.util.List;
import java.util.Map;
import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.ExceptionHandler;
+@CrossOrigin
public class SpringDemoServiceImpl implements SpringRestDemoService {
private static Map<String, Object> context;
private boolean called;
diff --git
a/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/AbstractServerHttpChannelObserver.java
b/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/AbstractServerHttpChannelObserver.java
index 9883a8072e..26ce8b705e 100644
---
a/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/AbstractServerHttpChannelObserver.java
+++
b/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/AbstractServerHttpChannelObserver.java
@@ -87,6 +87,7 @@ public abstract class AbstractServerHttpChannelObserver
implements CustomizableH
public final void onError(Throwable throwable) {
if (throwable instanceof HttpResultPayloadException) {
onNext(((HttpResultPayloadException) throwable).getResult());
+ doOnCompleted(null);
return;
}
try {
diff --git
a/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/HttpStatus.java
b/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/HttpStatus.java
index f91b628028..f28eb9422f 100644
---
a/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/HttpStatus.java
+++
b/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/HttpStatus.java
@@ -20,6 +20,7 @@ public enum HttpStatus {
OK(200),
CREATED(201),
ACCEPTED(202),
+ NO_CONTENT(204),
FOUND(302),
BAD_REQUEST(400),
UNAUTHORIZED(401),
diff --git
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/h12/AbstractServerTransportListener.java
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/h12/AbstractServerTransportListener.java
index 970aa88fa7..fc0961fdc7 100644
---
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/h12/AbstractServerTransportListener.java
+++
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/h12/AbstractServerTransportListener.java
@@ -162,10 +162,10 @@ public abstract class
AbstractServerTransportListener<HEADER extends RequestMeta
protected void logError(Throwable t) {
if (t instanceof HttpStatusException) {
HttpStatusException e = (HttpStatusException) t;
- if (e.getStatusCode() ==
HttpStatus.INTERNAL_SERVER_ERROR.getCode()) {
+ if (e.getStatusCode() >= HttpStatus.BAD_REQUEST.getCode()) {
LOGGER.debug("http status exception", e);
- return;
}
+ return;
}
LOGGER.error(INTERNAL_ERROR, "", "", "server internal error", t);
}
diff --git
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/RestConstants.java
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/RestConstants.java
index 7a431bd02a..7c512278bb 100644
---
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/RestConstants.java
+++
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/RestConstants.java
@@ -57,5 +57,14 @@ public final class RestConstants {
public static final String CASE_SENSITIVE_MATCH_KEY = CONFIG_PREFIX +
"case-sensitive-match";
public static final String FORMAT_PARAMETER_NAME_KEY = CONFIG_PREFIX +
"format-parameter-name";
+ /* Cors Configuration Key */
+ public static final String CORS_CONFIG_PREFIX = CONFIG_PREFIX + "cors.";
+ public static final String ALLOWED_ORIGINS = CORS_CONFIG_PREFIX +
"allowed-origins";
+ public static final String ALLOWED_METHODS = CORS_CONFIG_PREFIX +
"allowed-methods";
+ public static final String ALLOWED_HEADERS = CORS_CONFIG_PREFIX +
"allowed-headers";
+ public static final String ALLOW_CREDENTIALS = CORS_CONFIG_PREFIX +
"allow-credentials";
+ public static final String EXPOSED_HEADERS = CORS_CONFIG_PREFIX +
"exposed-headers";
+ public static final String MAX_AGE = CORS_CONFIG_PREFIX + "max-age";
+
private RestConstants() {}
}
diff --git
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsHeaderFilter.java
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsHeaderFilter.java
new file mode 100644
index 0000000000..f3dbe3eb62
--- /dev/null
+++
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsHeaderFilter.java
@@ -0,0 +1,273 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.dubbo.rpc.protocol.tri.rest.cors;
+
+import org.apache.dubbo.common.constants.CommonConstants;
+import org.apache.dubbo.common.extension.Activate;
+import org.apache.dubbo.common.utils.ArrayUtils;
+import org.apache.dubbo.common.utils.StringUtils;
+import org.apache.dubbo.remoting.http12.HttpMethods;
+import org.apache.dubbo.remoting.http12.HttpRequest;
+import org.apache.dubbo.remoting.http12.HttpResponse;
+import org.apache.dubbo.remoting.http12.HttpResult;
+import org.apache.dubbo.remoting.http12.HttpStatus;
+import org.apache.dubbo.remoting.http12.exception.HttpResultPayloadException;
+import org.apache.dubbo.rpc.Invoker;
+import org.apache.dubbo.rpc.RpcException;
+import org.apache.dubbo.rpc.RpcInvocation;
+import org.apache.dubbo.rpc.protocol.tri.rest.RestConstants;
+import org.apache.dubbo.rpc.protocol.tri.rest.filter.RestHeaderFilterAdapter;
+import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RequestMapping;
+import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.CorsMeta;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.regex.Pattern;
+
+import static org.apache.dubbo.common.constants.CommonConstants.ANY_VALUE;
+
+/**
+ * See: <a
href="https://github.com/spring-projects/spring-framework/blob/main/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java">DefaultCorsProcessor</a>
+ */
+@Activate(group = CommonConstants.PROVIDER, order = 1000)
+public class CorsHeaderFilter extends RestHeaderFilterAdapter {
+
+ public static final String VARY = "Vary";
+ public static final String ORIGIN = "Origin";
+ public static final String ACCESS_CONTROL_REQUEST_METHOD =
"Access-Control-Request-Method";
+ public static final String ACCESS_CONTROL_REQUEST_HEADERS =
"Access-Control-Request-Headers";
+ public static final String ACCESS_CONTROL_ALLOW_CREDENTIALS =
"Access-Control-Allow-Credentials";
+ public static final String ACCESS_CONTROL_EXPOSE_HEADERS =
"Access-Control-Expose-Headers";
+ public static final String ACCESS_CONTROL_MAX_AGE =
"Access-Control-Max-Age";
+ public static final String ACCESS_CONTROL_ALLOW_ORIGIN =
"Access-Control-Allow-Origin";
+ public static final String ACCESS_CONTROL_ALLOW_METHODS =
"Access-Control-Allow-Methods";
+ public static final String ACCESS_CONTROL_ALLOW_HEADERS =
"Access-Control-Allow-Headers";
+ public static final String SEP = ", ";
+
+ @Override
+ protected void invoke(Invoker<?> invoker, RpcInvocation invocation,
HttpRequest request, HttpResponse response)
+ throws RpcException {
+ RequestMapping mapping =
request.attribute(RestConstants.MAPPING_ATTRIBUTE);
+ CorsMeta cors = mapping.getCors();
+ String origin = request.header(ORIGIN);
+ if (cors == null) {
+ if (isPreFlightRequest(request, origin)) {
+ throw new HttpResultPayloadException(HttpResult.builder()
+ .status(HttpStatus.FORBIDDEN)
+ .body("Invalid CORS request")
+ .build());
+ }
+ return;
+ }
+
+ if (process(cors, request, response)) {
+ return;
+ }
+
+ throw new HttpResultPayloadException(HttpResult.builder()
+ .status(HttpStatus.FORBIDDEN)
+ .body("Invalid CORS request")
+ .headers(response.headers())
+ .build());
+ }
+
+ private boolean process(CorsMeta cors, HttpRequest request, HttpResponse
response) {
+ setVaryHeader(response);
+
+ String origin = request.header(ORIGIN);
+ if (isNotCorsRequest(request, origin)) {
+ return true;
+ }
+
+ if (response.header(ACCESS_CONTROL_ALLOW_ORIGIN) != null) {
+ return true;
+ }
+
+ String allowOrigin = checkOrigin(cors, origin);
+ if (allowOrigin == null) {
+ return false;
+ }
+
+ boolean preFlight = isPreFlightRequest(request, origin);
+
+ List<String> allowMethods =
+ checkMethods(cors, preFlight ?
request.header(ACCESS_CONTROL_REQUEST_METHOD) : request.method());
+ if (allowMethods == null) {
+ return false;
+ }
+
+ List<String> allowHeaders = null;
+ if (preFlight) {
+ allowHeaders = checkHeaders(cors,
request.headerValues(ACCESS_CONTROL_REQUEST_HEADERS));
+ if (allowHeaders == null) {
+ return false;
+ }
+ }
+
+ response.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN, allowOrigin);
+
+ if (ArrayUtils.isNotEmpty(cors.getExposedHeaders())) {
+ response.setHeader(ACCESS_CONTROL_EXPOSE_HEADERS,
StringUtils.join(cors.getExposedHeaders(), SEP));
+ }
+
+ if (Boolean.TRUE.equals(cors.getAllowCredentials())) {
+ response.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS,
Boolean.TRUE.toString());
+ }
+
+ if (preFlight) {
+ response.setHeader(ACCESS_CONTROL_ALLOW_METHODS,
StringUtils.join(allowMethods, SEP));
+
+ if (!allowHeaders.isEmpty()) {
+ response.setHeader(ACCESS_CONTROL_ALLOW_HEADERS,
StringUtils.join(allowHeaders, SEP));
+ }
+ if (cors.getMaxAge() != null) {
+ response.setHeader(ACCESS_CONTROL_MAX_AGE,
cors.getMaxAge().toString());
+ }
+ throw new HttpResultPayloadException(HttpResult.builder()
+ .status(HttpStatus.NO_CONTENT)
+ .headers(response.headers())
+ .build());
+ }
+
+ return true;
+ }
+
+ private static void setVaryHeader(HttpResponse response) {
+ List<String> varyHeaders = response.headerValues(VARY);
+ String varyValue;
+ if (varyHeaders == null) {
+ varyValue = ORIGIN + SEP + ACCESS_CONTROL_REQUEST_METHOD + SEP +
ACCESS_CONTROL_REQUEST_HEADERS;
+ } else {
+ Set<String> varHeadersSet = new LinkedHashSet<>(varyHeaders);
+ varHeadersSet.add(ORIGIN);
+ varHeadersSet.add(ACCESS_CONTROL_REQUEST_METHOD);
+ varHeadersSet.add(ACCESS_CONTROL_REQUEST_HEADERS);
+ varyValue = StringUtils.join(varHeadersSet, SEP);
+ }
+ response.setHeader(VARY, varyValue);
+ }
+
+ private static String checkOrigin(CorsMeta cors, String origin) {
+ if (StringUtils.isBlank(origin)) {
+ return null;
+ }
+ origin = CorsUtils.formatOrigin(origin);
+ String[] allowedOrigins = cors.getAllowedOrigins();
+ if (ArrayUtils.isNotEmpty(allowedOrigins)) {
+ if (ArrayUtils.contains(allowedOrigins, ANY_VALUE)) {
+ if (Boolean.TRUE.equals(cors.getAllowCredentials())) {
+ throw new IllegalArgumentException(
+ "When allowCredentials is true, allowedOrigins
cannot contain the special value \"*\"");
+ }
+ return ANY_VALUE;
+ }
+ for (String allowedOrigin : allowedOrigins) {
+ if (origin.equalsIgnoreCase(allowedOrigin)) {
+ return origin;
+ }
+ }
+ }
+ if (ArrayUtils.isNotEmpty(cors.getAllowedOriginsPatterns())) {
+ for (Pattern pattern : cors.getAllowedOriginsPatterns()) {
+ if (pattern.matcher(origin).matches()) {
+ return origin;
+ }
+ }
+ }
+ return null;
+ }
+
+ private static List<String> checkMethods(CorsMeta cors, String method) {
+ if (method == null) {
+ return null;
+ }
+ String[] allowedMethods = cors.getAllowedMethods();
+ if (ArrayUtils.contains(allowedMethods, ANY_VALUE)) {
+ return Collections.singletonList(method);
+ }
+ for (String allowedMethod : allowedMethods) {
+ if (method.equalsIgnoreCase(allowedMethod)) {
+ return Arrays.asList(allowedMethods);
+ }
+ }
+ return null;
+ }
+
+ private static List<String> checkHeaders(CorsMeta cors, Collection<String>
headers) {
+ if (headers == null || headers.isEmpty()) {
+ return Collections.emptyList();
+ }
+ String[] allowedHeaders = cors.getAllowedHeaders();
+ if (ArrayUtils.isEmpty(allowedHeaders)) {
+ return null;
+ }
+
+ boolean allowAny = ArrayUtils.contains(allowedHeaders, ANY_VALUE);
+ List<String> result = new ArrayList<>(headers.size());
+ for (String header : headers) {
+ if (allowAny) {
+ result.add(header);
+ continue;
+ }
+ for (String allowedHeader : allowedHeaders) {
+ if (header.equalsIgnoreCase(allowedHeader)) {
+ result.add(header);
+ break;
+ }
+ }
+ }
+ return result.isEmpty() ? null : result;
+ }
+
+ private static boolean isNotCorsRequest(HttpRequest request, String
origin) {
+ if (origin == null) {
+ return true;
+ }
+ try {
+ URI uri = new URI(origin);
+ return request.scheme().equals(uri.getScheme())
+ && request.serverName().equals(uri.getHost())
+ && getPort(request.scheme(), request.serverPort()) ==
getPort(uri.getScheme(), uri.getPort());
+ } catch (URISyntaxException e) {
+ return false;
+ }
+ }
+
+ private static boolean isPreFlightRequest(HttpRequest request, String
origin) {
+ return request.method().equals(HttpMethods.OPTIONS.name())
+ && origin != null
+ && request.hasHeader(ACCESS_CONTROL_REQUEST_METHOD);
+ }
+
+ private static int getPort(String scheme, int port) {
+ if (port == -1) {
+ if ("http".equals(scheme)) {
+ return 80;
+ }
+ if ("https".equals(scheme)) {
+ return 443;
+ }
+ }
+ return port;
+ }
+}
diff --git
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsUtils.java
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsUtils.java
new file mode 100644
index 0000000000..a9aab0bd91
--- /dev/null
+++
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsUtils.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.dubbo.rpc.protocol.tri.rest.cors;
+
+import org.apache.dubbo.common.config.Configuration;
+import org.apache.dubbo.common.config.ConfigurationUtils;
+import org.apache.dubbo.common.utils.StringUtils;
+import org.apache.dubbo.rpc.model.FrameworkModel;
+import org.apache.dubbo.rpc.protocol.tri.rest.RestConstants;
+import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.CorsMeta;
+
+public class CorsUtils {
+
+ private CorsUtils() {}
+
+ public static CorsMeta getGlobalCorsMeta(FrameworkModel frameworkModel) {
+ Configuration config =
ConfigurationUtils.getGlobalConfiguration(frameworkModel.defaultApplication());
+
+ String maxAge = config.getString(RestConstants.MAX_AGE);
+ return CorsMeta.builder()
+ .allowedOrigins(getValues(config,
RestConstants.ALLOWED_ORIGINS))
+ .allowedMethods(getValues(config,
RestConstants.ALLOWED_METHODS))
+ .allowedHeaders(getValues(config,
RestConstants.ALLOWED_HEADERS))
+
.allowCredentials(config.getString(RestConstants.ALLOW_CREDENTIALS))
+ .exposedHeaders(getValues(config,
RestConstants.EXPOSED_HEADERS))
+ .maxAge(maxAge == null ? null : Long.valueOf(maxAge))
+ .build();
+ }
+
+ private static String[] getValues(Configuration config, String key) {
+ return StringUtils.tokenize(config.getString(key), ',');
+ }
+
+ public static String formatOrigin(String value) {
+ value = value.trim();
+ int last = value.length() - 1;
+ return last > -1 && value.charAt(last) == '/' ? value.substring(0,
last) : value;
+ }
+}
diff --git
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/filter/RestHeaderFilterAdapter.java
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/filter/RestHeaderFilterAdapter.java
index abbf4a7a5a..bb6f0927fb 100644
---
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/filter/RestHeaderFilterAdapter.java
+++
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/filter/RestHeaderFilterAdapter.java
@@ -19,7 +19,6 @@ package org.apache.dubbo.rpc.protocol.tri.rest.filter;
import org.apache.dubbo.remoting.http12.HttpRequest;
import org.apache.dubbo.remoting.http12.HttpResponse;
import org.apache.dubbo.rpc.HeaderFilter;
-import org.apache.dubbo.rpc.Invocation;
import org.apache.dubbo.rpc.Invoker;
import org.apache.dubbo.rpc.RpcException;
import org.apache.dubbo.rpc.RpcInvocation;
@@ -32,11 +31,12 @@ public abstract class RestHeaderFilterAdapter implements
HeaderFilter {
if
(TripleConstant.TRIPLE_HANDLER_TYPE_REST.equals(invocation.get(TripleConstant.HANDLER_TYPE_KEY)))
{
HttpRequest request = (HttpRequest)
invocation.get(TripleConstant.HTTP_REQUEST_KEY);
HttpResponse response = (HttpResponse)
invocation.get(TripleConstant.HTTP_RESPONSE_KEY);
- return invoke(invoker, invocation, request, response);
+ invoke(invoker, invocation, request, response);
}
return invocation;
}
- protected abstract RpcInvocation invoke(
- Invoker<?> invoker, Invocation invocation, HttpRequest request,
HttpResponse response) throws RpcException;
+ protected abstract void invoke(
+ Invoker<?> invoker, RpcInvocation invocation, HttpRequest request,
HttpResponse response)
+ throws RpcException;
}
diff --git
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/RequestMapping.java
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/RequestMapping.java
index 2c38367c71..5d00ad8eac 100644
---
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/RequestMapping.java
+++
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/RequestMapping.java
@@ -27,6 +27,7 @@ import
org.apache.dubbo.rpc.protocol.tri.rest.mapping.condition.ParamsCondition;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.condition.PathCondition;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.condition.PathExpression;
import
org.apache.dubbo.rpc.protocol.tri.rest.mapping.condition.ProducesCondition;
+import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.CorsMeta;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.ResponseMeta;
import java.util.Objects;
@@ -43,6 +44,7 @@ public final class RequestMapping implements
Condition<RequestMapping, HttpReque
private final ConsumesCondition consumesCondition;
private final ProducesCondition producesCondition;
private final ConditionWrapper customCondition;
+ private final CorsMeta cors;
private final ResponseMeta response;
private int hashCode;
@@ -56,6 +58,7 @@ public final class RequestMapping implements
Condition<RequestMapping, HttpReque
ConsumesCondition consumesCondition,
ProducesCondition producesCondition,
Condition<?, HttpRequest> customCondition,
+ CorsMeta cors,
ResponseMeta response) {
this.name = name;
this.pathCondition = pathCondition;
@@ -65,6 +68,7 @@ public final class RequestMapping implements
Condition<RequestMapping, HttpReque
this.consumesCondition = consumesCondition;
this.producesCondition = producesCondition;
this.customCondition = customCondition == null ? null : new
ConditionWrapper(customCondition);
+ this.cors = cors;
this.response = response;
}
@@ -82,12 +86,20 @@ public final class RequestMapping implements
Condition<RequestMapping, HttpReque
ConsumesCondition consumes = combine(consumesCondition,
other.consumesCondition);
ProducesCondition produces = combine(producesCondition,
other.producesCondition);
ConditionWrapper custom = combine(customCondition,
other.customCondition);
+ CorsMeta corsMeta = combine(this.cors, other.cors);
ResponseMeta response = ResponseMeta.combine(this.response,
other.response);
- return new RequestMapping(name, paths, methods, params, headers,
consumes, produces, custom, response);
+ return new RequestMapping(
+ name, paths, methods, params, headers, consumes, produces,
custom, corsMeta, response);
}
- private <T extends Condition<T, HttpRequest>> T combine(T value, T other) {
- return value == null ? other : other == null ? value :
value.combine(other);
+ private <T extends Condition<T, HttpRequest>> T combine(T source, T other)
{
+ return source == null ? other : other == null ? source :
source.combine(other);
+ }
+
+ private CorsMeta combine(CorsMeta source, CorsMeta other) {
+ return source == null || source.isEmpty()
+ ? other == null || other.isEmpty() ? null :
other.applyDefault()
+ : source.combine(other).applyDefault();
}
public RequestMapping match(HttpRequest request, PathExpression path) {
@@ -156,7 +168,7 @@ public final class RequestMapping implements
Condition<RequestMapping, HttpReque
}
}
- return new RequestMapping(name, paths, methods, params, headers,
consumes, produces, custom, response);
+ return new RequestMapping(name, paths, methods, params, headers,
consumes, produces, custom, cors, response);
}
public String getName() {
@@ -171,6 +183,10 @@ public final class RequestMapping implements
Condition<RequestMapping, HttpReque
return producesCondition;
}
+ public CorsMeta getCors() {
+ return cors;
+ }
+
public ResponseMeta getResponse() {
return response;
}
@@ -290,6 +306,9 @@ public final class RequestMapping implements
Condition<RequestMapping, HttpReque
if (response != null) {
sb.append(", response=").append(response);
}
+ if (cors != null) {
+ sb.append(", cors=").append(cors);
+ }
sb.append('}');
return sb.toString();
}
@@ -304,6 +323,7 @@ public final class RequestMapping implements
Condition<RequestMapping, HttpReque
private String[] consumes;
private String[] produces;
private Condition<?, HttpRequest> customCondition;
+ private CorsMeta corsMeta;
private Integer responseStatus;
private String responseReason;
@@ -352,6 +372,11 @@ public final class RequestMapping implements
Condition<RequestMapping, HttpReque
return this;
}
+ public Builder cors(CorsMeta corsMeta) {
+ this.corsMeta = corsMeta;
+ return this;
+ }
+
public Builder responseStatus(int status) {
responseStatus = status;
return this;
@@ -363,23 +388,17 @@ public final class RequestMapping implements
Condition<RequestMapping, HttpReque
}
public RequestMapping build() {
- PathCondition pathCondition = isEmpty(paths) ? null : new
PathCondition(contextPath, paths);
- MethodsCondition methodsCondition = isEmpty(methods) ? null : new
MethodsCondition(methods);
- ParamsCondition paramsCondition = isEmpty(params) ? null : new
ParamsCondition(params);
- HeadersCondition headersCondition = isEmpty(headers) ? null : new
HeadersCondition(headers);
- ConsumesCondition consumesCondition = isEmpty(consumes) ? null :
new ConsumesCondition(consumes);
- ProducesCondition producesCondition = isEmpty(produces) ? null :
new ProducesCondition(produces);
- ResponseMeta response = responseStatus == null ? null : new
ResponseMeta(responseStatus, responseReason);
return new RequestMapping(
name,
- pathCondition,
- methodsCondition,
- paramsCondition,
- headersCondition,
- consumesCondition,
- producesCondition,
+ isEmpty(paths) ? null : new PathCondition(contextPath,
paths),
+ isEmpty(methods) ? null : new MethodsCondition(methods),
+ isEmpty(params) ? null : new ParamsCondition(params),
+ isEmpty(headers) ? null : new HeadersCondition(headers),
+ isEmpty(consumes) ? null : new ConsumesCondition(consumes),
+ isEmpty(produces) ? null : new ProducesCondition(produces),
customCondition,
- response);
+ corsMeta == null ? null : corsMeta,
+ responseStatus == null ? null : new
ResponseMeta(responseStatus, responseReason));
}
}
}
diff --git
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/condition/MethodsCondition.java
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/condition/MethodsCondition.java
index fdde62dc61..fa01f67194 100644
---
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/condition/MethodsCondition.java
+++
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/condition/MethodsCondition.java
@@ -24,6 +24,7 @@ import java.util.Set;
import static org.apache.dubbo.remoting.http12.HttpMethods.GET;
import static org.apache.dubbo.remoting.http12.HttpMethods.HEAD;
+import static org.apache.dubbo.remoting.http12.HttpMethods.OPTIONS;
public final class MethodsCondition implements Condition<MethodsCondition,
HttpRequest> {
@@ -53,6 +54,12 @@ public final class MethodsCondition implements
Condition<MethodsCondition, HttpR
if (HEAD.name().equals(method) && methods.contains(GET.name())) {
return new MethodsCondition(GET.name());
}
+ if (OPTIONS.name().equals(method)
+ && request.hasHeader("origin")
+ && request.hasHeader("access-control-request-method")) {
+ return new MethodsCondition(OPTIONS.name());
+ }
+
return null;
}
diff --git
a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/meta/CorsMeta.java
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/meta/CorsMeta.java
new file mode 100644
index 0000000000..d5545c5858
--- /dev/null
+++
b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/mapping/meta/CorsMeta.java
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta;
+
+import org.apache.dubbo.common.utils.CollectionUtils;
+import org.apache.dubbo.common.utils.StringUtils;
+import org.apache.dubbo.remoting.http12.HttpMethods;
+import org.apache.dubbo.rpc.protocol.tri.rest.cors.CorsUtils;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.function.Function;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import static org.apache.dubbo.common.constants.CommonConstants.ANY_VALUE;
+import static org.apache.dubbo.common.utils.StringUtils.EMPTY_STRING_ARRAY;
+
+public class CorsMeta {
+
+ private final String[] allowedOrigins;
+ private final Pattern[] allowedOriginsPatterns;
+ private final String[] allowedMethods;
+ private final String[] allowedHeaders;
+ private final String[] exposedHeaders;
+ private final Boolean allowCredentials;
+ private final Long maxAge;
+
+ private CorsMeta(
+ String[] allowedOrigins,
+ Pattern[] allowedOriginsPatterns,
+ String[] allowedMethods,
+ String[] allowedHeaders,
+ String[] exposedHeaders,
+ Boolean allowCredentials,
+ Long maxAge) {
+ this.allowedOrigins = allowedOrigins;
+ this.allowedOriginsPatterns = allowedOriginsPatterns;
+ this.allowedMethods = allowedMethods;
+ this.allowedHeaders = allowedHeaders;
+ this.exposedHeaders = exposedHeaders;
+ this.allowCredentials = allowCredentials;
+ this.maxAge = maxAge;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public String[] getAllowedOrigins() {
+ return allowedOrigins;
+ }
+
+ public Pattern[] getAllowedOriginsPatterns() {
+ return allowedOriginsPatterns;
+ }
+
+ public String[] getAllowedMethods() {
+ return allowedMethods;
+ }
+
+ public String[] getAllowedHeaders() {
+ return allowedHeaders;
+ }
+
+ public String[] getExposedHeaders() {
+ return exposedHeaders;
+ }
+
+ public Boolean getAllowCredentials() {
+ return allowCredentials;
+ }
+
+ public Long getMaxAge() {
+ return maxAge;
+ }
+
+ public boolean isEmpty() {
+ return allowedOrigins.length == 0
+ && allowedMethods.length == 0
+ && allowedHeaders.length == 0
+ && exposedHeaders.length == 0
+ && allowCredentials == null
+ && maxAge == null;
+ }
+
+ public CorsMeta applyDefault() {
+ String[] allowedOriginArray = null;
+ Pattern[] allowedOriginPatternArray = null;
+ if (this.allowedOrigins.length == 0) {
+ allowedOriginArray = new String[] {ANY_VALUE};
+ allowedOriginPatternArray = new Pattern[] {null};
+ }
+
+ String[] allowedMethodArray = null;
+ if (this.allowedMethods.length == 0) {
+ allowedMethodArray =
+ new String[] {HttpMethods.GET.name(),
HttpMethods.HEAD.name(), HttpMethods.POST.name()};
+ }
+
+ String[] allowedHeaderArray = null;
+ if (this.allowedHeaders.length == 0) {
+ allowedHeaderArray = new String[] {ANY_VALUE};
+ }
+
+ Long maxAgeValue = null;
+ if (this.maxAge == null) {
+ maxAgeValue = 1800L;
+ }
+
+ if (allowedOriginArray == null
+ && allowedMethodArray == null
+ && allowedHeaderArray == null
+ && maxAgeValue == null) {
+ return this;
+ }
+
+ return new CorsMeta(
+ allowedOriginArray == null ? this.allowedOrigins :
allowedOriginArray,
+ allowedOriginPatternArray == null ?
this.allowedOriginsPatterns : allowedOriginPatternArray,
+ allowedMethodArray == null ? this.allowedMethods :
allowedMethodArray,
+ allowedHeaderArray == null ? this.allowedHeaders :
allowedHeaderArray,
+ exposedHeaders,
+ allowCredentials,
+ maxAgeValue);
+ }
+
+ public CorsMeta combine(CorsMeta other) {
+ if (other == null || other.isEmpty()) {
+ return this;
+ }
+ return new CorsMeta(
+ combine(allowedOrigins, other.allowedOrigins),
+ merge(allowedOriginsPatterns,
other.allowedOriginsPatterns).toArray(new Pattern[0]),
+ combine(allowedMethods, other.allowedMethods),
+ combine(allowedHeaders, other.allowedHeaders),
+ combine(exposedHeaders, other.exposedHeaders),
+ other.allowCredentials == null ? allowCredentials :
other.allowCredentials,
+ other.maxAge == null ? maxAge : other.maxAge);
+ }
+
+ /**
+ * Merge two arrays of CORS config values, with the other array having
higher priority.
+ */
+ private static String[] combine(String[] source, String[] other) {
+ if (other.length == 0) {
+ return source;
+ }
+ if (source.length == 0 || source[0].equals(ANY_VALUE) ||
other[0].equals(ANY_VALUE)) {
+ return other;
+ }
+ return merge(source, other).toArray(EMPTY_STRING_ARRAY);
+ }
+
+ private static <T> Set<T> merge(T[] source, T[] other) {
+ int size = source.length + other.length;
+ if (size == 0) {
+ return Collections.emptySet();
+ }
+ Set<T> merged = CollectionUtils.newLinkedHashSet(size);
+ Collections.addAll(merged, source);
+ Collections.addAll(merged, other);
+ return merged;
+ }
+
+ @Override
+ public String toString() {
+ return "CorsMeta{"
+ + "allowedOrigins=" + Arrays.toString(allowedOrigins)
+ + ", allowedOriginsPatterns=" +
Arrays.toString(allowedOriginsPatterns)
+ + ", allowedMethods=" + Arrays.toString(allowedMethods)
+ + ", allowedHeaders=" + Arrays.toString(allowedHeaders)
+ + ", exposedHeaders=" + Arrays.toString(exposedHeaders)
+ + ", allowCredentials=" + allowCredentials
+ + ", maxAge=" + maxAge
+ + '}';
+ }
+
+ public static final class Builder {
+
+ private static final Pattern PORTS_PATTERN =
Pattern.compile("(.*):\\[(\\*|\\d+(,\\d+)*)]");
+
+ private final Set<String> allowedOrigins = new LinkedHashSet<>();
+ private final Set<String> allowedMethods = new LinkedHashSet<>();
+ private final Set<String> allowedHeaders = new LinkedHashSet<>();
+ private final Set<String> exposedHeaders = new LinkedHashSet<>();
+ private Boolean allowCredentials;
+ private Long maxAge;
+
+ public Builder allowedOrigins(String... origins) {
+ addValues(allowedOrigins, CorsUtils::formatOrigin, origins);
+ return this;
+ }
+
+ public Builder allowedMethods(String... methods) {
+ addValues(allowedMethods, v -> v.trim().toUpperCase(), methods);
+ return this;
+ }
+
+ public Builder allowedHeaders(String... headers) {
+ addValues(allowedHeaders, String::trim, headers);
+ return this;
+ }
+
+ public Builder exposedHeaders(String... headers) {
+ addValues(exposedHeaders, String::trim, headers);
+ return this;
+ }
+
+ private static void addValues(Set<String> set, Function<String,
String> fn, String... values) {
+ if (values == null || set.contains(ANY_VALUE)) {
+ return;
+ }
+ for (String value : values) {
+ if (StringUtils.isNotEmpty(value)) {
+ value = fn.apply(value);
+ if (value.isEmpty()) {
+ continue;
+ }
+ if (ANY_VALUE.equals(value)) {
+ set.clear();
+ set.add(ANY_VALUE);
+ return;
+ }
+ set.add(value);
+ }
+ }
+ }
+
+ private static Pattern initPattern(String patternValue) {
+ String ports = null;
+ Matcher matcher = PORTS_PATTERN.matcher(patternValue);
+ if (matcher.matches()) {
+ patternValue = matcher.group(1);
+ ports = matcher.group(2);
+ }
+ patternValue = "\\Q" + patternValue + "\\E";
+ patternValue = patternValue.replace("*", "\\E.*\\Q");
+ if (ports != null) {
+ patternValue += (ANY_VALUE.equals(ports) ? "(:\\d+)?" : ":(" +
ports.replace(',', '|') + ")");
+ }
+ return Pattern.compile(patternValue);
+ }
+
+ public Builder allowCredentials(Boolean allowCredentials) {
+ this.allowCredentials = allowCredentials;
+ return this;
+ }
+
+ public Builder allowCredentials(String allowCredentials) {
+ if ("true".equals(allowCredentials)) {
+ this.allowCredentials = true;
+ } else if ("false".equals(allowCredentials)) {
+ this.allowCredentials = false;
+ }
+ return this;
+ }
+
+ public Builder maxAge(Long maxAge) {
+ if (maxAge != null && maxAge > -1) {
+ this.maxAge = maxAge;
+ }
+ return this;
+ }
+
+ public CorsMeta build() {
+ int len = allowedOrigins.size();
+ String[] origins = new String[len];
+ List<Pattern> originsPatterns = new ArrayList<>(len);
+ int i = 0;
+ for (String origin : allowedOrigins) {
+ origins[i++] = origin;
+ if (ANY_VALUE.equals(origin)) {
+ continue;
+ }
+ originsPatterns.add(initPattern(origin));
+ }
+ return new CorsMeta(
+ origins,
+ originsPatterns.toArray(new Pattern[0]),
+ allowedMethods.toArray(EMPTY_STRING_ARRAY),
+ allowedHeaders.toArray(EMPTY_STRING_ARRAY),
+ exposedHeaders.toArray(EMPTY_STRING_ARRAY),
+ allowCredentials,
+ maxAge);
+ }
+ }
+}
diff --git
a/dubbo-rpc/dubbo-rpc-triple/src/main/resources/META-INF/dubbo/internal/org.apache.dubbo.rpc.HeaderFilter
b/dubbo-rpc/dubbo-rpc-triple/src/main/resources/META-INF/dubbo/internal/org.apache.dubbo.rpc.HeaderFilter
new file mode 100644
index 0000000000..fa26599c25
--- /dev/null
+++
b/dubbo-rpc/dubbo-rpc-triple/src/main/resources/META-INF/dubbo/internal/org.apache.dubbo.rpc.HeaderFilter
@@ -0,0 +1 @@
+rest-cors=org.apache.dubbo.rpc.protocol.tri.rest.cors.CorsHeaderFilter
\ No newline at end of file
diff --git
a/dubbo-rpc/dubbo-rpc-triple/src/test/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsHeaderFilterTest.java
b/dubbo-rpc/dubbo-rpc-triple/src/test/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsHeaderFilterTest.java
new file mode 100644
index 0000000000..70ef114483
--- /dev/null
+++
b/dubbo-rpc/dubbo-rpc-triple/src/test/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsHeaderFilterTest.java
@@ -0,0 +1,423 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.dubbo.rpc.protocol.tri.rest.cors;
+
+import org.apache.dubbo.remoting.http12.HttpMethods;
+import org.apache.dubbo.remoting.http12.HttpRequest;
+import org.apache.dubbo.remoting.http12.HttpResponse;
+import org.apache.dubbo.remoting.http12.HttpStatus;
+import org.apache.dubbo.remoting.http12.exception.HttpResultPayloadException;
+import org.apache.dubbo.remoting.http12.message.DefaultHttpResponse;
+import org.apache.dubbo.rpc.protocol.tri.rest.RestConstants;
+import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RequestMapping;
+import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.CorsMeta;
+
+import java.util.Arrays;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+class CorsHeaderFilterTest {
+
+ private HttpRequest request;
+
+ private HttpResponse response;
+
+ private MockCorsHeaderFilter processor;
+
+ private RequestMapping build;
+
+ static class MockCorsHeaderFilter extends CorsHeaderFilter {
+ public void process(HttpRequest request, HttpResponse response) {
+ invoke(null, null, request, response);
+ }
+
+ public void preLightProcess(HttpRequest request, HttpResponse
response, int code) {
+ try {
+ process(request, response);
+ Assertions.fail();
+ } catch (HttpResultPayloadException e) {
+ Assertions.assertEquals(code, e.getStatusCode());
+ } catch (Exception e) {
+ Assertions.fail();
+ }
+ }
+ }
+
+ private CorsMeta defaultCorsMeta() {
+ return CorsMeta.builder().maxAge(1000L).build();
+ }
+
+ @BeforeEach
+ public void setup() {
+ build = Mockito.mock(RequestMapping.class);
+ request = Mockito.mock(HttpRequest.class);
+
Mockito.when(request.attribute(RestConstants.MAPPING_ATTRIBUTE)).thenReturn(build);
+ Mockito.when(request.uri()).thenReturn("/test.html");
+ Mockito.when(request.serverName()).thenReturn("domain1.example");
+ Mockito.when(request.scheme()).thenReturn("http");
+ Mockito.when(request.serverPort()).thenReturn(80);
+ Mockito.when(request.remoteHost()).thenReturn("127.0.0.1");
+ response = new DefaultHttpResponse();
+ response.setStatus(HttpStatus.OK.getCode());
+ processor = new MockCorsHeaderFilter();
+ }
+
+ @Test
+ void requestWithoutOriginHeader() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+ Mockito.when(build.getCors()).thenReturn(CorsMeta.builder().build());
+ Mockito.when(build.getCors()).thenReturn(defaultCorsMeta());
+ processor.process(request, response);
+
Assertions.assertFalse(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+
Assertions.assertTrue(response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ORIGIN));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS));
+ Assertions.assertEquals(HttpStatus.OK.getCode(), response.status());
+ }
+
+ @Test
+ void sameOriginRequest() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("http://domain1.example");
+ Mockito.when(build.getCors()).thenReturn(defaultCorsMeta());
+ processor.process(request, response);
+
Assertions.assertFalse(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+
Assertions.assertTrue(response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ORIGIN));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS));
+ Assertions.assertEquals(HttpStatus.OK.getCode(), response.status());
+ }
+
+ @Test
+ void actualRequestWithOriginHeader() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+ Mockito.when(build.getCors()).thenReturn(defaultCorsMeta());
+ Assertions.assertThrows(HttpResultPayloadException.class, () ->
processor.process(request, response));
+ }
+
+ @Test
+ void actualRequestWithOriginHeaderAndNullConfig() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+ Mockito.when(build.getCors()).thenReturn(null);
+ processor.process(request, response);
+
Assertions.assertFalse(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+ Assertions.assertEquals(HttpStatus.OK.getCode(), response.status());
+ }
+
+ @Test
+ void actualRequestWithOriginHeaderAndAllowedOrigin() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(build.getCors()).thenReturn(CorsMeta.builder().build().applyDefault());
+ processor.process(request, response);
+
Assertions.assertTrue(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+ Assertions.assertEquals("*",
response.header(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+
Assertions.assertFalse(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_MAX_AGE));
+
Assertions.assertFalse(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_EXPOSE_HEADERS));
+
Assertions.assertTrue(response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ORIGIN));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS));
+ Assertions.assertEquals(HttpStatus.OK.getCode(), response.status());
+ }
+
+ @Test
+ void actualRequestCaseInsensitiveOriginMatch() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+ Mockito.when(build.getCors())
+ .thenReturn(CorsMeta.builder()
+ .allowedOrigins("https://DOMAIN2.com")
+ .build()
+ .applyDefault());
+ processor.process(request, response);
+ Assertions.assertEquals(HttpStatus.OK.getCode(), response.status());
+
Assertions.assertTrue(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+ }
+
+ @Test
+ void actualRequestTrailingSlashOriginMatch() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+ Mockito.when(build.getCors())
+ .thenReturn(CorsMeta.builder()
+ .allowedOrigins("https://domain2.com/")
+ .build()
+ .applyDefault());
+ processor.process(request, response);
+ Assertions.assertEquals(HttpStatus.OK.getCode(), response.status());
+
Assertions.assertTrue(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+ }
+
+ @Test
+ void actualRequestExposedHeaders() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+ Mockito.doReturn(CorsMeta.builder()
+ .allowedOrigins("https://domain2.com")
+ .exposedHeaders("header1", "header2")
+ .build()
+ .applyDefault())
+ .when(build)
+ .getCors();
+ processor.process(request, response);
+
Assertions.assertTrue(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+ Assertions.assertEquals("https://domain2.com",
response.header(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+
Assertions.assertTrue(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_EXPOSE_HEADERS));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1"));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2"));
+
Assertions.assertTrue(response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ORIGIN));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS));
+ Assertions.assertEquals(HttpStatus.OK.getCode(), response.status());
+ }
+
+ @Test
+ void actualRequestCredentials() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+ Mockito.doReturn(CorsMeta.builder()
+ .allowedOrigins("https://domain1.com",
"https://domain2.com")
+ .allowCredentials(true)
+ .build()
+ .applyDefault())
+ .when(build)
+ .getCors();
+ processor.process(request, response);
+
Assertions.assertTrue(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+ Assertions.assertEquals("https://domain2.com",
response.header(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_ORIGIN));
+
Assertions.assertTrue(response.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+ Assertions.assertEquals("true",
response.header(CorsHeaderFilter.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+
Assertions.assertTrue(response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ORIGIN));
+ Assertions.assertEquals(HttpStatus.OK.getCode(), response.status());
+ }
+
+ @Test
+ void actualRequestCredentialsWithWildcardOrigin() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+ Mockito.doReturn(CorsMeta.builder()
+ .allowedOrigins("*")
+ .allowCredentials(true)
+ .build()
+ .applyDefault())
+ .when(build)
+ .getCors();
+
+ Assertions.assertThrows(IllegalArgumentException.class, () ->
processor.process(request, response));
+ }
+
+ @Test
+ void preflightRequestWrongAllowedMethod() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn("DELETE");
+
Mockito.when(request.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn(true);
+ Mockito.when(build.getCors())
+ .thenReturn(CorsMeta.builder().allowedOrigins("*").build());
+ processor.preLightProcess(request, response,
HttpStatus.FORBIDDEN.getCode());
+ }
+
+ @Test
+ void preflightRequestMatchedAllowedMethod() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn("GET");
+
Mockito.when(request.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn(true);
+
Mockito.when(build.getCors()).thenReturn(CorsMeta.builder().build().applyDefault());
+ processor.preLightProcess(request, response,
HttpStatus.NO_CONTENT.getCode());
+ }
+
+ @Test
+ void preflightRequestTestWithOriginButWithoutOtherHeaders() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn(true);
+ Mockito.when(build.getCors()).thenReturn(defaultCorsMeta());
+ processor.preLightProcess(request, response,
HttpStatus.FORBIDDEN.getCode());
+ }
+
+ @Test
+ void preflightRequestWithoutRequestMethod() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS))
+ .thenReturn("Header1");
+
Mockito.when(request.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn(true);
+ Mockito.when(build.getCors()).thenReturn(defaultCorsMeta());
+ processor.preLightProcess(request, response,
HttpStatus.FORBIDDEN.getCode());
+ }
+
+ @Test
+ void preflightRequestWithRequestAndMethodHeaderButNoConfig() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn("GET");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS))
+ .thenReturn("Header1");
+
Mockito.when(request.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn(true);
+ Mockito.when(build.getCors()).thenReturn(defaultCorsMeta());
+ processor.preLightProcess(request, response,
HttpStatus.FORBIDDEN.getCode());
+ }
+
+ @Test
+ void preflightRequestValidRequestAndConfig() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn("GET");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS))
+ .thenReturn("Header1");
+
Mockito.when(request.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn(true);
+ Mockito.when(build.getCors())
+ .thenReturn(CorsMeta.builder()
+ .allowedOrigins("*")
+ .allowedMethods("GET", "PUT")
+ .allowedHeaders("Header1", "Header2")
+ .build());
+ processor.preLightProcess(request, response,
HttpStatus.NO_CONTENT.getCode());
+ }
+
+ @Test
+ void preflightRequestAllowedHeaders() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn("GET");
+
Mockito.when(request.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn(true);
+ Mockito.doReturn(Arrays.asList("Header1", "Header2"))
+ .when(request)
+ .headerValues(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS);
+ Mockito.doReturn(CorsMeta.builder()
+ .allowedOrigins("https://domain2.com")
+ .allowedHeaders("Header1", "Header2")
+ .build()
+ .applyDefault())
+ .when(build)
+ .getCors();
+ processor.preLightProcess(request, response,
HttpStatus.NO_CONTENT.getCode());
+ }
+
+ @Test
+ void preflightRequestAllowsAllHeaders() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn("GET");
+
+
Mockito.when(request.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn(true);
+
Mockito.when(request.headerValues(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS))
+ .thenReturn(Arrays.asList("Header1", "Header2"));
+ Mockito.when(build.getCors())
+ .thenReturn(CorsMeta.builder()
+ .allowedOrigins("https://domain2.com")
+ .allowedHeaders("*")
+ .build()
+ .applyDefault());
+ processor.preLightProcess(request, response,
HttpStatus.NO_CONTENT.getCode());
+ }
+
+ @Test
+ void preflightRequestWithEmptyHeaders() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn("GET");
+
Mockito.when(request.hasHeader(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn(true);
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS))
+ .thenReturn("");
+ Mockito.when(build.getCors())
+ .thenReturn(CorsMeta.builder()
+ .allowedOrigins("https://domain2.com")
+ .allowedHeaders("*")
+ .build()
+ .applyDefault());
+ processor.preLightProcess(request, response,
HttpStatus.NO_CONTENT.getCode());
+ }
+
+ @Test
+ void preflightRequestWithNullConfig() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn("GET");
+ Mockito.when(build.getCors())
+ .thenReturn(CorsMeta.builder().allowedOrigins("*").build());
+ processor.preLightProcess(request, response,
HttpStatus.FORBIDDEN.getCode());
+ }
+
+ @Test
+ void preflightRequestCredentials() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.OPTIONS.name());
+
Mockito.when(request.header(CorsHeaderFilter.ORIGIN)).thenReturn("https://domain2.com");
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD))
+ .thenReturn("GET");
+
Mockito.doReturn(true).when(request).hasHeader(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD);
+
Mockito.when(request.header(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS))
+ .thenReturn("Header1");
+ Mockito.doReturn(CorsMeta.builder()
+ .allowedOrigins("https://domain1.com",
"https://domain2.com", "http://domain3.example")
+ .allowedHeaders("Header1")
+ .allowCredentials(true)
+ .build()
+ .applyDefault())
+ .when(build)
+ .getCors();
+ processor.preLightProcess(request, response,
HttpStatus.NO_CONTENT.getCode());
+ }
+
+ @Test
+ void preventDuplicatedVaryHeaders() {
+ Mockito.when(request.method()).thenReturn(HttpMethods.GET.name());
+ response.setHeader(
+ CorsHeaderFilter.VARY,
+ CorsHeaderFilter.ORIGIN + "," +
CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD + ","
+ + CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS);
+ processor.process(request, response);
+
Assertions.assertTrue(response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ORIGIN));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_METHOD));
+ Assertions.assertTrue(
+
response.header(CorsHeaderFilter.VARY).contains(CorsHeaderFilter.ACCESS_CONTROL_REQUEST_HEADERS));
+ }
+}
diff --git a/pom.xml b/pom.xml
index b93069ee27..a8a3b1d694 100644
--- a/pom.xml
+++ b/pom.xml
@@ -571,6 +571,7 @@
**/org/apache/dubbo/test/common/utils/TestSocketUtils.java,
**/org/apache/dubbo/triple/TripleWrapper.java,
**/org/apache/dubbo/rpc/protocol/tri/TriHttp2RemoteFlowController.java,
+
**/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsHeaderFilter.java,
**/org/apache/dubbo/metrics/aggregate/DubboMergingDigest.java,
**/org/apache/dubbo/metrics/aggregate/DubboAbstractTDigest.java,
**/org/apache/dubbo/common/logger/helpers/FormattingTuple.java,
@@ -842,6 +843,7 @@
<exclude>src/test/java/org/apache/dubbo/config/spring/EmbeddedZooKeeper.java</exclude>
<exclude>src/main/java/org/apache/dubbo/test/common/utils/TestSocketUtils.java</exclude>
<exclude>src/main/java/org/apache/dubbo/rpc/protocol/tri/TriHttp2RemoteFlowController.java</exclude>
+
<exclude>src/main/java/org/apache/dubbo/rpc/protocol/tri/rest/cors/CorsHeaderFilter.java</exclude>
<exclude>src/main/java/org/apache/dubbo/common/threadpool/serial/SerializingExecutor.java</exclude>
<exclude>src/main/java/org/apache/dubbo/maven/plugin/aot/AbstractAotMojo.java</exclude>
<exclude>src/main/java/org/apache/dubbo/maven/plugin/aot/AbstractDependencyFilterMojo.java</exclude>