This is an automated email from the ASF dual-hosted git repository.

rcordier pushed a commit to branch postgresql
in repository https://gitbox.apache.org/repos/asf/james-project.git

commit c1fc08fee6bb2d92d057a29e74bf17b835dcff5a
Author: vttran <[email protected]>
AuthorDate: Thu Nov 2 10:00:11 2023 +0700

    JAMES-2586 PostgresTableManager support create table when enable row level 
security
---
 .../james/backends/postgres/PostgresTable.java     | 24 +++++++-
 .../backends/postgres/PostgresTableManager.java    | 24 +++++++-
 .../backends/postgres/utils/PostgresExecutor.java  |  4 ++
 .../postgres/PostgresTableManagerTest.java         | 65 ++++++++++++++++++----
 4 files changed, 102 insertions(+), 15 deletions(-)

diff --git 
a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTable.java
 
b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTable.java
index 0e8c22ed43..331f530ad7 100644
--- 
a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTable.java
+++ 
b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTable.java
@@ -30,7 +30,7 @@ public class PostgresTable {
 
     @FunctionalInterface
     public interface RequireCreateTableStep {
-        PostgresTable createTableStep(CreateTableFunction createTableFunction);
+        RequireRowLevelSecurity createTableStep(CreateTableFunction 
createTableFunction);
     }
 
 
@@ -39,17 +39,32 @@ public class PostgresTable {
         DDLQuery createTable(DSLContext dsl, String tableName);
     }
 
+    @FunctionalInterface
+    public interface RequireRowLevelSecurity {
+        PostgresTable enableRLS(boolean enableRowLevelSecurity);
+
+        default PostgresTable noRLS() {
+            return enableRLS(false);
+        }
+
+        default PostgresTable enableRLS() {
+            return enableRLS(true);
+        }
+    }
+
     public static RequireCreateTableStep name(String tableName) {
         Preconditions.checkNotNull(tableName);
 
-        return createTableFunction -> new PostgresTable(tableName, dsl -> 
createTableFunction.createTable(dsl, tableName));
+        return createTableFunction -> enableRLS -> new 
PostgresTable(tableName, enableRLS, dsl -> createTableFunction.createTable(dsl, 
tableName));
     }
 
     private final String name;
+    private final boolean enableRowLevelSecurity;
     private final Function<DSLContext, DDLQuery> createTableStepFunction;
 
-    private PostgresTable(String name, Function<DSLContext, DDLQuery> 
createTableStepFunction) {
+    private PostgresTable(String name, boolean enableRowLevelSecurity, 
Function<DSLContext, DDLQuery> createTableStepFunction) {
         this.name = name;
+        this.enableRowLevelSecurity = enableRowLevelSecurity;
         this.createTableStepFunction = createTableStepFunction;
     }
 
@@ -62,4 +77,7 @@ public class PostgresTable {
         return createTableStepFunction;
     }
 
+    public boolean isEnableRowLevelSecurity() {
+        return enableRowLevelSecurity;
+    }
 }
diff --git 
a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTableManager.java
 
b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTableManager.java
index 23749fed72..c563b5918b 100644
--- 
a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTableManager.java
+++ 
b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTableManager.java
@@ -24,6 +24,7 @@ import org.jooq.exception.DataAccessException;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import io.r2dbc.spi.Result;
 import reactor.core.publisher.Flux;
 import reactor.core.publisher.Mono;
 
@@ -41,10 +42,10 @@ public class PostgresTableManager {
         return postgresExecutor.dslContext()
             .flatMap(dsl -> Flux.fromIterable(module.tables())
                 .flatMap(table -> 
Mono.from(table.getCreateTableStepFunction().apply(dsl))
+                    .then(alterTableEnableRLSIfNeed(table))
                     .doOnSuccess(any -> LOGGER.info("Table {} created", 
table.getName()))
                     .onErrorResume(DataAccessException.class, exception -> {
                         if 
(exception.getMessage().contains(String.format("\"%s\" already exists", 
table.getName()))) {
-                            LOGGER.info("Table {} already exists", 
table.getName());
                             return Mono.empty();
                         }
                         return Mono.error(exception);
@@ -53,6 +54,26 @@ public class PostgresTableManager {
                 .then());
     }
 
+    private Mono<Void> alterTableEnableRLSIfNeed(PostgresTable table) {
+        if (table.isEnableRowLevelSecurity()) {
+            return alterTableEnableRLS(table);
+        }
+        return Mono.empty();
+    }
+
+    public Mono<Void> alterTableEnableRLS(PostgresTable table) {
+        return postgresExecutor.connection()
+            .flatMapMany(con -> 
con.createStatement(getAlterRLSStatement(table.getName())).execute())
+            .flatMap(Result::getRowsUpdated)
+            .then();
+    }
+
+    private String getAlterRLSStatement(String tableName) {
+        return "SET app.current_domain = ''; ALTER TABLE " + tableName + " ADD 
DOMAIN varchar(255) not null DEFAULT 
current_setting('app.current_domain')::text;" +
+            "ALTER TABLE " + tableName + " ENABLE ROW LEVEL SECURITY; " +
+            "CREATE POLICY DOMAIN_" + tableName + "_POLICY ON " + tableName + 
" USING (DOMAIN = current_setting('app.current_domain')::text);";
+    }
+
     public Mono<Void> truncate() {
         return postgresExecutor.dslContext()
             .flatMap(dsl -> Flux.fromIterable(module.tables())
@@ -69,7 +90,6 @@ public class PostgresTableManager {
                     .doOnSuccess(any -> LOGGER.info("Index {} created", 
index.getName()))
                     .onErrorResume(DataAccessException.class, exception -> {
                         if 
(exception.getMessage().contains(String.format("\"%s\" already exists", 
index.getName()))) {
-                            LOGGER.info("Index {} already exists", 
index.getName());
                             return Mono.empty();
                         }
                         return Mono.error(exception);
diff --git 
a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/utils/PostgresExecutor.java
 
b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/utils/PostgresExecutor.java
index f3a86d41a3..81a8cc8d2c 100644
--- 
a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/utils/PostgresExecutor.java
+++ 
b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/utils/PostgresExecutor.java
@@ -44,4 +44,8 @@ public class PostgresExecutor {
     public Mono<DSLContext> dslContext() {
         return connection.map(con -> DSL.using(con, PGSQL_DIALECT, SETTINGS));
     }
+
+    public Mono<Connection> connection() {
+        return connection;
+    }
 }
diff --git 
a/backends-common/postgres/src/test/java/org/apache/james/backends/postgres/PostgresTableManagerTest.java
 
b/backends-common/postgres/src/test/java/org/apache/james/backends/postgres/PostgresTableManagerTest.java
index 3a853fbe54..62eae38316 100644
--- 
a/backends-common/postgres/src/test/java/org/apache/james/backends/postgres/PostgresTableManagerTest.java
+++ 
b/backends-common/postgres/src/test/java/org/apache/james/backends/postgres/PostgresTableManagerTest.java
@@ -87,7 +87,8 @@ public class PostgresTableManagerTest {
             .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
                 .column("colum1", SQLDataType.UUID.notNull())
                 .column("colum2", SQLDataType.INTEGER)
-                .column("colum3", SQLDataType.VARCHAR(255).notNull()));
+                .column("colum3", SQLDataType.VARCHAR(255).notNull()))
+            .noRLS();
 
         PostgresModule module = PostgresModule.table(table);
 
@@ -109,12 +110,12 @@ public class PostgresTableManagerTest {
 
         PostgresTable table1 = PostgresTable.name(tableName1)
             .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
-                .column("columA", SQLDataType.UUID.notNull()));
+                .column("columA", SQLDataType.UUID.notNull())).noRLS();
 
         String tableName2 = "tableName2";
         PostgresTable table2 = PostgresTable.name(tableName2)
             .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
-                .column("columB", SQLDataType.INTEGER));
+                .column("columB", SQLDataType.INTEGER)).noRLS();
 
         PostgresTableManager testee = 
tableManagerFactory.apply(PostgresModule.table(table1, table2));
 
@@ -135,7 +136,7 @@ public class PostgresTableManagerTest {
 
         PostgresTable table1 = PostgresTable.name(tableName1)
             .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
-                .column("columA", SQLDataType.UUID.notNull()));
+                .column("columA", SQLDataType.UUID.notNull())).noRLS();
 
         PostgresTableManager testee = 
tableManagerFactory.apply(PostgresModule.table(table1));
 
@@ -151,7 +152,7 @@ public class PostgresTableManagerTest {
         String tableName1 = "tableName1";
         PostgresTable table1 = PostgresTable.name(tableName1)
             .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
-                .column("columA", SQLDataType.UUID.notNull()));
+                .column("columA", SQLDataType.UUID.notNull())).noRLS();
 
         tableManagerFactory.apply(PostgresModule.table(table1))
             .initializeTables()
@@ -159,7 +160,7 @@ public class PostgresTableManagerTest {
 
         PostgresTable table1Changed = PostgresTable.name(tableName1)
             .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
-                .column("columB", SQLDataType.INTEGER));
+                .column("columB", SQLDataType.INTEGER)).noRLS();
 
         tableManagerFactory.apply(PostgresModule.table(table1Changed))
             .initializeTables()
@@ -178,7 +179,8 @@ public class PostgresTableManagerTest {
             .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
                 .column("colum1", SQLDataType.UUID.notNull())
                 .column("colum2", SQLDataType.INTEGER)
-                .column("colum3", SQLDataType.VARCHAR(255).notNull()));
+                .column("colum3", SQLDataType.VARCHAR(255).notNull()))
+            .noRLS();
 
         String indexName = "idx_test_1";
         PostgresIndex index = PostgresIndex.name(indexName)
@@ -210,7 +212,8 @@ public class PostgresTableManagerTest {
             .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
                 .column("colum1", SQLDataType.UUID.notNull())
                 .column("colum2", SQLDataType.INTEGER)
-                .column("colum3", SQLDataType.VARCHAR(255).notNull()));
+                .column("colum3", SQLDataType.VARCHAR(255).notNull()))
+            .noRLS();
 
         String indexName1 = "idx_test_1";
         PostgresIndex index1 = PostgresIndex.name(indexName1)
@@ -247,7 +250,8 @@ public class PostgresTableManagerTest {
             .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
                 .column("colum1", SQLDataType.UUID.notNull())
                 .column("colum2", SQLDataType.INTEGER)
-                .column("colum3", SQLDataType.VARCHAR(255).notNull()));
+                .column("colum3", SQLDataType.VARCHAR(255).notNull()))
+            .noRLS();
 
         String indexName = "idx_test_1";
         PostgresIndex index = PostgresIndex.name(indexName)
