This is an automated email from the ASF dual-hosted git repository. rcordier pushed a commit to branch postgresql in repository https://gitbox.apache.org/repos/asf/james-project.git
commit c1fc08fee6bb2d92d057a29e74bf17b835dcff5a Author: vttran <[email protected]> AuthorDate: Thu Nov 2 10:00:11 2023 +0700 JAMES-2586 PostgresTableManager support create table when enable row level security --- .../james/backends/postgres/PostgresTable.java | 24 +++++++- .../backends/postgres/PostgresTableManager.java | 24 +++++++- .../backends/postgres/utils/PostgresExecutor.java | 4 ++ .../postgres/PostgresTableManagerTest.java | 65 ++++++++++++++++++---- 4 files changed, 102 insertions(+), 15 deletions(-) diff --git a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTable.java b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTable.java index 0e8c22ed43..331f530ad7 100644 --- a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTable.java +++ b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTable.java @@ -30,7 +30,7 @@ public class PostgresTable { @FunctionalInterface public interface RequireCreateTableStep { - PostgresTable createTableStep(CreateTableFunction createTableFunction); + RequireRowLevelSecurity createTableStep(CreateTableFunction createTableFunction); } @@ -39,17 +39,32 @@ public class PostgresTable { DDLQuery createTable(DSLContext dsl, String tableName); } + @FunctionalInterface + public interface RequireRowLevelSecurity { + PostgresTable enableRLS(boolean enableRowLevelSecurity); + + default PostgresTable noRLS() { + return enableRLS(false); + } + + default PostgresTable enableRLS() { + return enableRLS(true); + } + } + public static RequireCreateTableStep name(String tableName) { Preconditions.checkNotNull(tableName); - return createTableFunction -> new PostgresTable(tableName, dsl -> createTableFunction.createTable(dsl, tableName)); + return createTableFunction -> enableRLS -> new PostgresTable(tableName, enableRLS, dsl -> createTableFunction.createTable(dsl, tableName)); } private final String name; + private final boolean enableRowLevelSecurity; private final Function<DSLContext, DDLQuery> createTableStepFunction; - private PostgresTable(String name, Function<DSLContext, DDLQuery> createTableStepFunction) { + private PostgresTable(String name, boolean enableRowLevelSecurity, Function<DSLContext, DDLQuery> createTableStepFunction) { this.name = name; + this.enableRowLevelSecurity = enableRowLevelSecurity; this.createTableStepFunction = createTableStepFunction; } @@ -62,4 +77,7 @@ public class PostgresTable { return createTableStepFunction; } + public boolean isEnableRowLevelSecurity() { + return enableRowLevelSecurity; + } } diff --git a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTableManager.java b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTableManager.java index 23749fed72..c563b5918b 100644 --- a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTableManager.java +++ b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/PostgresTableManager.java @@ -24,6 +24,7 @@ import org.jooq.exception.DataAccessException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.r2dbc.spi.Result; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -41,10 +42,10 @@ public class PostgresTableManager { return postgresExecutor.dslContext() .flatMap(dsl -> Flux.fromIterable(module.tables()) .flatMap(table -> Mono.from(table.getCreateTableStepFunction().apply(dsl)) + .then(alterTableEnableRLSIfNeed(table)) .doOnSuccess(any -> LOGGER.info("Table {} created", table.getName())) .onErrorResume(DataAccessException.class, exception -> { if (exception.getMessage().contains(String.format("\"%s\" already exists", table.getName()))) { - LOGGER.info("Table {} already exists", table.getName()); return Mono.empty(); } return Mono.error(exception); @@ -53,6 +54,26 @@ public class PostgresTableManager { .then()); } + private Mono<Void> alterTableEnableRLSIfNeed(PostgresTable table) { + if (table.isEnableRowLevelSecurity()) { + return alterTableEnableRLS(table); + } + return Mono.empty(); + } + + public Mono<Void> alterTableEnableRLS(PostgresTable table) { + return postgresExecutor.connection() + .flatMapMany(con -> con.createStatement(getAlterRLSStatement(table.getName())).execute()) + .flatMap(Result::getRowsUpdated) + .then(); + } + + private String getAlterRLSStatement(String tableName) { + return "SET app.current_domain = ''; ALTER TABLE " + tableName + " ADD DOMAIN varchar(255) not null DEFAULT current_setting('app.current_domain')::text;" + + "ALTER TABLE " + tableName + " ENABLE ROW LEVEL SECURITY; " + + "CREATE POLICY DOMAIN_" + tableName + "_POLICY ON " + tableName + " USING (DOMAIN = current_setting('app.current_domain')::text);"; + } + public Mono<Void> truncate() { return postgresExecutor.dslContext() .flatMap(dsl -> Flux.fromIterable(module.tables()) @@ -69,7 +90,6 @@ public class PostgresTableManager { .doOnSuccess(any -> LOGGER.info("Index {} created", index.getName())) .onErrorResume(DataAccessException.class, exception -> { if (exception.getMessage().contains(String.format("\"%s\" already exists", index.getName()))) { - LOGGER.info("Index {} already exists", index.getName()); return Mono.empty(); } return Mono.error(exception); diff --git a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/utils/PostgresExecutor.java b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/utils/PostgresExecutor.java index f3a86d41a3..81a8cc8d2c 100644 --- a/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/utils/PostgresExecutor.java +++ b/backends-common/postgres/src/main/java/org/apache/james/backends/postgres/utils/PostgresExecutor.java @@ -44,4 +44,8 @@ public class PostgresExecutor { public Mono<DSLContext> dslContext() { return connection.map(con -> DSL.using(con, PGSQL_DIALECT, SETTINGS)); } + + public Mono<Connection> connection() { + return connection; + } } diff --git a/backends-common/postgres/src/test/java/org/apache/james/backends/postgres/PostgresTableManagerTest.java b/backends-common/postgres/src/test/java/org/apache/james/backends/postgres/PostgresTableManagerTest.java index 3a853fbe54..62eae38316 100644 --- a/backends-common/postgres/src/test/java/org/apache/james/backends/postgres/PostgresTableManagerTest.java +++ b/backends-common/postgres/src/test/java/org/apache/james/backends/postgres/PostgresTableManagerTest.java @@ -87,7 +87,8 @@ public class PostgresTableManagerTest { .createTableStep((dsl, tbn) -> dsl.createTable(tbn) .column("colum1", SQLDataType.UUID.notNull()) .column("colum2", SQLDataType.INTEGER) - .column("colum3", SQLDataType.VARCHAR(255).notNull())); + .column("colum3", SQLDataType.VARCHAR(255).notNull())) + .noRLS(); PostgresModule module = PostgresModule.table(table); @@ -109,12 +110,12 @@ public class PostgresTableManagerTest { PostgresTable table1 = PostgresTable.name(tableName1) .createTableStep((dsl, tbn) -> dsl.createTable(tbn) - .column("columA", SQLDataType.UUID.notNull())); + .column("columA", SQLDataType.UUID.notNull())).noRLS(); String tableName2 = "tableName2"; PostgresTable table2 = PostgresTable.name(tableName2) .createTableStep((dsl, tbn) -> dsl.createTable(tbn) - .column("columB", SQLDataType.INTEGER)); + .column("columB", SQLDataType.INTEGER)).noRLS(); PostgresTableManager testee = tableManagerFactory.apply(PostgresModule.table(table1, table2)); @@ -135,7 +136,7 @@ public class PostgresTableManagerTest { PostgresTable table1 = PostgresTable.name(tableName1) .createTableStep((dsl, tbn) -> dsl.createTable(tbn) - .column("columA", SQLDataType.UUID.notNull())); + .column("columA", SQLDataType.UUID.notNull())).noRLS(); PostgresTableManager testee = tableManagerFactory.apply(PostgresModule.table(table1)); @@ -151,7 +152,7 @@ public class PostgresTableManagerTest { String tableName1 = "tableName1"; PostgresTable table1 = PostgresTable.name(tableName1) .createTableStep((dsl, tbn) -> dsl.createTable(tbn) - .column("columA", SQLDataType.UUID.notNull())); + .column("columA", SQLDataType.UUID.notNull())).noRLS(); tableManagerFactory.apply(PostgresModule.table(table1)) .initializeTables() @@ -159,7 +160,7 @@ public class PostgresTableManagerTest { PostgresTable table1Changed = PostgresTable.name(tableName1) .createTableStep((dsl, tbn) -> dsl.createTable(tbn) - .column("columB", SQLDataType.INTEGER)); + .column("columB", SQLDataType.INTEGER)).noRLS(); tableManagerFactory.apply(PostgresModule.table(table1Changed)) .initializeTables() @@ -178,7 +179,8 @@ public class PostgresTableManagerTest { .createTableStep((dsl, tbn) -> dsl.createTable(tbn) .column("colum1", SQLDataType.UUID.notNull()) .column("colum2", SQLDataType.INTEGER) - .column("colum3", SQLDataType.VARCHAR(255).notNull())); + .column("colum3", SQLDataType.VARCHAR(255).notNull())) + .noRLS(); String indexName = "idx_test_1"; PostgresIndex index = PostgresIndex.name(indexName) @@ -210,7 +212,8 @@ public class PostgresTableManagerTest { .createTableStep((dsl, tbn) -> dsl.createTable(tbn) .column("colum1", SQLDataType.UUID.notNull()) .column("colum2", SQLDataType.INTEGER) - .column("colum3", SQLDataType.VARCHAR(255).notNull())); + .column("colum3", SQLDataType.VARCHAR(255).notNull())) + .noRLS(); String indexName1 = "idx_test_1"; PostgresIndex index1 = PostgresIndex.name(indexName1) @@ -247,7 +250,8 @@ public class PostgresTableManagerTest { .createTableStep((dsl, tbn) -> dsl.createTable(tbn) .column("colum1", SQLDataType.UUID.notNull()) .column("colum2", SQLDataType.INTEGER) - .column("colum3", SQLDataType.VARCHAR(255).notNull())); + .column("colum3", SQLDataType.VARCHAR(255).notNull())) + .noRLS(); String indexName = "idx_test_1"; PostgresIndex index = PostgresIndex.name(indexName) @@ -275,7 +279,7 @@ public class PostgresTableManagerTest { String tableName1 = "tbn1"; PostgresTable table1 = PostgresTable.name(tableName1) .createTableStep((dsl, tbn) -> dsl.createTable(tbn) - .column("column1", SQLDataType.INTEGER.notNull())); + .column("column1", SQLDataType.INTEGER.notNull())).noRLS(); PostgresTableManager testee = tableManagerFactory.apply(PostgresModule.table(table1)); testee.initializeTables() @@ -312,6 +316,47 @@ public class PostgresTableManagerTest { assertThat(getTotalRecordInDB.get()).isEqualTo(0L); } + @Test + void createTableShouldSucceedWhenEnableRLS() { + String tableName = "tbn1"; + + PostgresTable table = PostgresTable.name(tableName) + .createTableStep((dsl, tbn) -> dsl.createTable(tbn) + .column("clm1", SQLDataType.UUID.notNull()) + .column("clm2", SQLDataType.VARCHAR(255).notNull())) + .enableRLS(); + + PostgresModule module = PostgresModule.table(table); + + PostgresTableManager testee = tableManagerFactory.apply(module); + + testee.initializeTables() + .block(); + + assertThat(getColumnNameAndDataType(tableName)) + .containsExactlyInAnyOrder( + Pair.of("clm1", "uuid"), + Pair.of("clm2", "character varying"), + Pair.of("domain", "character varying")); + + List<Pair<String, Boolean>> pgClassCheckResult = Flux.usingWhen(connectionFactory.create(), + connection -> Mono.from(connection.createStatement("select relname, relrowsecurity " + + "from pg_class " + + "where oid = 'tbn1'::regclass;;") + .execute()) + .flatMapMany(result -> + result.map((row, rowMetadata) -> + Pair.of(row.get("relname", String.class), + row.get("relrowsecurity", Boolean.class)))), + Connection::close) + .collectList() + .block(); + + assertThat(pgClassCheckResult) + .containsExactlyInAnyOrder( + Pair.of("tbn1", true)); + } + private List<Pair<String, String>> getColumnNameAndDataType(String tableName) { return Flux.usingWhen(connectionFactory.create(), --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
