This is an automated email from the ASF dual-hosted git repository.
xuetaoli pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go-pixiu.git
The following commit(s) were added to refs/heads/develop by this push:
new 0170d46d Feature/support opa server (#843)
0170d46d is described below
commit 0170d46ded43da92b1ed094cd4693b7f510d40bb
Author: Sirui Huang <[email protected]>
AuthorDate: Thu Dec 18 11:54:22 2025 +0800
Feature/support opa server (#843)
* support http filter plugins (Open Policy Agent)
---------
Signed-off-by: Aetherance <[email protected]>
Co-authored-by: Sirui Huang <[email protected]>
Co-authored-by: Xuetao Li <[email protected]>
---
pkg/filter/opa/opa.go | 3 +-
pkg/filter/opa/opa_test.go | 224 +++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 226 insertions(+), 1 deletion(-)
diff --git a/pkg/filter/opa/opa.go b/pkg/filter/opa/opa.go
index 100d611c..83fdbb47 100644
--- a/pkg/filter/opa/opa.go
+++ b/pkg/filter/opa/opa.go
@@ -278,12 +278,13 @@ func (f *Filter) evaluateServer(c
*contextHttp.HttpContext, input map[string]any
defer resp.Body.Close()
// Check HTTP status code
- if resp.StatusCode != 200 {
+ if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.Errorf("Failed to read OPA server response body:
%v", err)
body = []byte("")
}
+
logger.Errorf("OPA server returned status %d: %s",
resp.StatusCode, string(body))
errResp := contextHttp.BadGateway.WithError(fmt.Errorf("OPA
server returned status %d", resp.StatusCode))
c.SendLocalReply(errResp.Status, errResp.ToJSON())
diff --git a/pkg/filter/opa/opa_test.go b/pkg/filter/opa/opa_test.go
index 98a50824..76dfe98e 100644
--- a/pkg/filter/opa/opa_test.go
+++ b/pkg/filter/opa/opa_test.go
@@ -22,6 +22,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
"time"
)
@@ -538,6 +539,229 @@ func (m *mockFilterChain) OnEncode(ctx
*contextHttp.HttpContext) {
// Not needed for testing
}
+type errorRoundTripper struct {
+ err error
+}
+
+func (e errorRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+ return nil, e.err
+}
+
+func newTestContext(req *http.Request) *contextHttp.HttpContext {
+ return &contextHttp.HttpContext{
+ Writer: httptest.NewRecorder(),
+ Request: req,
+ Ctx: context.Background(),
+ }
+}
+
+// TestFactoryAndConfigCovers creation helpers with simple happy path coverage
+func TestFactoryAndConfig(t *testing.T) {
+ p := &Plugin{}
+ factoryIface, err := p.CreateFilterFactory()
+ assert.NoError(t, err)
+
+ factory, ok := factoryIface.(*FilterFactory)
+ assert.True(t, ok)
+ assert.NotNil(t, factory.Config())
+}
+
+func TestApplyEmbeddedSuccessAndDefaultServerTimeout(t *testing.T) {
+ // Embedded successful init
+ factory := &FilterFactory{
+ cfg: &Config{
+ Policy: `
+ package test
+ default allow := true
+ `,
+ Entrypoint: "data.test.allow",
+ },
+ }
+ err := factory.Apply()
+ assert.NoError(t, err)
+ assert.NotNil(t, factory.preparedQuery)
+
+ // Server mode default timeout 100ms when TimeoutMs not provided
+ serverFactory := &FilterFactory{
+ cfg: &Config{
+ ServerURL: "http://example.com",
+ DecisionPath: "/v1/data/test",
+ },
+ }
+ err = serverFactory.Apply()
+ assert.NoError(t, err)
+ assert.Equal(t, 100*time.Millisecond, serverFactory.httpClient.Timeout)
+}
+
+func TestApplyEmbeddedPrepareFailure(t *testing.T) {
+ factory := &FilterFactory{
+ cfg: &Config{
+ Policy: `
+ package test
+ allow = 1 == // malformed to force prepare error
+ `,
+ Entrypoint: "data.test.allow",
+ },
+ }
+ err := factory.Apply()
+ assert.Error(t, err)
+ assert.True(t, strings.Contains(err.Error(), "failed to prepare"))
+}
+
+func TestPrepareFilterChainEmbeddedAndUninitialized(t *testing.T) {
+ // Embedded branch
+ factory := &FilterFactory{
+ cfg: &Config{
+ Policy: `
+ package test
+ default allow := true
+ `,
+ Entrypoint: "data.test.allow",
+ },
+ }
+ err := factory.Apply()
+ assert.NoError(t, err)
+
+ chain := &mockFilterChain{}
+ ctx := newTestContext(httptest.NewRequest("GET", "/demo", nil))
+ err = factory.PrepareFilterChain(ctx, chain)
+ assert.NoError(t, err)
+ assert.Len(t, chain.filters, 1)
+ assert.NotNil(t, chain.filters[0].(*Filter).preparedQuery)
+
+ // Uninitialized factory should error
+ emptyFactory := &FilterFactory{cfg: &Config{}}
+ err = emptyFactory.PrepareFilterChain(ctx, chain)
+ assert.Error(t, err)
+}
+
+func TestDecodeNotInitialized(t *testing.T) {
+ f := &Filter{cfg: &Config{}}
+ req := httptest.NewRequest("GET", "/not-init", nil)
+ ctx := newTestContext(req)
+
+ status := f.Decode(ctx)
+ assert.Equal(t, filter.Stop, status)
+ assert.Equal(t, http.StatusInternalServerError, ctx.GetStatusCode())
+}
+
+func TestEvaluateEmbeddedEmptyAndError(t *testing.T) {
+ // Empty results branch
+ rEmpty := rego.New(
+ rego.Query("data.test.undefined"),
+ rego.Module("policy.rego", `package test`),
+ )
+ preparedEmpty, err := rEmpty.PrepareForEval(context.Background())
+ assert.NoError(t, err)
+ fEmpty := &Filter{
+ cfg: &Config{Policy: "package test", Entrypoint:
"data.test.undefined"},
+ preparedQuery: &preparedEmpty,
+ }
+ ctx := newTestContext(httptest.NewRequest("GET", "/empty", nil))
+ result := fEmpty.Decode(ctx)
+ assert.Equal(t, filter.Stop, result)
+ assert.Equal(t, http.StatusForbidden, ctx.GetStatusCode())
+
+ // Eval error branch using unsupported input type to trigger value
conversion error
+ rErr := rego.New(
+ rego.Query("data.test.allow"),
+ rego.Module("policy.rego", `
+ package test
+ default allow := true
+ `),
+ )
+ preparedErr, err := rErr.PrepareForEval(context.Background())
+ assert.NoError(t, err)
+ fErr := &Filter{
+ cfg: &Config{Policy: "package test", Entrypoint:
"data.test.allow"},
+ preparedQuery: &preparedErr,
+ }
+ ctxErr := newTestContext(httptest.NewRequest("GET", "/err", nil))
+ ctxErr.Params = map[string]any{"bad": make(chan int)} // unsupported
type for input conversion
+
+ result = fErr.Decode(ctxErr)
+ assert.Equal(t, filter.Stop, result)
+ assert.Equal(t, http.StatusInternalServerError, ctxErr.GetStatusCode())
+}
+
+func TestEvaluateServerTransportAndDecodeErrors(t *testing.T) {
+ // Transport error (non-timeout) -> 503
+ fTransport := &Filter{
+ cfg: &Config{
+ ServerURL: "http://opa",
+ DecisionPath: "/v1/data/test",
+ },
+ httpClient: &http.Client{Transport: errorRoundTripper{err:
assert.AnError}},
+ }
+ req := httptest.NewRequest("GET", "/transport", nil)
+ ctx := newTestContext(req)
+ status := fTransport.Decode(ctx)
+ assert.Equal(t, filter.Stop, status)
+ assert.Equal(t, http.StatusServiceUnavailable, ctx.GetStatusCode())
+
+ // Decode error from invalid JSON
+ serverBadJSON := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("not-json"))
+ }))
+ defer serverBadJSON.Close()
+
+ fDecode := &Filter{
+ cfg: &Config{
+ ServerURL: serverBadJSON.URL,
+ DecisionPath: "/v1/data/test",
+ },
+ httpClient: &http.Client{Timeout: time.Second},
+ }
+ ctxBad := newTestContext(httptest.NewRequest("GET", "/decode", nil))
+ status = fDecode.Decode(ctxBad)
+ assert.Equal(t, filter.Stop, status)
+ assert.Equal(t, http.StatusBadGateway, ctxBad.GetStatusCode())
+}
+
+func TestEvaluateServerMissingResult(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ json.NewEncoder(w).Encode(map[string]any{
+ "decision": true,
+ })
+ }))
+ defer server.Close()
+
+ f := &Filter{
+ cfg: &Config{
+ ServerURL: server.URL,
+ DecisionPath: "/v1/data/test",
+ },
+ httpClient: &http.Client{Timeout: time.Second},
+ }
+
+ ctx := newTestContext(httptest.NewRequest("GET", "/missing", nil))
+ status := f.Decode(ctx)
+ assert.Equal(t, filter.Stop, status)
+ assert.Equal(t, http.StatusBadGateway, ctx.GetStatusCode())
+}
+
+func TestExtractDecisionVariants(t *testing.T) {
+ testcases := []struct {
+ name string
+ value any
+ expected bool
+ }{
+ {"direct-bool", true, true},
+ {"allow-field", map[string]any{"allow": true}, true},
+ {"nested-result-field", map[string]any{"result": true}, true},
+ {"unknown-type", "unexpected", false},
+ }
+
+ for _, tc := range testcases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ got := extractDecision(tc.value)
+ assert.Equal(t, tc.expected, got)
+ })
+ }
+}
+
// Backward compatibility test names
func TestAllowedRule(t *testing.T) {
TestEmbeddedAllowedRule(t)