This is an automated email from the ASF dual-hosted git repository. lahirujayathilake pushed a commit to branch privileges-impl in repository https://gitbox.apache.org/repos/asf/airavata-custos.git
commit 6d4e4ce17c285f574eb1bf292d07697771cbe4f9 Author: lahiruj <[email protected]> AuthorDate: Thu May 28 05:24:52 2026 -0400 Add admin privileges framework with privileges:grant gate --- cmd/server/main.go | 6 + .../db/migrations/000006_user_privileges.down.sql | 18 + .../db/migrations/000006_user_privileges.up.sql | 37 ++ internal/server/auth.go | 149 ++++++++ internal/server/integration_common_test.go | 128 +++++++ internal/server/privilege.go | 143 ++++++++ internal/server/privilege_integration_test.go | 267 ++++++++++++++ internal/server/server.go | 17 +- internal/store/store.go | 22 ++ internal/store/user_privilege_store.go | 127 +++++++ pkg/models/privilege.go | 69 ++++ pkg/service/integration_common_test.go | 155 +++++++++ pkg/service/interface.go | 15 + pkg/service/mock.go | 385 +++++++++++++++++++++ pkg/service/service.go | 4 + pkg/service/user_privilege.go | 280 +++++++++++++++ pkg/service/user_privilege_integration_test.go | 278 +++++++++++++++ 17 files changed, 2097 insertions(+), 3 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 14a954b30..602ffe89b 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -80,6 +80,12 @@ func run() error { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() + if email := os.Getenv("CUSTOS_BOOTSTRAP_ADMIN_EMAIL"); email != "" { + if err := svc.BootstrapPrivilegeGrant(ctx, email, "env:CUSTOS_BOOTSTRAP_ADMIN_EMAIL"); err != nil { + slog.Warn("bootstrap privilege grant failed", "email", email, "error", err) + } + } + // Tracks every background goroutine spawned by connectors so we can wait // for them to drain on shutdown instead of killing them mid-flight. var connectorsWG sync.WaitGroup diff --git a/internal/db/migrations/000006_user_privileges.down.sql b/internal/db/migrations/000006_user_privileges.down.sql new file mode 100644 index 000000000..162b3d7bc --- /dev/null +++ b/internal/db/migrations/000006_user_privileges.down.sql @@ -0,0 +1,18 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +DROP TABLE IF EXISTS user_privileges; diff --git a/internal/db/migrations/000006_user_privileges.up.sql b/internal/db/migrations/000006_user_privileges.up.sql new file mode 100644 index 000000000..351bafb52 --- /dev/null +++ b/internal/db/migrations/000006_user_privileges.up.sql @@ -0,0 +1,37 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +SET NAMES utf8mb4; +SET time_zone = '+00:00'; + +-- Privileges held by a user. Only active grants live here; revoke is DELETE. +-- The full grant/revoke history (who, when, why) is recorded in audit_events. +CREATE TABLE IF NOT EXISTS user_privileges +( + id VARCHAR(255) NOT NULL, + user_id VARCHAR(255) NOT NULL, + privilege VARCHAR(64) NOT NULL, + granted_by VARCHAR(255) NULL, + granted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + reason TEXT NULL, + PRIMARY KEY (id), + UNIQUE KEY uq_user_privileges (user_id, privilege), + KEY idx_user_privileges_user (user_id), + KEY idx_user_privileges_priv (privilege), + CONSTRAINT fk_user_privileges_user FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE, + CONSTRAINT fk_user_privileges_granted_by FOREIGN KEY (granted_by) REFERENCES users (id) ON DELETE SET NULL +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci; diff --git a/internal/server/auth.go b/internal/server/auth.go new file mode 100644 index 000000000..b3b3b5c5f --- /dev/null +++ b/internal/server/auth.go @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package server + +import ( + "context" + "errors" + "net/http" + "sync" + "time" + + "github.com/apache/airavata-custos/pkg/models" +) + +// callerHeader identifies the caller. TODO A JWT-verification middleware should set +// this header from the verified `sub` claim after validating the token +// against the IdP's JWKS endpoint. Until that lands, the value +// is supplied directly by the caller. +const callerHeader = "X-Custos-User-Id" + +// authProfileTTL bounds how long the middleware will trust a cached +// privilege set before re-reading the DB. +// +// TODO: make configurable via env (eg. AUTH_CACHE_TTL_SECONDS), cap at 60s. +// TODO support caching for multi-instance, the cache is per-process. +const authProfileTTL = 5 * time.Second + +// authProfile is the cached snapshot of a user's effective privileges. +type authProfile struct { + privileges map[models.PrivilegeKey]struct{} +} + +func (p *authProfile) has(privilege models.PrivilegeKey) bool { + if p == nil { + return false + } + _, ok := p.privileges[privilege] + return ok +} + +// authProfileCache is a tiny in-process TTL cache for "userID -> privilege set". +// An empty profile (zero privileges) is still cached so users with no grants do not read the DB everytime. +type authProfileCache struct { + mu sync.Mutex + entries map[string]authProfileCacheEntry + ttl time.Duration +} + +type authProfileCacheEntry struct { + profile *authProfile + expires time.Time +} + +func newAuthProfileCache(ttl time.Duration) *authProfileCache { + return &authProfileCache{ + entries: make(map[string]authProfileCacheEntry), + ttl: ttl, + } +} + +func (c *authProfileCache) get(userID string) (*authProfile, bool) { + c.mu.Lock() + defer c.mu.Unlock() + e, ok := c.entries[userID] + if !ok || time.Now().After(e.expires) { + return nil, false + } + return e.profile, true +} + +func (c *authProfileCache) set(userID string, profile *authProfile) { + c.mu.Lock() + defer c.mu.Unlock() + c.entries[userID] = authProfileCacheEntry{ + profile: profile, + expires: time.Now().Add(c.ttl), + } +} + +// invalidate drops the cache entry for userID. Called after any grant or +// revoke so subsequent requests reflect the new state without waiting for TTL. +func (c *authProfileCache) invalidate(userID string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.entries, userID) +} + +// requirePrivilege returns a middleware that admits the request only if the +// caller (identified by callerHeader) holds the named active privilege. +// +// Responses: +// - 401 Unauthorized - no caller header +// - 403 Forbidden - caller is identified but does not hold the privilege +// - 503 Service Unavailable - auth-profile lookup failed +// +// Fail-closed: a DB failure NEVER reads as 403 +func (s *Server) requirePrivilege(p models.PrivilegeKey, next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + callerID := r.Header.Get(callerHeader) + if callerID == "" { + writeError(w, http.StatusUnauthorized, errors.New("missing "+callerHeader+" header")) + return + } + profile, err := s.lookupAuthProfile(r.Context(), callerID) + if err != nil { + writeError(w, http.StatusServiceUnavailable, errors.New("auth lookup failed")) + return + } + if !profile.has(p) { + writeError(w, http.StatusForbidden, errors.New("insufficient privilege")) + return + } + next(w, r) + } +} + +// lookupAuthProfile returns the caller's current privilege snapshot, hitting +// the cache first and falling back to the DB. Errors propagate so middleware +// can fail closed. +func (s *Server) lookupAuthProfile(ctx context.Context, userID string) (*authProfile, error) { + if cached, ok := s.authCache.get(userID); ok { + return cached, nil + } + grants, err := s.svc.ListUserPrivileges(ctx, userID) + if err != nil { + return nil, err + } + profile := &authProfile{privileges: make(map[models.PrivilegeKey]struct{}, len(grants))} + for _, g := range grants { + profile.privileges[g.Privilege] = struct{}{} + } + s.authCache.set(userID, profile) + return profile, nil +} diff --git a/internal/server/integration_common_test.go b/internal/server/integration_common_test.go new file mode 100644 index 000000000..490eb8923 --- /dev/null +++ b/internal/server/integration_common_test.go @@ -0,0 +1,128 @@ +//go:build integration + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package server + +import ( + "os" + "sync" + "testing" + + "github.com/google/uuid" + "github.com/jmoiron/sqlx" + + "github.com/apache/airavata-custos/internal/db" + "github.com/apache/airavata-custos/pkg/events" + "github.com/apache/airavata-custos/pkg/models" + "github.com/apache/airavata-custos/pkg/service" +) + +var ( + sharedDB *sqlx.DB + sharedDBOnce sync.Once + sharedDBErr error +) + +func setupTestStack(t *testing.T) (*sqlx.DB, *service.Service, *Server) { + t.Helper() + dsn := os.Getenv("CORE_TEST_DATABASE_DSN") + if dsn == "" { + dsn = os.Getenv("DATABASE_DSN") + } + if dsn == "" { + t.Skip("integration env not set: CORE_TEST_DATABASE_DSN or DATABASE_DSN required") + } + sharedDBOnce.Do(func() { + database, err := db.Open(db.Config{ + DSN: dsn, + MaxOpenConns: 5, + MaxIdleConns: 2, + }) + if err != nil { + sharedDBErr = err + return + } + if err := db.MigrateEmbedded(database); err != nil { + sharedDBErr = err + return + } + sharedDB = database + }) + if sharedDBErr != nil { + t.Fatalf("setup db: %v", sharedDBErr) + } + truncateAll(t, sharedDB) + svc := service.New(sharedDB, events.New()) + return sharedDB, svc, New(svc) +} + +func truncateAll(t *testing.T, database *sqlx.DB) { + t.Helper() + tables := []string{ + "user_privileges", + "audit_events", + "user_identities", + "users", + "organizations", + } + if _, err := database.Exec("SET FOREIGN_KEY_CHECKS = 0"); err != nil { + t.Fatalf("disable FK: %v", err) + } + for _, tbl := range tables { + if _, err := database.Exec("TRUNCATE TABLE " + tbl); err != nil { + t.Fatalf("truncate %s: %v", tbl, err) + } + } + if _, err := database.Exec("SET FOREIGN_KEY_CHECKS = 1"); err != nil { + t.Fatalf("re-enable FK: %v", err) + } +} + +func seedUser(t *testing.T, database *sqlx.DB, email string) string { + t.Helper() + orgID := uuid.NewString() + if _, err := database.Exec( + "INSERT INTO organizations (id, originated_id, name) VALUES (?, ?, ?)", + orgID, "TEST-ORG-"+orgID[:8], "Test Org", + ); err != nil { + t.Fatalf("seed org: %v", err) + } + userID := uuid.NewString() + if _, err := database.Exec( + "INSERT INTO users (id, organization_id, first_name, last_name, middle_name, email, status) VALUES (?, ?, ?, ?, ?, ?, ?)", + userID, orgID, "Test", "User", "", email, string(models.UserActive), + ); err != nil { + t.Fatalf("seed user %s: %v", email, err) + } + return userID +} + +// seedPrivilegeGrant directly inserts an active privileges:grant for userID. +// Bypasses the service guards so tests can stand up a granter without a +// chicken-and-egg dependency. +func seedPrivilegeGrant(t *testing.T, database *sqlx.DB, userID string) { + t.Helper() + if _, err := database.Exec( + `INSERT INTO user_privileges (id, user_id, privilege, granted_at, reason) + VALUES (?, ?, ?, NOW(6), 'seed')`, + uuid.NewString(), userID, string(models.PrivilegeGrant), + ); err != nil { + t.Fatalf("seed privileges:grant for %s: %v", userID, err) + } +} diff --git a/internal/server/privilege.go b/internal/server/privilege.go new file mode 100644 index 000000000..b882e6932 --- /dev/null +++ b/internal/server/privilege.go @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package server + +import ( + "errors" + "net/http" + + "github.com/apache/airavata-custos/pkg/models" +) + +// getCallerPrivileges handles GET /user/privileges. Returns the +// authenticated caller's effective privilege set. +func (s *Server) getCallerPrivileges(w http.ResponseWriter, r *http.Request) { + callerID := r.Header.Get(callerHeader) + if callerID == "" { + writeError(w, http.StatusUnauthorized, errors.New("missing "+callerHeader+" header")) + return + } + profile, err := s.lookupAuthProfile(r.Context(), callerID) + if err != nil { + writeError(w, http.StatusServiceUnavailable, errors.New("auth lookup failed")) + return + } + keys := make([]models.PrivilegeKey, 0, len(profile.privileges)) + for k := range profile.privileges { + keys = append(keys, k) + } + writeJSON(w, http.StatusOK, map[string]any{"privileges": keys}) +} + +// getPrivilegeCatalog handles GET /privileges/catalog. Returns the static +// catalog of declared privilege keys. Gated on privileges:grant. +func (s *Server) getPrivilegeCatalog(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, s.svc.PrivilegeCatalog()) +} + +// listUserPrivileges handles GET /users/{id}/privileges. Returns the active +// privileges of the target user. Gated on privileges:grant. +func (s *Server) listUserPrivileges(w http.ResponseWriter, r *http.Request) { + userID := r.PathValue("id") + if userID == "" { + writeError(w, http.StatusBadRequest, errors.New("user id is required")) + return + } + rows, err := s.svc.ListUserPrivileges(r.Context(), userID) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, rows) +} + +// listPrivilegeHolders handles GET /privileges/{key}/holders. Gated on +// privileges:grant. +func (s *Server) listPrivilegeHolders(w http.ResponseWriter, r *http.Request) { + key := models.PrivilegeKey(r.PathValue("key")) + if !models.IsKnownPrivilege(key) { + writeError(w, http.StatusBadRequest, errors.New("unknown privilege key")) + return + } + rows, err := s.svc.ListPrivilegeHolders(r.Context(), key) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, rows) +} + +type grantPrivilegeRequest struct { + Privilege models.PrivilegeKey `json:"privilege"` + Reason string `json:"reason"` +} + +// grantPrivilege handles POST /users/{id}/privileges. Gated on privileges:grant. +func (s *Server) grantPrivilege(w http.ResponseWriter, r *http.Request) { + userID := r.PathValue("id") + if userID == "" { + writeError(w, http.StatusBadRequest, errors.New("user id is required")) + return + } + granterID := r.Header.Get(callerHeader) + if granterID == "" { + writeError(w, http.StatusUnauthorized, errors.New("missing "+callerHeader+" header")) + return + } + var req grantPrivilegeRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + grant, err := s.svc.GrantPrivilege(r.Context(), userID, req.Privilege, granterID, req.Reason) + if err != nil { + writeServiceError(w, err) + return + } + s.authCache.invalidate(userID) + writeJSON(w, http.StatusCreated, grant) +} + +type revokePrivilegeRequest struct { + Reason string `json:"reason"` +} + +// revokePrivilege handles DELETE /users/{id}/privileges/{key}. Gated on +// privileges:grant. +func (s *Server) revokePrivilege(w http.ResponseWriter, r *http.Request) { + userID := r.PathValue("id") + key := models.PrivilegeKey(r.PathValue("key")) + if userID == "" || key == "" { + writeError(w, http.StatusBadRequest, errors.New("user id and privilege key are required")) + return + } + revokerID := r.Header.Get(callerHeader) + if revokerID == "" { + writeError(w, http.StatusUnauthorized, errors.New("missing "+callerHeader+" header")) + return + } + var req revokePrivilegeRequest + _ = decodeJSON(r, &req) + if err := s.svc.RevokePrivilege(r.Context(), userID, key, revokerID, req.Reason); err != nil { + writeServiceError(w, err) + return + } + s.authCache.invalidate(userID) + s.authCache.invalidate(revokerID) + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/server/privilege_integration_test.go b/internal/server/privilege_integration_test.go new file mode 100644 index 000000000..18cc29d0f --- /dev/null +++ b/internal/server/privilege_integration_test.go @@ -0,0 +1,267 @@ +//go:build integration + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package server + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/apache/airavata-custos/pkg/models" +) + +func TestGetCallerPrivileges_MissingHeader_401(t *testing.T) { + _, _, srv := setupTestStack(t) + rr := httptest.NewRecorder() + srv.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/user/privileges", nil)) + if rr.Code != http.StatusUnauthorized { + t.Errorf("status: got %d, want 401", rr.Code) + } +} + +func TestGetCallerPrivileges_NoGrants_ReturnsEmpty(t *testing.T) { + database, _, srv := setupTestStack(t) + user := seedUser(t, database, "[email protected]") + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/user/privileges", nil) + req.Header.Set(callerHeader, user) + srv.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want 200", rr.Code) + } + var body struct { + Privileges []string `json:"privileges"` + } + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if len(body.Privileges) != 0 { + t.Errorf("privileges: got %v, want empty", body.Privileges) + } +} + +func TestGetCallerPrivileges_WithGrants(t *testing.T) { + database, _, srv := setupTestStack(t) + user := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, user) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/user/privileges", nil) + req.Header.Set(callerHeader, user) + srv.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want 200", rr.Code) + } + var body struct { + Privileges []string `json:"privileges"` + } + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if len(body.Privileges) != 1 || body.Privileges[0] != string(models.PrivilegeGrant) { + t.Errorf("privileges: got %v, want [%s]", body.Privileges, models.PrivilegeGrant) + } +} + +func TestRequirePrivilege_NoGrants_403(t *testing.T) { + database, _, srv := setupTestStack(t) + user := seedUser(t, database, "[email protected]") + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/privileges/catalog", nil) + req.Header.Set(callerHeader, user) + srv.ServeHTTP(rr, req) + if rr.Code != http.StatusForbidden { + t.Errorf("status: got %d, want 403", rr.Code) + } +} + +func TestRequirePrivilege_WithGrant_200(t *testing.T) { + database, _, srv := setupTestStack(t) + user := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, user) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/privileges/catalog", nil) + req.Header.Set(callerHeader, user) + srv.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("status: got %d, want 200", rr.Code) + } + var cat []string + if err := json.NewDecoder(rr.Body).Decode(&cat); err != nil { + t.Fatalf("decode: %v", err) + } + if len(cat) == 0 { + t.Errorf("catalog empty") + } +} + +func TestGrantPrivilegeEndpoint_HappyPath(t *testing.T) { + database, svc, srv := setupTestStack(t) + granter := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, granter) + + body, _ := json.Marshal(map[string]any{"privilege": "amie:read", "reason": "ops view"}) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/users/"+target+"/privileges", bytes.NewReader(body)) + req.Header.Set(callerHeader, granter) + req.Header.Set("Content-Type", "application/json") + srv.ServeHTTP(rr, req) + if rr.Code != http.StatusCreated { + t.Fatalf("status: got %d, want 201, body=%s", rr.Code, rr.Body.String()) + } + has, err := svc.HasPrivilege(context.Background(), target, models.PrivilegeAMIERead) + if err != nil || !has { + t.Errorf("HasPrivilege after grant endpoint: has=%v err=%v", has, err) + } +} + +func TestGrantPrivilegeEndpoint_GranterWithoutMeta_403(t *testing.T) { + database, _, srv := setupTestStack(t) + plain := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + body, _ := json.Marshal(map[string]any{"privilege": "amie:read"}) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/users/"+target+"/privileges", bytes.NewReader(body)) + req.Header.Set(callerHeader, plain) + srv.ServeHTTP(rr, req) + if rr.Code != http.StatusForbidden { + t.Errorf("status: got %d, want 403 (granter lacks privileges:grant)", rr.Code) + } +} + +func TestRevokePrivilegeEndpoint_HappyPath(t *testing.T) { + database, svc, srv := setupTestStack(t) + granter := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, granter) + if _, err := svc.GrantPrivilege(context.Background(), target, models.PrivilegeAMIERead, granter, ""); err != nil { + t.Fatalf("seed grant: %v", err) + } + body, _ := json.Marshal(map[string]any{"reason": "rotated"}) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/users/"+target+"/privileges/amie:read", bytes.NewReader(body)) + req.Header.Set(callerHeader, granter) + srv.ServeHTTP(rr, req) + if rr.Code != http.StatusNoContent { + t.Fatalf("status: got %d, want 204, body=%s", rr.Code, rr.Body.String()) + } + if has, err := svc.HasPrivilege(context.Background(), target, models.PrivilegeAMIERead); err != nil || has { + t.Errorf("HasPrivilege after revoke: has=%v err=%v", has, err) + } +} + +func TestRevokePrivilegeEndpoint_SelfRevokeMeta_400(t *testing.T) { + database, _, srv := setupTestStack(t) + user := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, user) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/users/"+user+"/privileges/privileges:grant", nil) + req.Header.Set(callerHeader, user) + srv.ServeHTTP(rr, req) + if rr.Code != http.StatusBadRequest { + t.Errorf("status: got %d, want 400 (self-revoke of meta)", rr.Code) + } +} + +func TestListUserPrivilegesEndpoint(t *testing.T) { + database, svc, srv := setupTestStack(t) + granter := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, granter) + if _, err := svc.GrantPrivilege(context.Background(), target, models.PrivilegeAMIERead, granter, ""); err != nil { + t.Fatalf("grant: %v", err) + } + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/users/"+target+"/privileges", nil) + req.Header.Set(callerHeader, granter) + srv.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want 200, body=%s", rr.Code, rr.Body.String()) + } + var rows []models.UserPrivilege + if err := json.NewDecoder(rr.Body).Decode(&rows); err != nil { + t.Fatalf("decode: %v", err) + } + if len(rows) != 1 || rows[0].Privilege != models.PrivilegeAMIERead { + t.Errorf("rows: got %v, want [amie:read]", rows) + } +} + +func TestListPrivilegeHoldersEndpoint(t *testing.T) { + database, svc, srv := setupTestStack(t) + granter := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, granter) + if _, err := svc.GrantPrivilege(context.Background(), target, models.PrivilegeAMIERead, granter, ""); err != nil { + t.Fatalf("grant: %v", err) + } + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/privileges/amie:read/holders", nil) + req.Header.Set(callerHeader, granter) + srv.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want 200, body=%s", rr.Code, rr.Body.String()) + } + var rows []models.UserPrivilege + if err := json.NewDecoder(rr.Body).Decode(&rows); err != nil { + t.Fatalf("decode: %v", err) + } + if len(rows) != 1 || rows[0].UserID != target { + t.Errorf("holders: got %v, want [target=%s]", rows, target) + } +} + +func TestRequirePrivilege_StaleCacheStillReturns403AfterRevoke(t *testing.T) { + database, svc, srv := setupTestStack(t) + a := seedUser(t, database, "[email protected]") + b := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, a) + seedPrivilegeGrant(t, database, b) + + // b warms the cache. + warm := httptest.NewRecorder() + warmReq := httptest.NewRequest(http.MethodGet, "/privileges/catalog", nil) + warmReq.Header.Set(callerHeader, b) + srv.ServeHTTP(warm, warmReq) + if warm.Code != http.StatusOK { + t.Fatalf("warm-up: got %d, want 200", warm.Code) + } + + // a revokes b's privileges:grant via service (avoids the self-revoke + // guard since a != b). + if err := svc.RevokePrivilege(context.Background(), b, models.PrivilegeGrant, a, "rotated"); err != nil { + t.Fatalf("revoke b's grant: %v", err) + } + // invalidate b's cache the way the HTTP handler would. + srv.authCache.invalidate(b) + + // b retries and now gets 403. + again := httptest.NewRecorder() + againReq := httptest.NewRequest(http.MethodGet, "/privileges/catalog", nil) + againReq.Header.Set(callerHeader, b) + srv.ServeHTTP(again, againReq) + if again.Code != http.StatusForbidden { + t.Errorf("post-revoke status for b: got %d, want 403", again.Code) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 086e8de64..15d637874 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -32,13 +32,18 @@ import ( // Server is an HTTP handler that exposes the service API. type Server struct { - svc *service.Service - mux *http.ServeMux + svc *service.Service + mux *http.ServeMux + authCache *authProfileCache } // New builds an HTTP handler wired to the supplied service. func New(svc *service.Service) *Server { - s := &Server{svc: svc, mux: http.NewServeMux()} + s := &Server{ + svc: svc, + mux: http.NewServeMux(), + authCache: newAuthProfileCache(authProfileTTL), + } s.routes() return s } @@ -143,6 +148,12 @@ func (s *Server) routes() { s.mux.HandleFunc("GET /user-identities/oidc-subjects/{oidcSub}", s.getUserIdentityByOIDCSub) s.mux.HandleFunc("GET /users/{id}/user-identities", s.listUserIdentitiesForUser) + s.mux.HandleFunc("GET /user/privileges", s.getCallerPrivileges) + s.mux.HandleFunc("GET /privileges/catalog", s.requirePrivilege(models.PrivilegeGrant, s.getPrivilegeCatalog)) + s.mux.HandleFunc("GET /users/{id}/privileges", s.requirePrivilege(models.PrivilegeGrant, s.listUserPrivileges)) + s.mux.HandleFunc("GET /privileges/{key}/holders", s.requirePrivilege(models.PrivilegeGrant, s.listPrivilegeHolders)) + s.mux.HandleFunc("POST /users/{id}/privileges", s.requirePrivilege(models.PrivilegeGrant, s.grantPrivilege)) + s.mux.HandleFunc("DELETE /users/{id}/privileges/{key}", s.requirePrivilege(models.PrivilegeGrant, s.revokePrivilege)) } func (s *Server) healthz(w http.ResponseWriter, _ *http.Request) { diff --git a/internal/store/store.go b/internal/store/store.go index 1cae3728f..a17ed5805 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -328,6 +328,28 @@ type AuditEventStore interface { Delete(ctx context.Context, tx *sql.Tx, id string) error } +// UserPrivilegeStore defines persistence operations for fine-grained admin +// privileges. Only active grants live in the table; revoke is DELETE. The +// full grant/revoke history is in audit_events. +type UserPrivilegeStore interface { + // Find returns the active grant for (userID, privilege), or nil. + Find(ctx context.Context, userID string, privilege models.PrivilegeKey) (*models.UserPrivilege, error) + // FindForUpdate returns the active grant inside a tx with SELECT FOR + // UPDATE so the caller can serialize grant / revoke decisions. + FindForUpdate(ctx context.Context, tx *sql.Tx, userID string, privilege models.PrivilegeKey) (*models.UserPrivilege, error) + // ListByUser returns every active grant held by the user. + ListByUser(ctx context.Context, userID string) ([]models.UserPrivilege, error) + // ListByPrivilege returns every active holder of the given privilege. + ListByPrivilege(ctx context.Context, privilege models.PrivilegeKey) ([]models.UserPrivilege, error) + // CountByPrivilege returns the number of active holders inside a tx. + // Used to enforce the last-meta-holder guard when revoking PrivilegeGrant. + CountByPrivilege(ctx context.Context, tx *sql.Tx, privilege models.PrivilegeKey) (int, error) + // Create inserts a new grant inside the provided transaction. + Create(ctx context.Context, tx *sql.Tx, r *models.UserPrivilege) error + // Delete removes the grant for (userID, privilege) inside the provided tx. + Delete(ctx context.Context, tx *sql.Tx, userID string, privilege models.PrivilegeKey) error +} + // ComputeAllocationUsageStore defines persistence operations for the // append-only log of resource consumption events charged against a compute // allocation. diff --git a/internal/store/user_privilege_store.go b/internal/store/user_privilege_store.go new file mode 100644 index 000000000..2ca42a84b --- /dev/null +++ b/internal/store/user_privilege_store.go @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package store + +import ( + "context" + "database/sql" + "errors" + + "github.com/jmoiron/sqlx" + + "github.com/apache/airavata-custos/pkg/models" +) + +const userPrivilegeColumns = "id, user_id, privilege, granted_by, granted_at, reason" + +type mysqlUserPrivilegeStore struct { + db *sqlx.DB +} + +// NewUserPrivilegeStore returns a MySQL-backed UserPrivilegeStore. +func NewUserPrivilegeStore(db *sqlx.DB) UserPrivilegeStore { + return &mysqlUserPrivilegeStore{db: db} +} + +func (s *mysqlUserPrivilegeStore) Find(ctx context.Context, userID string, privilege models.PrivilegeKey) (*models.UserPrivilege, error) { + var r models.UserPrivilege + err := s.db.GetContext(ctx, &r, + `SELECT `+userPrivilegeColumns+` + FROM user_privileges + WHERE user_id = ? AND privilege = ?`, + userID, privilege) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &r, nil +} + +func (s *mysqlUserPrivilegeStore) FindForUpdate(ctx context.Context, tx *sql.Tx, userID string, privilege models.PrivilegeKey) (*models.UserPrivilege, error) { + row := tx.QueryRowContext(ctx, + `SELECT `+userPrivilegeColumns+` + FROM user_privileges + WHERE user_id = ? AND privilege = ? + FOR UPDATE`, + userID, privilege) + var r models.UserPrivilege + if err := row.Scan( + &r.ID, &r.UserID, &r.Privilege, + &r.GrantedBy, &r.GrantedAt, &r.Reason, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &r, nil +} + +func (s *mysqlUserPrivilegeStore) ListByUser(ctx context.Context, userID string) ([]models.UserPrivilege, error) { + var rows []models.UserPrivilege + err := s.db.SelectContext(ctx, &rows, + `SELECT `+userPrivilegeColumns+` + FROM user_privileges + WHERE user_id = ? + ORDER BY granted_at`, userID) + if err != nil { + return nil, err + } + return rows, nil +} + +func (s *mysqlUserPrivilegeStore) ListByPrivilege(ctx context.Context, privilege models.PrivilegeKey) ([]models.UserPrivilege, error) { + var rows []models.UserPrivilege + err := s.db.SelectContext(ctx, &rows, + `SELECT `+userPrivilegeColumns+` + FROM user_privileges + WHERE privilege = ? + ORDER BY granted_at`, privilege) + if err != nil { + return nil, err + } + return rows, nil +} + +func (s *mysqlUserPrivilegeStore) CountByPrivilege(ctx context.Context, tx *sql.Tx, privilege models.PrivilegeKey) (int, error) { + var n int + err := tx.QueryRowContext(ctx, + `SELECT COUNT(*) FROM user_privileges WHERE privilege = ?`, privilege).Scan(&n) + if err != nil { + return 0, err + } + return n, nil +} + +func (s *mysqlUserPrivilegeStore) Create(ctx context.Context, tx *sql.Tx, r *models.UserPrivilege) error { + _, err := tx.ExecContext(ctx, + `INSERT INTO user_privileges + (id, user_id, privilege, granted_by, granted_at, reason) + VALUES (?, ?, ?, ?, ?, ?)`, + r.ID, r.UserID, r.Privilege, r.GrantedBy, r.GrantedAt, r.Reason) + return err +} + +func (s *mysqlUserPrivilegeStore) Delete(ctx context.Context, tx *sql.Tx, userID string, privilege models.PrivilegeKey) error { + _, err := tx.ExecContext(ctx, + `DELETE FROM user_privileges WHERE user_id = ? AND privilege = ?`, + userID, privilege) + return err +} diff --git a/pkg/models/privilege.go b/pkg/models/privilege.go new file mode 100644 index 000000000..47eb3bdac --- /dev/null +++ b/pkg/models/privilege.go @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package models + +import "time" + +// PrivilegeKey names a fine-grained admin capability. The set is closed and +// declared in code; service-layer validation rejects grants of any key not +// returned by KnownPrivileges. +type PrivilegeKey string + +const ( + PrivilegeAMIERead PrivilegeKey = "amie:read" + PrivilegeAMIEWrite PrivilegeKey = "amie:write" + PrivilegeHPCRead PrivilegeKey = "hpc:read" + PrivilegeHPCWrite PrivilegeKey = "hpc:write" + PrivilegeSignerRead PrivilegeKey = "signer:read" + PrivilegeSignerWrite PrivilegeKey = "signer:write" + PrivilegeGrant PrivilegeKey = "privileges:grant" +) + +// KnownPrivileges returns the static catalog of declared privilege keys. +func KnownPrivileges() []PrivilegeKey { + return []PrivilegeKey{ + PrivilegeAMIERead, + PrivilegeAMIEWrite, + PrivilegeHPCRead, + PrivilegeHPCWrite, + PrivilegeSignerRead, + PrivilegeSignerWrite, + PrivilegeGrant, + } +} + +// IsKnownPrivilege reports whether p is in the declared catalog. +func IsKnownPrivilege(p PrivilegeKey) bool { + for _, k := range KnownPrivileges() { + if k == p { + return true + } + } + return false +} + +// UserPrivilege is one active grant in user_privileges. Revoked grants are +// deleted from the table; their history lives in audit_events. +type UserPrivilege struct { + ID string `json:"id" db:"id"` + UserID string `json:"user_id" db:"user_id"` + Privilege PrivilegeKey `json:"privilege" db:"privilege"` + GrantedBy *string `json:"granted_by" db:"granted_by"` + GrantedAt time.Time `json:"granted_at" db:"granted_at"` + Reason *string `json:"reason" db:"reason"` +} diff --git a/pkg/service/integration_common_test.go b/pkg/service/integration_common_test.go new file mode 100644 index 000000000..0a3fbb891 --- /dev/null +++ b/pkg/service/integration_common_test.go @@ -0,0 +1,155 @@ +//go:build integration + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package service + +import ( + "context" + "os" + "sync" + "testing" + + "github.com/google/uuid" + "github.com/jmoiron/sqlx" + + "github.com/apache/airavata-custos/internal/db" + "github.com/apache/airavata-custos/pkg/events" + "github.com/apache/airavata-custos/pkg/models" +) + +var ( + sharedDB *sqlx.DB + sharedDBOnce sync.Once + sharedDBErr error +) + +// setupTestDB opens (once per process) and migrates the test DB pointed to by +// CORE_TEST_DATABASE_DSN or DATABASE_DSN. Subsequent calls truncate the core +// tables and return the shared handle. +func setupTestDB(t *testing.T) *sqlx.DB { + t.Helper() + dsn := os.Getenv("CORE_TEST_DATABASE_DSN") + if dsn == "" { + dsn = os.Getenv("DATABASE_DSN") + } + if dsn == "" { + t.Skip("integration env not set: CORE_TEST_DATABASE_DSN or DATABASE_DSN required") + } + sharedDBOnce.Do(func() { + database, err := db.Open(db.Config{ + DSN: dsn, + MaxOpenConns: 5, + MaxIdleConns: 2, + }) + if err != nil { + sharedDBErr = err + return + } + if err := db.MigrateEmbedded(database); err != nil { + sharedDBErr = err + return + } + sharedDB = database + }) + if sharedDBErr != nil { + t.Fatalf("setup db: %v", sharedDBErr) + } + truncateAll(t, sharedDB) + return sharedDB +} + +// truncateAll wipes the core tables that privilege tests touch. +func truncateAll(t *testing.T, database *sqlx.DB) { + t.Helper() + tables := []string{ + "user_privileges", + "audit_events", + "user_identities", + "users", + "organizations", + } + if _, err := database.Exec("SET FOREIGN_KEY_CHECKS = 0"); err != nil { + t.Fatalf("disable FK: %v", err) + } + for _, tbl := range tables { + if _, err := database.Exec("TRUNCATE TABLE " + tbl); err != nil { + t.Fatalf("truncate %s: %v", tbl, err) + } + } + if _, err := database.Exec("SET FOREIGN_KEY_CHECKS = 1"); err != nil { + t.Fatalf("re-enable FK: %v", err) + } +} + +func newTestService(database *sqlx.DB) *Service { + return New(database, events.New()) +} + +func seedOrg(t *testing.T, database *sqlx.DB) string { + t.Helper() + orgID := uuid.NewString() + if _, err := database.Exec( + "INSERT INTO organizations (id, originated_id, name) VALUES (?, ?, ?)", + orgID, "TEST-ORG-"+orgID[:8], "Test Org", + ); err != nil { + t.Fatalf("seed org: %v", err) + } + return orgID +} + +func seedUser(t *testing.T, database *sqlx.DB, email string) string { + t.Helper() + orgID := seedOrg(t, database) + userID := uuid.NewString() + if _, err := database.Exec( + "INSERT INTO users (id, organization_id, first_name, last_name, middle_name, email, status) VALUES (?, ?, ?, ?, ?, ?, ?)", + userID, orgID, "Test", "User", "", email, string(models.UserActive), + ); err != nil { + t.Fatalf("seed user %s: %v", email, err) + } + return userID +} + +// seedPrivilegeGrant directly inserts an active grant of privileges:grant for +// userID. Bypasses the service guards so tests can stand up a granter without +// a chicken-and-egg dependency. +func seedPrivilegeGrant(t *testing.T, database *sqlx.DB, userID string) { + t.Helper() + if _, err := database.Exec( + `INSERT INTO user_privileges (id, user_id, privilege, granted_at, reason) + VALUES (?, ?, ?, NOW(6), 'seed')`, + uuid.NewString(), userID, string(models.PrivilegeGrant), + ); err != nil { + t.Fatalf("seed privileges:grant for %s: %v", userID, err) + } +} + +func countAuditEventsOfType(t *testing.T, database *sqlx.DB, eventType, entityID string) int { + t.Helper() + var n int + if err := database.Get(&n, + "SELECT COUNT(*) FROM audit_events WHERE event_type = ? AND entity_id = ?", + eventType, entityID, + ); err != nil { + t.Fatalf("count audit_events: %v", err) + } + return n +} + +func ctx() context.Context { return context.Background() } diff --git a/pkg/service/interface.go b/pkg/service/interface.go index bfbb65337..52ddd8ac0 100644 --- a/pkg/service/interface.go +++ b/pkg/service/interface.go @@ -214,6 +214,20 @@ type AuditEventService interface { DeleteAuditEvent(ctx context.Context, id string) error } +// UserPrivilegeService exposes the fine-grained capability layer that gates +// admin surfaces. Privileges are the sole authorization signal; the DB is +// the source of truth and HasPrivilege re-reads it on every call (callers +// cache the result if hot). +type UserPrivilegeService interface { + GrantPrivilege(ctx context.Context, userID string, privilege models.PrivilegeKey, granterID, reason string) (*models.UserPrivilege, error) + RevokePrivilege(ctx context.Context, userID string, privilege models.PrivilegeKey, revokerID, reason string) error + HasPrivilege(ctx context.Context, userID string, privilege models.PrivilegeKey) (bool, error) + ListUserPrivileges(ctx context.Context, userID string) ([]models.UserPrivilege, error) + ListPrivilegeHolders(ctx context.Context, privilege models.PrivilegeKey) ([]models.UserPrivilege, error) + PrivilegeCatalog() []models.PrivilegeKey + BootstrapPrivilegeGrant(ctx context.Context, email, source string) error +} + // CoreService is the aggregate of every domain interface this package exposes. // Most callers should depend on this, or — when they need only a slice of // the API — on the narrower per-domain interfaces above. @@ -235,6 +249,7 @@ type CoreService interface { ComputeAllocationMembershipResourceOverrideService ComputeAllocationUsageService AuditEventService + UserPrivilegeService } // Compile-time assertion that *Service satisfies the aggregate CoreService. diff --git a/pkg/service/mock.go b/pkg/service/mock.go index b60278d1e..2e2393472 100644 --- a/pkg/service/mock.go +++ b/pkg/service/mock.go @@ -23,6 +23,9 @@ var _ CoreService = &CoreServiceMock{} // AttachResourceToAllocationFunc: func(ctx context.Context, allocationID string, resourceID string, resourceAmount int64, resourceTime int64) (*models.ComputeAllocationResourceMapping, error) { // panic("mock out the AttachResourceToAllocation method") // }, +// BootstrapPrivilegeGrantFunc: func(ctx context.Context, email string, source string) error { +// panic("mock out the BootstrapPrivilegeGrant method") +// }, // CreateAuditEventFunc: func(ctx context.Context, e *models.AuditEvent) (*models.AuditEvent, error) { // panic("mock out the CreateAuditEvent method") // }, @@ -212,6 +215,12 @@ var _ CoreService = &CoreServiceMock{} // GetUserIdentityBySourceAndExternalIDFunc: func(ctx context.Context, source string, externalID string) (*models.UserIdentity, error) { // panic("mock out the GetUserIdentityBySourceAndExternalID method") // }, +// GrantPrivilegeFunc: func(ctx context.Context, userID string, privilege models.PrivilegeKey, granterID string, reason string) (*models.UserPrivilege, error) { +// panic("mock out the GrantPrivilege method") +// }, +// HasPrivilegeFunc: func(ctx context.Context, userID string, privilege models.PrivilegeKey) (bool, error) { +// panic("mock out the HasPrivilege method") +// }, // ListAllAuditEventsFunc: func(ctx context.Context) ([]*models.AuditEvent, error) { // panic("mock out the ListAllAuditEvents method") // }, @@ -266,6 +275,9 @@ var _ CoreService = &CoreServiceMock{} // ListOverridesForResourceFunc: func(ctx context.Context, resourceID string) ([]models.ComputeAllocationMembershipResourceOverride, error) { // panic("mock out the ListOverridesForResource method") // }, +// ListPrivilegeHoldersFunc: func(ctx context.Context, privilege models.PrivilegeKey) ([]models.UserPrivilege, error) { +// panic("mock out the ListPrivilegeHolders method") +// }, // ListProjectsByPIFunc: func(ctx context.Context, piUserID string) ([]models.Project, error) { // panic("mock out the ListProjectsByPI method") // }, @@ -284,12 +296,21 @@ var _ CoreService = &CoreServiceMock{} // ListUserIdentitiesForUserFunc: func(ctx context.Context, userID string) ([]models.UserIdentity, error) { // panic("mock out the ListUserIdentitiesForUser method") // }, +// ListUserPrivilegesFunc: func(ctx context.Context, userID string) ([]models.UserPrivilege, error) { +// panic("mock out the ListUserPrivileges method") +// }, // ListUsersByOrganizationFunc: func(ctx context.Context, organizationID string) ([]models.User, error) { // panic("mock out the ListUsersByOrganization method") // }, // MergeUsersFunc: func(ctx context.Context, survivingID string, retiringID string) (*models.User, error) { // panic("mock out the MergeUsers method") // }, +// PrivilegeCatalogFunc: func() []models.PrivilegeKey { +// panic("mock out the PrivilegeCatalog method") +// }, +// RevokePrivilegeFunc: func(ctx context.Context, userID string, privilege models.PrivilegeKey, revokerID string, reason string) error { +// panic("mock out the RevokePrivilege method") +// }, // UpdateAllocationResourceMappingFunc: func(ctx context.Context, allocationID string, resourceID string, resourceAmount int64, resourceTime int64) (*models.ComputeAllocationResourceMapping, error) { // panic("mock out the UpdateAllocationResourceMapping method") // }, @@ -348,6 +369,9 @@ type CoreServiceMock struct { // AttachResourceToAllocationFunc mocks the AttachResourceToAllocation method. AttachResourceToAllocationFunc func(ctx context.Context, allocationID string, resourceID string, resourceAmount int64, resourceTime int64) (*models.ComputeAllocationResourceMapping, error) + // BootstrapPrivilegeGrantFunc mocks the BootstrapPrivilegeGrant method. + BootstrapPrivilegeGrantFunc func(ctx context.Context, email string, source string) error + // CreateAuditEventFunc mocks the CreateAuditEvent method. CreateAuditEventFunc func(ctx context.Context, e *models.AuditEvent) (*models.AuditEvent, error) @@ -537,6 +561,12 @@ type CoreServiceMock struct { // GetUserIdentityBySourceAndExternalIDFunc mocks the GetUserIdentityBySourceAndExternalID method. GetUserIdentityBySourceAndExternalIDFunc func(ctx context.Context, source string, externalID string) (*models.UserIdentity, error) + // GrantPrivilegeFunc mocks the GrantPrivilege method. + GrantPrivilegeFunc func(ctx context.Context, userID string, privilege models.PrivilegeKey, granterID string, reason string) (*models.UserPrivilege, error) + + // HasPrivilegeFunc mocks the HasPrivilege method. + HasPrivilegeFunc func(ctx context.Context, userID string, privilege models.PrivilegeKey) (bool, error) + // ListAllAuditEventsFunc mocks the ListAllAuditEvents method. ListAllAuditEventsFunc func(ctx context.Context) ([]*models.AuditEvent, error) @@ -591,6 +621,9 @@ type CoreServiceMock struct { // ListOverridesForResourceFunc mocks the ListOverridesForResource method. ListOverridesForResourceFunc func(ctx context.Context, resourceID string) ([]models.ComputeAllocationMembershipResourceOverride, error) + // ListPrivilegeHoldersFunc mocks the ListPrivilegeHolders method. + ListPrivilegeHoldersFunc func(ctx context.Context, privilege models.PrivilegeKey) ([]models.UserPrivilege, error) + // ListProjectsByPIFunc mocks the ListProjectsByPI method. ListProjectsByPIFunc func(ctx context.Context, piUserID string) ([]models.Project, error) @@ -609,12 +642,21 @@ type CoreServiceMock struct { // ListUserIdentitiesForUserFunc mocks the ListUserIdentitiesForUser method. ListUserIdentitiesForUserFunc func(ctx context.Context, userID string) ([]models.UserIdentity, error) + // ListUserPrivilegesFunc mocks the ListUserPrivileges method. + ListUserPrivilegesFunc func(ctx context.Context, userID string) ([]models.UserPrivilege, error) + // ListUsersByOrganizationFunc mocks the ListUsersByOrganization method. ListUsersByOrganizationFunc func(ctx context.Context, organizationID string) ([]models.User, error) // MergeUsersFunc mocks the MergeUsers method. MergeUsersFunc func(ctx context.Context, survivingID string, retiringID string) (*models.User, error) + // PrivilegeCatalogFunc mocks the PrivilegeCatalog method. + PrivilegeCatalogFunc func() []models.PrivilegeKey + + // RevokePrivilegeFunc mocks the RevokePrivilege method. + RevokePrivilegeFunc func(ctx context.Context, userID string, privilege models.PrivilegeKey, revokerID string, reason string) error + // UpdateAllocationResourceMappingFunc mocks the UpdateAllocationResourceMapping method. UpdateAllocationResourceMappingFunc func(ctx context.Context, allocationID string, resourceID string, resourceAmount int64, resourceTime int64) (*models.ComputeAllocationResourceMapping, error) @@ -678,6 +720,15 @@ type CoreServiceMock struct { // ResourceTime is the resourceTime argument value. ResourceTime int64 } + // BootstrapPrivilegeGrant holds details about calls to the BootstrapPrivilegeGrant method. + BootstrapPrivilegeGrant []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Email is the email argument value. + Email string + // Source is the source argument value. + Source string + } // CreateAuditEvent holds details about calls to the CreateAuditEvent method. CreateAuditEvent []struct { // Ctx is the ctx argument value. @@ -1133,6 +1184,28 @@ type CoreServiceMock struct { // ExternalID is the externalID argument value. ExternalID string } + // GrantPrivilege holds details about calls to the GrantPrivilege method. + GrantPrivilege []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // UserID is the userID argument value. + UserID string + // Privilege is the privilege argument value. + Privilege models.PrivilegeKey + // GranterID is the granterID argument value. + GranterID string + // Reason is the reason argument value. + Reason string + } + // HasPrivilege holds details about calls to the HasPrivilege method. + HasPrivilege []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // UserID is the userID argument value. + UserID string + // Privilege is the privilege argument value. + Privilege models.PrivilegeKey + } // ListAllAuditEvents holds details about calls to the ListAllAuditEvents method. ListAllAuditEvents []struct { // Ctx is the ctx argument value. @@ -1253,6 +1326,13 @@ type CoreServiceMock struct { // ResourceID is the resourceID argument value. ResourceID string } + // ListPrivilegeHolders holds details about calls to the ListPrivilegeHolders method. + ListPrivilegeHolders []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Privilege is the privilege argument value. + Privilege models.PrivilegeKey + } // ListProjectsByPI holds details about calls to the ListProjectsByPI method. ListProjectsByPI []struct { // Ctx is the ctx argument value. @@ -1295,6 +1375,13 @@ type CoreServiceMock struct { // UserID is the userID argument value. UserID string } + // ListUserPrivileges holds details about calls to the ListUserPrivileges method. + ListUserPrivileges []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // UserID is the userID argument value. + UserID string + } // ListUsersByOrganization holds details about calls to the ListUsersByOrganization method. ListUsersByOrganization []struct { // Ctx is the ctx argument value. @@ -1311,6 +1398,22 @@ type CoreServiceMock struct { // RetiringID is the retiringID argument value. RetiringID string } + // PrivilegeCatalog holds details about calls to the PrivilegeCatalog method. + PrivilegeCatalog []struct { + } + // RevokePrivilege holds details about calls to the RevokePrivilege method. + RevokePrivilege []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // UserID is the userID argument value. + UserID string + // Privilege is the privilege argument value. + Privilege models.PrivilegeKey + // RevokerID is the revokerID argument value. + RevokerID string + // Reason is the reason argument value. + Reason string + } // UpdateAllocationResourceMapping holds details about calls to the UpdateAllocationResourceMapping method. UpdateAllocationResourceMapping []struct { // Ctx is the ctx argument value. @@ -1437,6 +1540,7 @@ type CoreServiceMock struct { } } lockAttachResourceToAllocation sync.RWMutex + lockBootstrapPrivilegeGrant sync.RWMutex lockCreateAuditEvent sync.RWMutex lockCreateComputeAllocation sync.RWMutex lockCreateComputeAllocationChangeRequest sync.RWMutex @@ -1500,6 +1604,8 @@ type CoreServiceMock struct { lockGetUserIdentity sync.RWMutex lockGetUserIdentityByOIDCSub sync.RWMutex lockGetUserIdentityBySourceAndExternalID sync.RWMutex + lockGrantPrivilege sync.RWMutex + lockHasPrivilege sync.RWMutex lockListAllAuditEvents sync.RWMutex lockListAllocationsForResource sync.RWMutex lockListAllocationsForUser sync.RWMutex @@ -1518,14 +1624,18 @@ type CoreServiceMock struct { lockListMembersForAllocation sync.RWMutex lockListOverridesForMembership sync.RWMutex lockListOverridesForResource sync.RWMutex + lockListPrivilegeHolders sync.RWMutex lockListProjectsByPI sync.RWMutex lockListRatesForResource sync.RWMutex lockListResourcesForAllocation sync.RWMutex lockListUsagesByUser sync.RWMutex lockListUsagesForAllocation sync.RWMutex lockListUserIdentitiesForUser sync.RWMutex + lockListUserPrivileges sync.RWMutex lockListUsersByOrganization sync.RWMutex lockMergeUsers sync.RWMutex + lockPrivilegeCatalog sync.RWMutex + lockRevokePrivilege sync.RWMutex lockUpdateAllocationResourceMapping sync.RWMutex lockUpdateComputeAllocation sync.RWMutex lockUpdateComputeAllocationChangeRequest sync.RWMutex @@ -1592,6 +1702,46 @@ func (mock *CoreServiceMock) AttachResourceToAllocationCalls() []struct { return calls } +// BootstrapPrivilegeGrant calls BootstrapPrivilegeGrantFunc. +func (mock *CoreServiceMock) BootstrapPrivilegeGrant(ctx context.Context, email string, source string) error { + if mock.BootstrapPrivilegeGrantFunc == nil { + panic("CoreServiceMock.BootstrapPrivilegeGrantFunc: method is nil but CoreService.BootstrapPrivilegeGrant was just called") + } + callInfo := struct { + Ctx context.Context + Email string + Source string + }{ + Ctx: ctx, + Email: email, + Source: source, + } + mock.lockBootstrapPrivilegeGrant.Lock() + mock.calls.BootstrapPrivilegeGrant = append(mock.calls.BootstrapPrivilegeGrant, callInfo) + mock.lockBootstrapPrivilegeGrant.Unlock() + return mock.BootstrapPrivilegeGrantFunc(ctx, email, source) +} + +// BootstrapPrivilegeGrantCalls gets all the calls that were made to BootstrapPrivilegeGrant. +// Check the length with: +// +// len(mockedCoreService.BootstrapPrivilegeGrantCalls()) +func (mock *CoreServiceMock) BootstrapPrivilegeGrantCalls() []struct { + Ctx context.Context + Email string + Source string +} { + var calls []struct { + Ctx context.Context + Email string + Source string + } + mock.lockBootstrapPrivilegeGrant.RLock() + calls = mock.calls.BootstrapPrivilegeGrant + mock.lockBootstrapPrivilegeGrant.RUnlock() + return calls +} + // CreateAuditEvent calls CreateAuditEventFunc. func (mock *CoreServiceMock) CreateAuditEvent(ctx context.Context, e *models.AuditEvent) (*models.AuditEvent, error) { if mock.CreateAuditEventFunc == nil { @@ -3888,6 +4038,94 @@ func (mock *CoreServiceMock) GetUserIdentityBySourceAndExternalIDCalls() []struc return calls } +// GrantPrivilege calls GrantPrivilegeFunc. +func (mock *CoreServiceMock) GrantPrivilege(ctx context.Context, userID string, privilege models.PrivilegeKey, granterID string, reason string) (*models.UserPrivilege, error) { + if mock.GrantPrivilegeFunc == nil { + panic("CoreServiceMock.GrantPrivilegeFunc: method is nil but CoreService.GrantPrivilege was just called") + } + callInfo := struct { + Ctx context.Context + UserID string + Privilege models.PrivilegeKey + GranterID string + Reason string + }{ + Ctx: ctx, + UserID: userID, + Privilege: privilege, + GranterID: granterID, + Reason: reason, + } + mock.lockGrantPrivilege.Lock() + mock.calls.GrantPrivilege = append(mock.calls.GrantPrivilege, callInfo) + mock.lockGrantPrivilege.Unlock() + return mock.GrantPrivilegeFunc(ctx, userID, privilege, granterID, reason) +} + +// GrantPrivilegeCalls gets all the calls that were made to GrantPrivilege. +// Check the length with: +// +// len(mockedCoreService.GrantPrivilegeCalls()) +func (mock *CoreServiceMock) GrantPrivilegeCalls() []struct { + Ctx context.Context + UserID string + Privilege models.PrivilegeKey + GranterID string + Reason string +} { + var calls []struct { + Ctx context.Context + UserID string + Privilege models.PrivilegeKey + GranterID string + Reason string + } + mock.lockGrantPrivilege.RLock() + calls = mock.calls.GrantPrivilege + mock.lockGrantPrivilege.RUnlock() + return calls +} + +// HasPrivilege calls HasPrivilegeFunc. +func (mock *CoreServiceMock) HasPrivilege(ctx context.Context, userID string, privilege models.PrivilegeKey) (bool, error) { + if mock.HasPrivilegeFunc == nil { + panic("CoreServiceMock.HasPrivilegeFunc: method is nil but CoreService.HasPrivilege was just called") + } + callInfo := struct { + Ctx context.Context + UserID string + Privilege models.PrivilegeKey + }{ + Ctx: ctx, + UserID: userID, + Privilege: privilege, + } + mock.lockHasPrivilege.Lock() + mock.calls.HasPrivilege = append(mock.calls.HasPrivilege, callInfo) + mock.lockHasPrivilege.Unlock() + return mock.HasPrivilegeFunc(ctx, userID, privilege) +} + +// HasPrivilegeCalls gets all the calls that were made to HasPrivilege. +// Check the length with: +// +// len(mockedCoreService.HasPrivilegeCalls()) +func (mock *CoreServiceMock) HasPrivilegeCalls() []struct { + Ctx context.Context + UserID string + Privilege models.PrivilegeKey +} { + var calls []struct { + Ctx context.Context + UserID string + Privilege models.PrivilegeKey + } + mock.lockHasPrivilege.RLock() + calls = mock.calls.HasPrivilege + mock.lockHasPrivilege.RUnlock() + return calls +} + // ListAllAuditEvents calls ListAllAuditEventsFunc. func (mock *CoreServiceMock) ListAllAuditEvents(ctx context.Context) ([]*models.AuditEvent, error) { if mock.ListAllAuditEventsFunc == nil { @@ -4524,6 +4762,42 @@ func (mock *CoreServiceMock) ListOverridesForResourceCalls() []struct { return calls } +// ListPrivilegeHolders calls ListPrivilegeHoldersFunc. +func (mock *CoreServiceMock) ListPrivilegeHolders(ctx context.Context, privilege models.PrivilegeKey) ([]models.UserPrivilege, error) { + if mock.ListPrivilegeHoldersFunc == nil { + panic("CoreServiceMock.ListPrivilegeHoldersFunc: method is nil but CoreService.ListPrivilegeHolders was just called") + } + callInfo := struct { + Ctx context.Context + Privilege models.PrivilegeKey + }{ + Ctx: ctx, + Privilege: privilege, + } + mock.lockListPrivilegeHolders.Lock() + mock.calls.ListPrivilegeHolders = append(mock.calls.ListPrivilegeHolders, callInfo) + mock.lockListPrivilegeHolders.Unlock() + return mock.ListPrivilegeHoldersFunc(ctx, privilege) +} + +// ListPrivilegeHoldersCalls gets all the calls that were made to ListPrivilegeHolders. +// Check the length with: +// +// len(mockedCoreService.ListPrivilegeHoldersCalls()) +func (mock *CoreServiceMock) ListPrivilegeHoldersCalls() []struct { + Ctx context.Context + Privilege models.PrivilegeKey +} { + var calls []struct { + Ctx context.Context + Privilege models.PrivilegeKey + } + mock.lockListPrivilegeHolders.RLock() + calls = mock.calls.ListPrivilegeHolders + mock.lockListPrivilegeHolders.RUnlock() + return calls +} + // ListProjectsByPI calls ListProjectsByPIFunc. func (mock *CoreServiceMock) ListProjectsByPI(ctx context.Context, piUserID string) ([]models.Project, error) { if mock.ListProjectsByPIFunc == nil { @@ -4740,6 +5014,42 @@ func (mock *CoreServiceMock) ListUserIdentitiesForUserCalls() []struct { return calls } +// ListUserPrivileges calls ListUserPrivilegesFunc. +func (mock *CoreServiceMock) ListUserPrivileges(ctx context.Context, userID string) ([]models.UserPrivilege, error) { + if mock.ListUserPrivilegesFunc == nil { + panic("CoreServiceMock.ListUserPrivilegesFunc: method is nil but CoreService.ListUserPrivileges was just called") + } + callInfo := struct { + Ctx context.Context + UserID string + }{ + Ctx: ctx, + UserID: userID, + } + mock.lockListUserPrivileges.Lock() + mock.calls.ListUserPrivileges = append(mock.calls.ListUserPrivileges, callInfo) + mock.lockListUserPrivileges.Unlock() + return mock.ListUserPrivilegesFunc(ctx, userID) +} + +// ListUserPrivilegesCalls gets all the calls that were made to ListUserPrivileges. +// Check the length with: +// +// len(mockedCoreService.ListUserPrivilegesCalls()) +func (mock *CoreServiceMock) ListUserPrivilegesCalls() []struct { + Ctx context.Context + UserID string +} { + var calls []struct { + Ctx context.Context + UserID string + } + mock.lockListUserPrivileges.RLock() + calls = mock.calls.ListUserPrivileges + mock.lockListUserPrivileges.RUnlock() + return calls +} + // ListUsersByOrganization calls ListUsersByOrganizationFunc. func (mock *CoreServiceMock) ListUsersByOrganization(ctx context.Context, organizationID string) ([]models.User, error) { if mock.ListUsersByOrganizationFunc == nil { @@ -4816,6 +5126,81 @@ func (mock *CoreServiceMock) MergeUsersCalls() []struct { return calls } +// PrivilegeCatalog calls PrivilegeCatalogFunc. +func (mock *CoreServiceMock) PrivilegeCatalog() []models.PrivilegeKey { + if mock.PrivilegeCatalogFunc == nil { + panic("CoreServiceMock.PrivilegeCatalogFunc: method is nil but CoreService.PrivilegeCatalog was just called") + } + callInfo := struct { + }{} + mock.lockPrivilegeCatalog.Lock() + mock.calls.PrivilegeCatalog = append(mock.calls.PrivilegeCatalog, callInfo) + mock.lockPrivilegeCatalog.Unlock() + return mock.PrivilegeCatalogFunc() +} + +// PrivilegeCatalogCalls gets all the calls that were made to PrivilegeCatalog. +// Check the length with: +// +// len(mockedCoreService.PrivilegeCatalogCalls()) +func (mock *CoreServiceMock) PrivilegeCatalogCalls() []struct { +} { + var calls []struct { + } + mock.lockPrivilegeCatalog.RLock() + calls = mock.calls.PrivilegeCatalog + mock.lockPrivilegeCatalog.RUnlock() + return calls +} + +// RevokePrivilege calls RevokePrivilegeFunc. +func (mock *CoreServiceMock) RevokePrivilege(ctx context.Context, userID string, privilege models.PrivilegeKey, revokerID string, reason string) error { + if mock.RevokePrivilegeFunc == nil { + panic("CoreServiceMock.RevokePrivilegeFunc: method is nil but CoreService.RevokePrivilege was just called") + } + callInfo := struct { + Ctx context.Context + UserID string + Privilege models.PrivilegeKey + RevokerID string + Reason string + }{ + Ctx: ctx, + UserID: userID, + Privilege: privilege, + RevokerID: revokerID, + Reason: reason, + } + mock.lockRevokePrivilege.Lock() + mock.calls.RevokePrivilege = append(mock.calls.RevokePrivilege, callInfo) + mock.lockRevokePrivilege.Unlock() + return mock.RevokePrivilegeFunc(ctx, userID, privilege, revokerID, reason) +} + +// RevokePrivilegeCalls gets all the calls that were made to RevokePrivilege. +// Check the length with: +// +// len(mockedCoreService.RevokePrivilegeCalls()) +func (mock *CoreServiceMock) RevokePrivilegeCalls() []struct { + Ctx context.Context + UserID string + Privilege models.PrivilegeKey + RevokerID string + Reason string +} { + var calls []struct { + Ctx context.Context + UserID string + Privilege models.PrivilegeKey + RevokerID string + Reason string + } + mock.lockRevokePrivilege.RLock() + calls = mock.calls.RevokePrivilege + mock.lockRevokePrivilege.RUnlock() + return calls +} + // UpdateAllocationResourceMapping calls UpdateAllocationResourceMappingFunc. func (mock *CoreServiceMock) UpdateAllocationResourceMapping(ctx context.Context, allocationID string, resourceID string, resourceAmount int64, resourceTime int64) (*models.ComputeAllocationResourceMapping, error) { if mock.UpdateAllocationResourceMappingFunc == nil { diff --git a/pkg/service/service.go b/pkg/service/service.go index cab154993..910e318d2 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -53,6 +53,7 @@ type Service struct { usages store.ComputeAllocationUsageStore userIdentities store.UserIdentityStore auditEvents store.AuditEventStore + privileges store.UserPrivilegeStore } // New constructs a Service backed by the supplied database handle. @@ -78,6 +79,7 @@ func New(database *sqlx.DB, eventBus *events.Bus) *Service { usages: store.NewComputeAllocationUsageStore(database), userIdentities: store.NewUserIdentityStore(database), auditEvents: store.NewAuditEventStore(database), + privileges: store.NewUserPrivilegeStore(database), } } @@ -104,6 +106,7 @@ func NewWithStores( usages store.ComputeAllocationUsageStore, userIdentities store.UserIdentityStore, auditEvents store.AuditEventStore, + privileges store.UserPrivilegeStore, ) *Service { return &Service{ db: database, @@ -125,6 +128,7 @@ func NewWithStores( usages: usages, userIdentities: userIdentities, auditEvents: auditEvents, + privileges: privileges, } } diff --git a/pkg/service/user_privilege.go b/pkg/service/user_privilege.go new file mode 100644 index 000000000..25452df82 --- /dev/null +++ b/pkg/service/user_privilege.go @@ -0,0 +1,280 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package service + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "log/slog" + + "github.com/apache/airavata-custos/pkg/models" +) + +const ( + privilegeAuditGrant = "PRIVILEGE_GRANTED" + privilegeAuditRevoke = "PRIVILEGE_REVOKED" + privilegeAuditBootstrap = "PRIVILEGE_BOOTSTRAPPED" +) + +// GrantPrivilege attaches privilege to userID. Caller (granterID) must hold +// an active privileges:grant. +func (s *Service) GrantPrivilege( + ctx context.Context, + userID string, + privilege models.PrivilegeKey, + granterID string, + reason string, +) (*models.UserPrivilege, error) { + if userID == "" { + return nil, fmt.Errorf("%w: user_id is required", ErrInvalidInput) + } + if granterID == "" { + return nil, fmt.Errorf("%w: granter_id is required", ErrInvalidInput) + } + if !models.IsKnownPrivilege(privilege) { + return nil, fmt.Errorf("%w: unknown privilege %q", ErrInvalidInput, privilege) + } + + grant := &models.UserPrivilege{ + ID: newID(), + UserID: userID, + Privilege: privilege, + GrantedBy: stringPtrOrNil(granterID), + GrantedAt: nowUTC(), + Reason: stringPtrOrNil(reason), + } + + if err := s.inTx(ctx, func(tx *sql.Tx) error { + if err := s.assertGranterTx(ctx, tx, granterID); err != nil { + return err + } + existing, err := s.privileges.FindForUpdate(ctx, tx, userID, privilege) + if err != nil { + return fmt.Errorf("lookup existing grant: %w", err) + } + if existing != nil { + return fmt.Errorf("%w: privilege %q is already active for user", ErrAlreadyExists, privilege) + } + if err := s.privileges.Create(ctx, tx, grant); err != nil { + return fmt.Errorf("insert privilege grant: %w", err) + } + return s.writePrivilegeAuditTx(ctx, tx, privilegeAuditGrant, userID, map[string]any{ + "privilege": privilege, + "actor_id": granterID, + "reason": reason, + }) + }); err != nil { + return nil, err + } + return grant, nil +} + +// RevokePrivilege removes the user's grant for privilege via DELETE. The +// full revoke history (who, when, why) is captured in audit_events. The +// meta-privilege (privileges:grant) cannot be self-revoked and cannot be +// removed from the last holder. +func (s *Service) RevokePrivilege( + ctx context.Context, + userID string, + privilege models.PrivilegeKey, + revokerID string, + reason string, +) error { + if userID == "" { + return fmt.Errorf("%w: user_id is required", ErrInvalidInput) + } + if revokerID == "" { + return fmt.Errorf("%w: revoker_id is required", ErrInvalidInput) + } + if !models.IsKnownPrivilege(privilege) { + return fmt.Errorf("%w: unknown privilege %q", ErrInvalidInput, privilege) + } + if privilege == models.PrivilegeGrant && revokerID == userID { + return fmt.Errorf("%w: cannot self-revoke %s", ErrInvalidInput, models.PrivilegeGrant) + } + + return s.inTx(ctx, func(tx *sql.Tx) error { + if err := s.assertGranterTx(ctx, tx, revokerID); err != nil { + return err + } + existing, err := s.privileges.FindForUpdate(ctx, tx, userID, privilege) + if err != nil { + return fmt.Errorf("lookup active grant: %w", err) + } + if existing == nil { + return fmt.Errorf("%w: no active grant for privilege %q", ErrNotFound, privilege) + } + if privilege == models.PrivilegeGrant { + count, err := s.privileges.CountByPrivilege(ctx, tx, models.PrivilegeGrant) + if err != nil { + return fmt.Errorf("count meta holders: %w", err) + } + if count <= 1 { + return fmt.Errorf("%w: cannot revoke the last active %s", ErrInvalidInput, models.PrivilegeGrant) + } + } + if err := s.privileges.Delete(ctx, tx, userID, privilege); err != nil { + return fmt.Errorf("delete grant: %w", err) + } + return s.writePrivilegeAuditTx(ctx, tx, privilegeAuditRevoke, userID, map[string]any{ + "privilege": privilege, + "actor_id": revokerID, + "reason": reason, + }) + }) +} + +// HasPrivilege returns true iff an active grant of the named privilege +// exists for userID. +func (s *Service) HasPrivilege(ctx context.Context, userID string, privilege models.PrivilegeKey) (bool, error) { + if userID == "" { + return false, fmt.Errorf("%w: user_id is required", ErrInvalidInput) + } + if !models.IsKnownPrivilege(privilege) { + return false, fmt.Errorf("%w: unknown privilege %q", ErrInvalidInput, privilege) + } + row, err := s.privileges.Find(ctx, userID, privilege) + if err != nil { + return false, fmt.Errorf("lookup privilege: %w", err) + } + return row != nil, nil +} + +// ListUserPrivileges returns the user's active privileges. +func (s *Service) ListUserPrivileges(ctx context.Context, userID string) ([]models.UserPrivilege, error) { + if userID == "" { + return nil, fmt.Errorf("%w: user_id is required", ErrInvalidInput) + } + rows, err := s.privileges.ListByUser(ctx, userID) + if err != nil { + return nil, fmt.Errorf("list privileges: %w", err) + } + return rows, nil +} + +// ListPrivilegeHolders returns the active holders of privilege. +func (s *Service) ListPrivilegeHolders(ctx context.Context, privilege models.PrivilegeKey) ([]models.UserPrivilege, error) { + if !models.IsKnownPrivilege(privilege) { + return nil, fmt.Errorf("%w: unknown privilege %q", ErrInvalidInput, privilege) + } + rows, err := s.privileges.ListByPrivilege(ctx, privilege) + if err != nil { + return nil, fmt.Errorf("list privilege holders: %w", err) + } + return rows, nil +} + +// PrivilegeCatalog returns the declared catalog of privilege keys. +func (s *Service) PrivilegeCatalog() []models.PrivilegeKey { + return models.KnownPrivileges() +} + +// BootstrapPrivilegeGrant is called once at server startup if +// CUSTOS_BOOTSTRAP_ADMIN_EMAIL is set. Looks up the user by email and grants +// PrivilegeGrant if no active holder exists. Returns nil on every no-op +// case (no env user, user not found, holder already present); startup must +// not fail because of this. +func (s *Service) BootstrapPrivilegeGrant(ctx context.Context, email, source string) error { + if email == "" { + return nil + } + user, err := s.users.FindByEmail(ctx, email) + if err != nil { + return fmt.Errorf("lookup bootstrap user: %w", err) + } + if user == nil { + slog.Warn("bootstrap: user not found, skipping", "email", email) + return nil + } + return s.inTx(ctx, func(tx *sql.Tx) error { + existing, err := s.privileges.CountByPrivilege(ctx, tx, models.PrivilegeGrant) + if err != nil { + return fmt.Errorf("count grant-holders: %w", err) + } + if existing > 0 { + slog.Info("bootstrap: privileges:grant already held by another user, skipping", "email", email) + return nil + } + grant := &models.UserPrivilege{ + ID: newID(), + UserID: user.ID, + Privilege: models.PrivilegeGrant, + GrantedBy: nil, + GrantedAt: nowUTC(), + Reason: stringPtrOrNil("bootstrap"), + } + if err := s.privileges.Create(ctx, tx, grant); err != nil { + return fmt.Errorf("insert bootstrap grant: %w", err) + } + if err := s.writePrivilegeAuditTx(ctx, tx, privilegeAuditBootstrap, user.ID, map[string]any{ + "privilege": models.PrivilegeGrant, + "source": source, + }); err != nil { + return fmt.Errorf("audit bootstrap grant: %w", err) + } + slog.Info("bootstrap: privileges:grant granted", "user_id", user.ID, "email", email, "source", source) + return nil + }) +} + +// assertGranterTx fails with ErrInvalidInput when actorID does not hold an +// active privileges:grant. The check runs inside the supplied tx with +// SELECT FOR UPDATE so concurrent grant + revoke serialize. +func (s *Service) assertGranterTx(ctx context.Context, tx *sql.Tx, actorID string) error { + grant, err := s.privileges.FindForUpdate(ctx, tx, actorID, models.PrivilegeGrant) + if err != nil { + return fmt.Errorf("lookup actor meta privilege: %w", err) + } + if grant == nil { + return fmt.Errorf("%w: actor does not hold %s", ErrInvalidInput, models.PrivilegeGrant) + } + return nil +} + +// writePrivilegeAuditTx records a privilege lifecycle event in audit_events. +func (s *Service) writePrivilegeAuditTx( + ctx context.Context, + tx *sql.Tx, + eventType string, + entityID string, + details map[string]any, +) error { + payload, err := json.Marshal(details) + if err != nil { + return fmt.Errorf("marshal audit details: %w", err) + } + return s.auditEvents.Create(ctx, tx, &models.AuditEvent{ + ID: newID(), + EventType: eventType, + EventTime: nowUTC(), + EntityID: entityID, + Details: string(payload), + }) +} + +// stringPtrOrNil returns a pointer to s, or nil when s is empty. Used for +// optional VARCHAR / TEXT columns whose absence we want to encode as NULL +// rather than the empty string. +func stringPtrOrNil(s string) *string { + if s == "" { + return nil + } + return &s +} diff --git a/pkg/service/user_privilege_integration_test.go b/pkg/service/user_privilege_integration_test.go new file mode 100644 index 000000000..ca9cda640 --- /dev/null +++ b/pkg/service/user_privilege_integration_test.go @@ -0,0 +1,278 @@ +//go:build integration + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package service + +import ( + "errors" + "strings" + "testing" + + "github.com/google/uuid" + + "github.com/apache/airavata-custos/pkg/models" +) + +func TestGrantPrivilege_HappyPath(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + + granter := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, granter) + + grant, err := svc.GrantPrivilege(ctx(), target, models.PrivilegeAMIERead, granter, "ops view") + if err != nil { + t.Fatalf("GrantPrivilege: %v", err) + } + if grant.UserID != target || grant.Privilege != models.PrivilegeAMIERead { + t.Errorf("grant payload mismatch: %+v", grant) + } + + has, err := svc.HasPrivilege(ctx(), target, models.PrivilegeAMIERead) + if err != nil || !has { + t.Errorf("HasPrivilege after grant: has=%v err=%v", has, err) + } + if got := countAuditEventsOfType(t, database, "PRIVILEGE_GRANTED", target); got != 1 { + t.Errorf("audit PRIVILEGE_GRANTED: got %d, want 1", got) + } +} + +func TestGrantPrivilege_RejectsGranterWithoutMeta(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + granter := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + // granter has NO privileges:grant. + _, err := svc.GrantPrivilege(ctx(), target, models.PrivilegeAMIERead, granter, "") + if !errors.Is(err, ErrInvalidInput) { + t.Errorf("granter w/o meta: got err=%v, want ErrInvalidInput", err) + } + if !strings.Contains(err.Error(), string(models.PrivilegeGrant)) { + t.Errorf("error should mention the missing privilege, got: %v", err) + } +} + +func TestGrantPrivilege_RejectsUnknownPrivilege(t *testing.T) { + svc := newTestService(setupTestDB(t)) + _, err := svc.GrantPrivilege(ctx(), "u", "bogus:thing", "g", "") + if !errors.Is(err, ErrInvalidInput) { + t.Errorf("unknown privilege: got err=%v, want ErrInvalidInput", err) + } +} + +func TestGrantPrivilege_DuplicateActiveReturnsAlreadyExists(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + granter := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, granter) + + if _, err := svc.GrantPrivilege(ctx(), target, models.PrivilegeAMIERead, granter, ""); err != nil { + t.Fatalf("first grant: %v", err) + } + _, err := svc.GrantPrivilege(ctx(), target, models.PrivilegeAMIERead, granter, "") + if !errors.Is(err, ErrAlreadyExists) { + t.Errorf("duplicate grant: got err=%v, want ErrAlreadyExists", err) + } +} + +func TestRevokePrivilege_HappyPath(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + granter := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, granter) + if _, err := svc.GrantPrivilege(ctx(), target, models.PrivilegeAMIERead, granter, ""); err != nil { + t.Fatalf("grant: %v", err) + } + if err := svc.RevokePrivilege(ctx(), target, models.PrivilegeAMIERead, granter, "no longer needed"); err != nil { + t.Fatalf("RevokePrivilege: %v", err) + } + if has, err := svc.HasPrivilege(ctx(), target, models.PrivilegeAMIERead); err != nil || has { + t.Errorf("HasPrivilege after revoke: has=%v err=%v", has, err) + } + if got := countAuditEventsOfType(t, database, "PRIVILEGE_REVOKED", target); got != 1 { + t.Errorf("audit PRIVILEGE_REVOKED: got %d, want 1", got) + } +} + +func TestRevokePrivilege_RejectsSelfRevokeOfMeta(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + caller := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, caller) + + err := svc.RevokePrivilege(ctx(), caller, models.PrivilegeGrant, caller, "self") + if !errors.Is(err, ErrInvalidInput) { + t.Errorf("self-revoke meta: got err=%v, want ErrInvalidInput", err) + } +} + +func TestRevokePrivilege_NoActiveGrant_ReturnsNotFound(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + granter := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, granter) + err := svc.RevokePrivilege(ctx(), target, models.PrivilegeAMIERead, granter, "") + if !errors.Is(err, ErrNotFound) { + t.Errorf("revoke with no active grant: got err=%v, want ErrNotFound", err) + } +} + +func TestHasPrivilege_ReturnsFalseForUngranted(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + user := seedUser(t, database, "[email protected]") + has, err := svc.HasPrivilege(ctx(), user, models.PrivilegeAMIERead) + if err != nil { + t.Fatalf("HasPrivilege: %v", err) + } + if has { + t.Errorf("HasPrivilege for ungranted user: got true, want false") + } +} + +func TestListUserPrivileges_ReturnsActiveOnly(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + granter := seedUser(t, database, "[email protected]") + target := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, granter) + if _, err := svc.GrantPrivilege(ctx(), target, models.PrivilegeAMIERead, granter, ""); err != nil { + t.Fatalf("grant: %v", err) + } + if _, err := svc.GrantPrivilege(ctx(), target, models.PrivilegeHPCRead, granter, ""); err != nil { + t.Fatalf("grant 2: %v", err) + } + if err := svc.RevokePrivilege(ctx(), target, models.PrivilegeHPCRead, granter, ""); err != nil { + t.Fatalf("revoke: %v", err) + } + rows, err := svc.ListUserPrivileges(ctx(), target) + if err != nil { + t.Fatalf("ListUserPrivileges: %v", err) + } + if len(rows) != 1 || rows[0].Privilege != models.PrivilegeAMIERead { + t.Errorf("active set: got %v, want [amie:read]", rows) + } +} + +func TestListPrivilegeHolders(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + granter := seedUser(t, database, "[email protected]") + a := seedUser(t, database, "[email protected]") + b := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, granter) + if _, err := svc.GrantPrivilege(ctx(), a, models.PrivilegeAMIERead, granter, ""); err != nil { + t.Fatalf("grant a: %v", err) + } + if _, err := svc.GrantPrivilege(ctx(), b, models.PrivilegeAMIERead, granter, ""); err != nil { + t.Fatalf("grant b: %v", err) + } + rows, err := svc.ListPrivilegeHolders(ctx(), models.PrivilegeAMIERead) + if err != nil { + t.Fatalf("ListPrivilegeHolders: %v", err) + } + if len(rows) != 2 { + t.Errorf("holders: got %d, want 2", len(rows)) + } +} + +func TestPrivilegeCatalog(t *testing.T) { + svc := newTestService(setupTestDB(t)) + cat := svc.PrivilegeCatalog() + seen := map[models.PrivilegeKey]bool{} + for _, k := range cat { + seen[k] = true + } + if !seen[models.PrivilegeGrant] || !seen[models.PrivilegeAMIERead] { + t.Errorf("catalog missing expected keys, got: %v", cat) + } +} + +func TestBootstrapPrivilegeGrant_HappyPath(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + user := seedUser(t, database, "[email protected]") + if err := svc.BootstrapPrivilegeGrant(ctx(), "[email protected]", "env:TEST"); err != nil { + t.Fatalf("BootstrapPrivilegeGrant: %v", err) + } + if has, err := svc.HasPrivilege(ctx(), user, models.PrivilegeGrant); err != nil || !has { + t.Errorf("HasPrivilege after bootstrap: has=%v err=%v", has, err) + } + if got := countAuditEventsOfType(t, database, "PRIVILEGE_BOOTSTRAPPED", user); got != 1 { + t.Errorf("audit PRIVILEGE_BOOTSTRAPPED: got %d, want 1", got) + } +} + +func TestBootstrapPrivilegeGrant_NoOpWhenHolderExists(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + existing := seedUser(t, database, "[email protected]") + seedPrivilegeGrant(t, database, existing) + if err := svc.BootstrapPrivilegeGrant(ctx(), "[email protected]", "env:TEST"); err != nil { + t.Fatalf("BootstrapPrivilegeGrant: %v", err) + } + rows, err := svc.ListPrivilegeHolders(ctx(), models.PrivilegeGrant) + if err != nil { + t.Fatalf("ListPrivilegeHolders: %v", err) + } + if len(rows) != 1 || rows[0].UserID != existing { + t.Errorf("holders after bootstrap no-op: got %v, want only [%s]", rows, existing) + } +} + +func TestBootstrapPrivilegeGrant_NoOpWhenEmailNotFound(t *testing.T) { + database := setupTestDB(t) + svc := newTestService(database) + if err := svc.BootstrapPrivilegeGrant(ctx(), "[email protected]", "env:TEST"); err != nil { + t.Fatalf("BootstrapPrivilegeGrant: %v", err) + } + rows, err := svc.ListPrivilegeHolders(ctx(), models.PrivilegeGrant) + if err != nil { + t.Fatalf("ListPrivilegeHolders: %v", err) + } + if len(rows) != 0 { + t.Errorf("holders after missing-email bootstrap: got %d, want 0", len(rows)) + } +} + +func TestSchema_RejectsTwoActivePrivilegesOfSameKey(t *testing.T) { + database := setupTestDB(t) + user := seedUser(t, database, "[email protected]") + if _, err := database.Exec( + `INSERT INTO user_privileges (id, user_id, privilege, granted_at) VALUES (?, ?, ?, NOW(6))`, + uuid.NewString(), user, string(models.PrivilegeAMIERead), + ); err != nil { + t.Fatalf("first insert: %v", err) + } + _, err := database.Exec( + `INSERT INTO user_privileges (id, user_id, privilege, granted_at) VALUES (?, ?, ?, NOW(6))`, + uuid.NewString(), user, string(models.PrivilegeAMIERead), + ) + if err == nil { + t.Fatal("expected UNIQUE violation, got nil") + } + if !strings.Contains(strings.ToLower(err.Error()), "duplicate") && + !strings.Contains(strings.ToLower(err.Error()), "unique") { + t.Errorf("expected duplicate/unique error, got: %v", err) + } +}
