This is an automated email from the ASF dual-hosted git repository.

wuzhiguo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/bigtop-manager.git


The following commit(s) were added to refs/heads/main by this push:
     new f5bc0c2  BIGTOP-4227: Address SQL injection (#70)
f5bc0c2 is described below

commit f5bc0c299a9b9fa2c44ffaa8f9db54989e437d0e
Author: timyuer <[email protected]>
AuthorDate: Sat Sep 14 23:04:46 2024 +0800

    BIGTOP-4227: Address SQL injection (#70)
---
 .../dao/interceptor/AuditingInterceptor.java       |  62 ++++++------
 .../apache/bigtop/manager/dao/sql/SQLBuilder.java  | 112 ++++++++++++---------
 2 files changed, 94 insertions(+), 80 deletions(-)

diff --git 
a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/interceptor/AuditingInterceptor.java
 
b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/interceptor/AuditingInterceptor.java
index 4cc536c..a178fe5 100644
--- 
a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/interceptor/AuditingInterceptor.java
+++ 
b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/interceptor/AuditingInterceptor.java
@@ -68,22 +68,23 @@ public class AuditingInterceptor implements Interceptor {
         Object parameter = invocation.getArgs()[1];
         log.debug("sqlCommandType {}", sqlCommandType);
 
-        Collection<Object> objects;
-        if (parameter instanceof MapperMethod.ParamMap) {
-            MapperMethod.ParamMap<Object> paramMap = 
((MapperMethod.ParamMap<Object>) parameter);
-            if (paramMap.get("param1") instanceof Collection) {
-                objects = ((Collection<Object>) paramMap.get("param1"));
+        if (SqlCommandType.INSERT == sqlCommandType || SqlCommandType.UPDATE 
== sqlCommandType) {
+            Collection<Object> objects;
+            if (parameter instanceof MapperMethod.ParamMap) {
+                MapperMethod.ParamMap<Object> paramMap = 
((MapperMethod.ParamMap<Object>) parameter);
+                if (paramMap.get("param1") instanceof Collection) {
+                    objects = ((Collection<Object>) paramMap.get("param1"));
+                } else {
+                    objects = 
Collections.singletonList(paramMap.get("param1"));
+                }
             } else {
-                objects = Collections.singletonList(paramMap.get("param1"));
+                objects = Collections.singletonList(parameter);
             }
-        } else {
-            objects = Collections.singletonList(parameter);
-        }
 
-        for (Object o : objects) {
-            setAuditFields(o, sqlCommandType);
+            for (Object o : objects) {
+                setAuditFields(o, sqlCommandType);
+            }
         }
-
         return invocation.proceed();
     }
 
@@ -92,26 +93,25 @@ public class AuditingInterceptor implements Interceptor {
         Timestamp timestamp = new Timestamp(System.currentTimeMillis());
 
         List<Field> fields = ClassUtils.getFields(object.getClass());
-        if (SqlCommandType.INSERT == sqlCommandType || SqlCommandType.UPDATE 
== sqlCommandType) {
-            for (Field field : fields) {
-                boolean accessible = field.canAccess(object);
-                field.setAccessible(true);
-                if (field.isAnnotationPresent(CreateBy.class)
-                        && SqlCommandType.INSERT == sqlCommandType
-                        && userId != null) {
-                    field.set(object, userId);
-                }
-                if (field.isAnnotationPresent(CreateTime.class) && 
SqlCommandType.INSERT == sqlCommandType) {
-                    field.set(object, timestamp);
-                }
-                if (field.isAnnotationPresent(UpdateBy.class) && userId != 
null) {
-                    field.set(object, userId);
-                }
-                if (field.isAnnotationPresent(UpdateTime.class)) {
-                    field.set(object, timestamp);
-                }
-                field.setAccessible(accessible);
+
+        for (Field field : fields) {
+            boolean accessible = field.canAccess(object);
+            field.setAccessible(true);
+            if (field.isAnnotationPresent(CreateBy.class)
+                    && SqlCommandType.INSERT == sqlCommandType
+                    && userId != null) {
+                field.set(object, userId);
+            }
+            if (field.isAnnotationPresent(CreateTime.class) && 
SqlCommandType.INSERT == sqlCommandType) {
+                field.set(object, timestamp);
+            }
+            if (field.isAnnotationPresent(UpdateBy.class) && userId != null) {
+                field.set(object, userId);
+            }
+            if (field.isAnnotationPresent(UpdateTime.class)) {
+                field.set(object, timestamp);
             }
+            field.setAccessible(accessible);
         }
     }
 }
diff --git 
a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/sql/SQLBuilder.java
 
b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/sql/SQLBuilder.java
index 8a6d7ba..35b9101 100644
--- 
a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/sql/SQLBuilder.java
+++ 
b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/sql/SQLBuilder.java
@@ -34,12 +34,10 @@ import lombok.extern.slf4j.Slf4j;
 import java.beans.PropertyDescriptor;
 import java.io.Serializable;
 import java.lang.reflect.Field;
-import java.text.MessageFormat;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
-import java.util.stream.Collectors;
 
 /**
  * Multiple data source support
@@ -117,7 +115,7 @@ public class SQLBuilder {
                     }
                     Object value = 
ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
                     if (!ObjectUtils.isEmpty(value)) {
-                        sql.SET("`" + getEquals(entry.getValue() + "`", 
entry.getKey()));
+                        sql.SET(getEquals(entry.getValue(), entry.getKey()));
                     }
                 }
 
@@ -158,7 +156,7 @@ public class SQLBuilder {
             case MYSQL: {
                 sql.SELECT(tableMetaData.getBaseColumns());
                 sql.FROM(tableMetaData.getTableName());
-                sql.WHERE(tableMetaData.getPkColumn() + " = '" + id + "'");
+                sql.WHERE(getEquals(tableMetaData.getPkColumn(), 
tableMetaData.getPkProperty()));
                 break;
             }
             case POSTGRESQL: {
@@ -185,10 +183,20 @@ public class SQLBuilder {
         SQL sql = new SQL();
         switch (DBType.toType(databaseId)) {
             case MYSQL: {
-                String idsStr = 
ids.stream().map(String::valueOf).collect(Collectors.joining("', '"));
                 sql.SELECT(tableMetaData.getBaseColumns());
                 sql.FROM(tableMetaData.getTableName());
-                sql.WHERE(tableMetaData.getPkColumn() + " in ('" + idsStr + 
"')");
+                if (ids == null || ids.isEmpty()) {
+                    sql.WHERE("1 = 0");
+                    break;
+                }
+
+                StringBuilder idStr = new StringBuilder();
+                for (int i = 0; i < ids.size(); i++) {
+                    idStr.append(getTokenParam("arg0[" + i + "]")).append(",");
+                }
+                idStr.deleteCharAt(idStr.lastIndexOf(","));
+
+                sql.WHERE(tableMetaData.getPkColumn() + " IN ( " + idStr + " 
)");
                 break;
             }
             case POSTGRESQL: {
@@ -240,7 +248,7 @@ public class SQLBuilder {
         switch (DBType.toType(databaseId)) {
             case MYSQL: {
                 sql.DELETE_FROM(tableMetaData.getTableName());
-                sql.WHERE(tableMetaData.getPkColumn() + " = '" + id + "'");
+                sql.WHERE(getEquals(tableMetaData.getPkColumn(), 
tableMetaData.getPkProperty()));
                 break;
             }
             case POSTGRESQL: {
@@ -261,9 +269,18 @@ public class SQLBuilder {
         SQL sql = new SQL();
         switch (DBType.toType(databaseId)) {
             case MYSQL: {
-                String idsStr = 
ids.stream().map(String::valueOf).collect(Collectors.joining("', '"));
+                if (ids == null || ids.isEmpty()) {
+                    break;
+                }
                 sql.DELETE_FROM(tableMetaData.getTableName());
-                sql.WHERE(tableMetaData.getPkColumn() + " in ('" + idsStr + 
"')");
+
+                StringBuilder idStr = new StringBuilder();
+                for (int i = 0; i < ids.size(); i++) {
+                    idStr.append(getTokenParam("arg0[" + i + "]")).append(",");
+                }
+                idStr.deleteCharAt(idStr.lastIndexOf(","));
+
+                sql.WHERE(tableMetaData.getPkColumn() + " IN ( " + idStr + " 
)");
                 break;
             }
             case POSTGRESQL: {
@@ -282,14 +299,13 @@ public class SQLBuilder {
 
     public static <Condition> String findByCondition(
             TableMetaData tableMetaData, String databaseId, Condition 
condition) throws IllegalAccessException {
-        String tableName = tableMetaData.getTableName();
         log.info("databaseId: {}", databaseId);
         SQL sql = new SQL();
         switch (DBType.toType(databaseId)) {
             case POSTGRESQL:
                 tableName = "\"" + tableName + "\"";
             case MYSQL: {
-                sql = mysqlCondition(condition, tableName);
+                sql = mysqlCondition(condition, tableMetaData);
                 break;
             }
             default: {
@@ -301,14 +317,15 @@ public class SQLBuilder {
     }
 
     private static String getEquals(String column, String property) {
-        return column + " = " + getTokenParam(property);
+        return "`" + column + "` = " + getTokenParam(property);
     }
 
     private static String getTokenParam(String property) {
         return "#{" + property + "}";
     }
 
-    private static <Condition> SQL mysqlCondition(Condition condition, String 
tableName) throws IllegalAccessException {
+    private static <Condition> SQL mysqlCondition(Condition condition, 
TableMetaData tableMetaData)
+            throws IllegalAccessException {
 
         Class<?> loadClass;
         try {
@@ -321,7 +338,7 @@ public class SQLBuilder {
         /* Prepare SQL */
         SQL sql = new SQL();
         sql.SELECT("*");
-        sql.FROM(tableName);
+        sql.FROM(tableMetaData.getTableName());
         for (Field field : fieldList) {
             field.setAccessible(true);
             String fieldName = field.getName();
@@ -329,75 +346,72 @@ public class SQLBuilder {
             if (field.isAnnotationPresent(QueryCondition.class) && 
Objects.nonNull(field.get(condition))) {
                 QueryCondition annotation = 
field.getAnnotation(QueryCondition.class);
 
-                String queryKey = fieldName;
+                String property = fieldName;
                 if (!annotation.queryKey().isEmpty()) {
-                    queryKey = annotation.queryKey();
+                    property = annotation.queryKey();
                 }
 
-                log.info(
-                        "[queryKey] {}, [queryType] {}, [queryValue] {}",
-                        queryKey,
-                        annotation.queryType().toString(),
-                        field.get(condition));
-
                 Object value = field.get(condition);
-                if (value != null) {
+                Map<String, String> fieldColumnMap = 
tableMetaData.getFieldColumnMap();
+
+                if (value != null && fieldColumnMap.containsKey(property)) {
+                    String columnName = fieldColumnMap.get(property);
+
+                    log.info(
+                            "[queryKey] {}, [queryType] {}, [queryValue] {}",
+                            property,
+                            annotation.queryType().toString(),
+                            field.get(condition));
                     switch (annotation.queryType()) {
                         case EQ:
-                            sql.WHERE(MessageFormat.format("{0} = ''{1}''", 
queryKey, value));
+                            sql.WHERE(getEquals(columnName, fieldName));
                             break;
                         case NOT_EQ:
-                            sql.WHERE(MessageFormat.format("{0} != ''{1}''", 
queryKey, value));
+                            sql.WHERE(columnName + " != " + 
getTokenParam(fieldName));
                             break;
                         case IN:
-                            sql.WHERE(MessageFormat.format(
-                                    "{0} IN (''{1}'')",
-                                    queryKey,
-                                    String.join("','", 
value.toString().split(annotation.multipleDelimiter()))));
+                            sql.WHERE(columnName + " IN ( REPLACE( " + 
getTokenParam(fieldName) + ", '"
+                                    + annotation.multipleDelimiter() + "', 
',') )");
                             break;
                         case NOT_IN:
-                            sql.WHERE(MessageFormat.format(
-                                    "{0} NOT IN (''{1}'')",
-                                    queryKey,
-                                    String.join("','", 
value.toString().split(annotation.multipleDelimiter()))));
+                            sql.WHERE(columnName + " NOT IN ( REPLACE( " + 
getTokenParam(fieldName) + ", '"
+                                    + annotation.multipleDelimiter() + "', 
',') )");
                             break;
                         case GT:
-                            sql.WHERE(MessageFormat.format("{0} > ''{1}''", 
queryKey, value));
+                            sql.WHERE(columnName + " > " + 
getTokenParam(fieldName));
                             break;
                         case GTE:
-                            sql.WHERE(MessageFormat.format("{0} >= ''{1}''", 
queryKey, value));
+                            sql.WHERE(columnName + " >= " + 
getTokenParam(fieldName));
                             break;
                         case LT:
-                            sql.WHERE(MessageFormat.format("{0} < ''{1}''", 
queryKey, value));
+                            sql.WHERE(columnName + " < " + 
getTokenParam(fieldName));
                             break;
                         case LTE:
-                            sql.WHERE(MessageFormat.format("{0} <= ''{1}''", 
queryKey, value));
+                            sql.WHERE(columnName + " <= " + 
getTokenParam(fieldName));
                             break;
                         case BETWEEN:
-                            String[] valueArr = 
field.get(condition).toString().split(annotation.pairDelimiter());
-                            if (valueArr.length == 2) {
-                                sql.WHERE(MessageFormat.format(
-                                        "{0} BETWEEN ''{1}'' AND ''{2}''", 
queryKey, valueArr[0], valueArr[1]));
-                            }
+                            sql.WHERE(columnName + " BETWEEN SUBSTRING_INDEX( 
" + getTokenParam(fieldName) + ", '"
+                                    + annotation.pairDelimiter() + "', 1) AND 
SUBSTRING_INDEX( "
+                                    + getTokenParam(fieldName) + ", '"
+                                    + annotation.pairDelimiter() + "', 2)");
                             break;
                         case PREFIX_LIKE:
-                            sql.WHERE(MessageFormat.format("{0} LIKE 
CONCAT(''{1}'', ''%'')", queryKey, value));
+                            sql.WHERE(columnName + " LIKE CONCAT( " + 
getTokenParam(fieldName) + ", '%')");
                             break;
                         case SUFFIX_LIKE:
-                            sql.WHERE(MessageFormat.format("{0} LIKE 
CONCAT(''%'', ''{1}'')", queryKey, value));
+                            sql.WHERE(columnName + " LIKE CONCAT('%', " + 
getTokenParam(fieldName) + ")");
                             break;
                         case LIKE:
-                            sql.WHERE(MessageFormat.format("{0} LIKE 
CONCAT(''%'', ''{1}'', ''%'')", queryKey, value));
+                            sql.WHERE(columnName + " LIKE CONCAT('%', " + 
getTokenParam(fieldName) + ", '%')");
                             break;
                         case NOT_LIKE:
-                            sql.WHERE(MessageFormat.format(
-                                    "{0} NOT LIKE CONCAT(''%'', ''{1}'', 
''%'')", queryKey, value));
+                            sql.WHERE(columnName + " NOT LIKE CONCAT('%', " + 
getTokenParam(fieldName) + ", '%')");
                             break;
                         case NULL:
-                            sql.WHERE(queryKey + " IS NULL");
+                            sql.WHERE(columnName + " IS NULL");
                             break;
                         case NOT_NULL:
-                            sql.WHERE(queryKey + " IS NOT NULL");
+                            sql.WHERE(columnName + " IS NOT NULL");
                             break;
                         default:
                             log.warn("Unknown query type: {}", 
annotation.queryType());

Reply via email to