@@ -275,7 +279,7 @@ public class PostgresTableManagerTest {
         String tableName1 = "tbn1";
         PostgresTable table1 = PostgresTable.name(tableName1)
             .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
-                .column("column1", SQLDataType.INTEGER.notNull()));
+                .column("column1", SQLDataType.INTEGER.notNull())).noRLS();
 
         PostgresTableManager testee = 
tableManagerFactory.apply(PostgresModule.table(table1));
         testee.initializeTables()
@@ -312,6 +316,47 @@ public class PostgresTableManagerTest {
         assertThat(getTotalRecordInDB.get()).isEqualTo(0L);
     }
 
+    @Test
+    void createTableShouldSucceedWhenEnableRLS() {
+        String tableName = "tbn1";
+
+        PostgresTable table = PostgresTable.name(tableName)
+            .createTableStep((dsl, tbn) -> dsl.createTable(tbn)
+                .column("clm1", SQLDataType.UUID.notNull())
+                .column("clm2", SQLDataType.VARCHAR(255).notNull()))
+            .enableRLS();
+
+        PostgresModule module = PostgresModule.table(table);
+
+        PostgresTableManager testee = tableManagerFactory.apply(module);
+
+        testee.initializeTables()
+            .block();
+
+        assertThat(getColumnNameAndDataType(tableName))
+            .containsExactlyInAnyOrder(
+                Pair.of("clm1", "uuid"),
+                Pair.of("clm2", "character varying"),
+                Pair.of("domain", "character varying"));
+
+        List<Pair<String, Boolean>> pgClassCheckResult = 
Flux.usingWhen(connectionFactory.create(),
+                connection -> Mono.from(connection.createStatement("select 
relname, relrowsecurity " +
+                            "from pg_class " +
+                            "where oid = 'tbn1'::regclass;;")
+                        .execute())
+                    .flatMapMany(result ->
+                        result.map((row, rowMetadata) ->
+                            Pair.of(row.get("relname", String.class),
+                                row.get("relrowsecurity", Boolean.class)))),
+                Connection::close)
+            .collectList()
+            .block();
+
+        assertThat(pgClassCheckResult)
+            .containsExactlyInAnyOrder(
+                Pair.of("tbn1", true));
+    }
+
 
     private List<Pair<String, String>> getColumnNameAndDataType(String 
tableName) {
         return Flux.usingWhen(connectionFactory.create(),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to