This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new efd0ce0 fix encrypt rewrite exception when execute multiple table
join query (#13705)
efd0ce0 is described below
commit efd0ce06371a6d8d5260963ac7537c20b3e4a37f
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Fri Nov 19 19:09:48 2021 +0800
fix encrypt rewrite exception when execute multiple table join query
(#13705)
* fix encrypt rewrite exception when execute multiple table join query
* fix encrypt rewrite exception when execute multiple table join query
* optimize logic
* optimize rewrite logic
* fix test case
* delete useless method
---
.../merge/dql/EncryptAlgorithmMetaData.java | 9 +-
.../rewrite/condition/EncryptConditionEngine.java | 15 +-
.../impl/EncryptPredicateColumnTokenGenerator.java | 31 ++--
.../impl/EncryptProjectionTokenGenerator.java | 157 +++++++++++----------
.../merge/dql/EncryptAlgorithmMetaDataTest.java | 8 +-
.../impl/EncryptProjectionTokenGeneratorTest.java | 6 +
.../impl/WhereClauseShardingConditionEngine.java | 16 ++-
.../infra/binder/segment/table/TablesContext.java | 91 ++++--------
.../binder/segment/table/TablesContextTest.java | 30 ++--
.../encrypt/case/select_for_query_with_cipher.xml | 5 +
10 files changed, 181 insertions(+), 187 deletions(-)
diff --git
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptAlgorithmMetaData.java
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptAlgorithmMetaData.java
index 6d6e82b..0a4586e 100644
---
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptAlgorithmMetaData.java
+++
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptAlgorithmMetaData.java
@@ -20,12 +20,14 @@ package org.apache.shardingsphere.encrypt.merge.dql;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.spi.EncryptAlgorithm;
-import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
import
org.apache.shardingsphere.infra.binder.segment.select.projection.Projection;
import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
+import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
+import java.util.Collections;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
/**
@@ -58,9 +60,10 @@ public final class EncryptAlgorithmMetaData {
Projection projection = expandProjections.get(columnIndex - 1);
if (projection instanceof ColumnProjection) {
String columnName = ((ColumnProjection) projection).getName();
- Optional<String> tableName =
selectStatementContext.getTablesContext().findTableName((ColumnProjection)
projection, schema);
+ Map<String, String> columnTableNames =
selectStatementContext.getTablesContext().findTableName(Collections.singletonList((ColumnProjection)
projection), schema);
String schemaName = selectStatementContext.getSchemaName();
- return tableName.isPresent() ? findEncryptor(schemaName,
tableName.get(), columnName) : findEncryptor(schemaName, columnName);
+ return columnTableNames.containsKey(projection.getExpression())
+ ? findEncryptor(schemaName,
columnTableNames.get(projection.getExpression()), columnName) :
findEncryptor(schemaName, columnName);
}
return Optional.empty();
}
diff --git
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java
index 0b50a23..23f12a94 100644
---
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java
+++
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java
@@ -21,6 +21,7 @@ import lombok.RequiredArgsConstructor;
import
org.apache.shardingsphere.encrypt.rewrite.condition.impl.EncryptEqualCondition;
import
org.apache.shardingsphere.encrypt.rewrite.condition.impl.EncryptInCondition;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
+import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import
org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
@@ -104,9 +105,10 @@ public final class EncryptConditionEngine {
private Collection<EncryptCondition> createEncryptConditions(final String
schemaName, final ExpressionSegment expression, final Map<String, String>
columnTableNames) {
Collection<EncryptCondition> result = new LinkedList<>();
for (ColumnSegment each : ColumnExtractor.extract(expression)) {
- Optional<String> tableName =
Optional.ofNullable(columnTableNames.get(each.getQualifiedName()));
+ ColumnProjection projection = buildColumnProjection(each);
+ Optional<String> tableName =
Optional.ofNullable(columnTableNames.get(projection.getExpression()));
Optional<EncryptCondition> encryptCondition =
tableName.isPresent()
- && encryptRule.findEncryptor(schemaName, tableName.get(),
each.getIdentifier().getValue()).isPresent() ?
createEncryptCondition(expression, tableName.get()) : Optional.empty();
+ && encryptRule.findEncryptor(schemaName, tableName.get(),
projection.getName()).isPresent() ? createEncryptCondition(expression,
tableName.get()) : Optional.empty();
encryptCondition.ifPresent(result::add);
}
return result;
@@ -145,11 +147,16 @@ public final class EncryptConditionEngine {
}
private Map<String, String> getColumnTableNames(final
SQLStatementContext<?> sqlStatementContext, final Collection<AndPredicate>
andPredicates) {
- Collection<ColumnSegment> columns =
andPredicates.stream().flatMap(each -> each.getPredicates().stream())
- .flatMap(each ->
ColumnExtractor.extract(each).stream()).collect(Collectors.toList());
+ Collection<ColumnProjection> columns =
andPredicates.stream().flatMap(each -> each.getPredicates().stream())
+ .flatMap(each ->
ColumnExtractor.extract(each).stream()).map(this::buildColumnProjection).collect(Collectors.toList());
return sqlStatementContext.getTablesContext().findTableName(columns,
schema);
}
+ private ColumnProjection buildColumnProjection(final ColumnSegment
segment) {
+ String owner = segment.getOwner().map(optional ->
optional.getIdentifier().getValue()).orElse(null);
+ return new ColumnProjection(owner, segment.getIdentifier().getValue(),
null);
+ }
+
private static Optional<EncryptCondition>
createCompareEncryptCondition(final String tableName, final
BinaryOperationExpression expression, final ExpressionSegment
compareRightValue) {
if (!(expression.getLeft() instanceof ColumnSegment)) {
return Optional.empty();
diff --git
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptPredicateColumnTokenGenerator.java
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptPredicateColumnTokenGenerator.java
index 2f4e450..502dfb9 100644
---
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptPredicateColumnTokenGenerator.java
+++
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptPredicateColumnTokenGenerator.java
@@ -23,7 +23,6 @@ import
org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLT
import org.apache.shardingsphere.encrypt.rule.EncryptTable;
import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
-import
org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.type.WhereAvailable;
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
@@ -44,7 +43,6 @@ import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.Map;
-import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
@@ -61,10 +59,9 @@ public final class EncryptPredicateColumnTokenGenerator
extends BaseEncryptSQLTo
@SuppressWarnings("rawtypes")
@Override
protected boolean isGenerateSQLTokenForEncrypt(final SQLStatementContext
sqlStatementContext) {
- boolean containsJoinQueryOrSubquery = sqlStatementContext instanceof
SelectStatementContext
- && (((SelectStatementContext)
sqlStatementContext).isContainsJoinQuery() || ((SelectStatementContext)
sqlStatementContext).isContainsSubquery());
- boolean containsInsertSelectContext = (sqlStatementContext instanceof
InsertStatementContext) && null != ((InsertStatementContext)
sqlStatementContext).getInsertSelectContext();
- return containsJoinQueryOrSubquery || containsInsertSelectContext ||
(sqlStatementContext instanceof WhereAvailable && ((WhereAvailable)
sqlStatementContext).getWhere().isPresent());
+ boolean containsJoinQuery = sqlStatementContext instanceof
SelectStatementContext && ((SelectStatementContext)
sqlStatementContext).isContainsJoinQuery();
+ boolean containsSubquery = sqlStatementContext instanceof
SelectStatementContext && ((SelectStatementContext)
sqlStatementContext).isContainsSubquery();
+ return containsJoinQuery || containsSubquery || (sqlStatementContext
instanceof WhereAvailable && ((WhereAvailable)
sqlStatementContext).getWhere().isPresent());
}
@SuppressWarnings("rawtypes")
@@ -85,7 +82,7 @@ public final class EncryptPredicateColumnTokenGenerator
extends BaseEncryptSQLTo
Collection<SubstitutableColumnNameToken> result = new LinkedList<>();
for (ExpressionSegment each : predicates) {
for (ColumnSegment column : ColumnExtractor.extract(each)) {
- Optional<EncryptTable> encryptTable =
findEncryptTable(columnTableNames, column);
+ Optional<EncryptTable> encryptTable =
findEncryptTable(columnTableNames, buildColumnProjection(column));
if (!encryptTable.isPresent() ||
!encryptTable.get().findEncryptorName(column.getIdentifier().getValue()).isPresent())
{
continue;
}
@@ -111,27 +108,29 @@ public final class EncryptPredicateColumnTokenGenerator
extends BaseEncryptSQLTo
private Collection<WhereSegment> getWhereSegments(final
SQLStatementContext<?> sqlStatementContext) {
Collection<WhereSegment> result = new LinkedList<>();
- if (sqlStatementContext instanceof WhereAvailable) {
- ((WhereAvailable)
sqlStatementContext).getWhere().ifPresent(result::add);
- }
if (sqlStatementContext instanceof SelectStatementContext) {
result.addAll(WhereExtractUtil.getSubqueryWhereSegments((SelectStatement)
sqlStatementContext.getSqlStatement()));
result.addAll(WhereExtractUtil.getJoinWhereSegments((SelectStatement)
sqlStatementContext.getSqlStatement()));
}
- if (sqlStatementContext instanceof InsertStatementContext && null !=
((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) {
- result.addAll(getWhereSegments(((InsertStatementContext)
sqlStatementContext).getInsertSelectContext().getSelectStatementContext()));
+ if (sqlStatementContext instanceof WhereAvailable) {
+ ((WhereAvailable)
sqlStatementContext).getWhere().ifPresent(result::add);
}
return result;
}
private Map<String, String> getColumnTableNames(final
SQLStatementContext<?> sqlStatementContext, final Collection<AndPredicate>
andPredicates) {
- Collection<ColumnSegment> columns =
andPredicates.stream().flatMap(each -> each.getPredicates().stream())
- .flatMap(each ->
ColumnExtractor.extract(each).stream()).filter(Objects::nonNull).collect(Collectors.toList());
+ Collection<ColumnProjection> columns =
andPredicates.stream().flatMap(each -> each.getPredicates().stream())
+ .flatMap(each ->
ColumnExtractor.extract(each).stream()).map(this::buildColumnProjection).collect(Collectors.toList());
return sqlStatementContext.getTablesContext().findTableName(columns,
schema);
}
- private Optional<EncryptTable> findEncryptTable(final Map<String, String>
columnTableNames, final ColumnSegment column) {
- return
Optional.ofNullable(columnTableNames.get(column.getQualifiedName())).flatMap(tableName
-> getEncryptRule().findEncryptTable(tableName));
+ private ColumnProjection buildColumnProjection(final ColumnSegment
segment) {
+ String owner = segment.getOwner().map(optional ->
optional.getIdentifier().getValue()).orElse(null);
+ return new ColumnProjection(owner, segment.getIdentifier().getValue(),
null);
+ }
+
+ private Optional<EncryptTable> findEncryptTable(final Map<String, String>
columnTableNames, final ColumnProjection column) {
+ return
Optional.ofNullable(columnTableNames.get(column.getExpression())).flatMap(tableName
-> getEncryptRule().findEncryptTable(tableName));
}
private Collection<ColumnProjection> getColumnProjections(final String
columnName) {
diff --git
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java
index 152b46a..e54d1a4 100644
---
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java
+++
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java
@@ -18,7 +18,6 @@
package org.apache.shardingsphere.encrypt.rewrite.token.generator.impl;
import com.google.common.base.Preconditions;
-import com.google.common.base.Strings;
import lombok.Setter;
import
org.apache.shardingsphere.encrypt.rewrite.aware.QueryWithCipherColumnAware;
import
org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLTokenGenerator;
@@ -27,24 +26,25 @@ import
org.apache.shardingsphere.infra.binder.segment.select.projection.Projecti
import
org.apache.shardingsphere.infra.binder.segment.select.projection.ProjectionsContext;
import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ShorthandProjection;
-import org.apache.shardingsphere.infra.binder.segment.table.TablesContext;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
+import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
import
org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import
org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.PreviousSQLTokensAware;
+import
org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
import
org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnNameToken;
import org.apache.shardingsphere.sql.parser.sql.common.constant.SubqueryType;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment;
-import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
/**
@@ -52,12 +52,14 @@ import java.util.Optional;
*/
@Setter
public final class EncryptProjectionTokenGenerator extends
BaseEncryptSQLTokenGenerator
- implements CollectionSQLTokenGenerator<SQLStatementContext>,
QueryWithCipherColumnAware, PreviousSQLTokensAware {
+ implements CollectionSQLTokenGenerator<SQLStatementContext>,
QueryWithCipherColumnAware, PreviousSQLTokensAware, SchemaMetaDataAware {
private boolean queryWithCipherColumn;
private List<SQLToken> previousSQLTokens;
+ private ShardingSphereSchema schema;
+
@Override
protected boolean isGenerateSQLTokenForEncrypt(final SQLStatementContext
sqlStatementContext) {
return sqlStatementContext instanceof SelectStatementContext &&
!((SelectStatementContext) sqlStatementContext).getAllTables().isEmpty();
@@ -68,59 +70,51 @@ public final class EncryptProjectionTokenGenerator extends
BaseEncryptSQLTokenGe
Preconditions.checkState(sqlStatementContext instanceof
SelectStatementContext);
Collection<SubstitutableColumnNameToken> result = new
LinkedHashSet<>();
for (SelectStatementContext each :
getSelectStatementContexts((SelectStatementContext) sqlStatementContext)) {
- for (String table : each.getTablesContext().getTableNames()) {
- Optional<EncryptTable> encryptTable =
getEncryptRule().findEncryptTable(table);
- encryptTable.ifPresent(optional ->
result.addAll(generateSQLTokens(each, optional, table)));
+ Map<String, String> columnTableNames = getColumnTableNames(each);
+ for (ProjectionSegment projection :
each.getSqlStatement().getProjections().getProjections()) {
+ result.addAll(generateSQLTokens(each, projection,
columnTableNames));
}
}
return result;
}
- private Collection<SubstitutableColumnNameToken> generateSQLTokens(final
SelectStatementContext selectStatementContext,
- final
EncryptTable encryptTable, final String tableName) {
+ private Collection<SubstitutableColumnNameToken> generateSQLTokens(final
SelectStatementContext selectStatementContext,
+ final
ProjectionSegment projection, final Map<String, String> columnTableNames) {
Collection<SubstitutableColumnNameToken> result = new LinkedList<>();
SubqueryType subqueryType = selectStatementContext.getSubqueryType();
- TablesContext tablesContext =
selectStatementContext.getTablesContext();
- Collection<ProjectionSegment> projections =
selectStatementContext.getSqlStatement().getProjections().getProjections();
- for (ProjectionSegment each : projections) {
- if (each instanceof ColumnProjectionSegment) {
- ColumnProjectionSegment columnSegment =
(ColumnProjectionSegment) each;
- String columnName =
columnSegment.getColumn().getIdentifier().getValue();
- String owner =
columnSegment.getColumn().getOwner().map(optional ->
optional.getIdentifier().getValue()).orElse(null);
- if (encryptTable.getLogicColumns().contains(columnName) &&
isOwnerSameWithTableNameOrAlias(tableName, owner, tablesContext)) {
- result.add(generateSQLTokens(columnSegment, encryptTable,
tableName, subqueryType));
- }
- Optional<String> subqueryTableName =
tablesContext.findTableNameFromSubquery(columnName, owner);
- subqueryTableName.ifPresent(optional ->
result.add(generateSQLTokens(columnSegment, encryptTable, optional,
subqueryType)));
+ if (projection instanceof ColumnProjectionSegment) {
+ ColumnProjectionSegment columnSegment = (ColumnProjectionSegment)
projection;
+ ColumnProjection columnProjection =
buildColumnProjection(columnSegment);
+ String tableName =
columnTableNames.get(columnProjection.getExpression());
+ if (null != tableName && getEncryptRule().findEncryptor(tableName,
columnProjection.getName()).isPresent()) {
+ result.add(generateSQLTokens(tableName, columnSegment,
columnProjection, subqueryType));
}
- if (isToGeneratedSQLToken(each, selectStatementContext,
tableName)) {
- ShorthandProjectionSegment shorthandSegment =
(ShorthandProjectionSegment) each;
- ShorthandProjection shorthandProjection =
getShorthandProjection(shorthandSegment,
selectStatementContext.getProjectionsContext());
- if (!shorthandProjection.getActualColumns().isEmpty()) {
- result.add(generateSQLTokens(shorthandSegment,
shorthandProjection, tableName, encryptTable,
selectStatementContext.getDatabaseType(), subqueryType));
- }
+ }
+ if (projection instanceof ShorthandProjectionSegment) {
+ ShorthandProjectionSegment shorthandSegment =
(ShorthandProjectionSegment) projection;
+ Collection<ColumnProjection> actualColumns =
getShorthandProjection(shorthandSegment,
selectStatementContext.getProjectionsContext()).getActualColumns().values();
+ if (!actualColumns.isEmpty()) {
+ result.add(generateSQLTokens(shorthandSegment, actualColumns,
selectStatementContext.getDatabaseType(), subqueryType, columnTableNames));
}
}
return result;
}
- private SubstitutableColumnNameToken generateSQLTokens(final
ColumnProjectionSegment segment, final EncryptTable encryptTable,
- final String
tableName, final SubqueryType subqueryType) {
- String columnName = segment.getColumn().getIdentifier().getValue();
- String alias = segment.getAlias().orElseGet(() ->
segment.getColumn().getIdentifier().getValue());
- Collection<ColumnProjection> projections =
generateProjections(tableName, columnName, alias, null, encryptTable,
subqueryType);
- int startIndex = segment.getColumn().getOwner().isPresent() ?
segment.getColumn().getOwner().get().getStopIndex() + 2 :
segment.getColumn().getStartIndex();
- int stopIndex = segment.getStopIndex();
+ private SubstitutableColumnNameToken generateSQLTokens(final String
tableName, final ColumnProjectionSegment columnSegment,
+ final
ColumnProjection columnProjection, final SubqueryType subqueryType) {
+ Collection<ColumnProjection> projections =
generateProjections(tableName, columnProjection, subqueryType, false);
+ int startIndex = columnSegment.getColumn().getOwner().isPresent() ?
columnSegment.getColumn().getOwner().get().getStopIndex() + 2 :
columnSegment.getColumn().getStartIndex();
+ int stopIndex = columnSegment.getStopIndex();
return new SubstitutableColumnNameToken(startIndex, stopIndex,
projections);
}
- private SubstitutableColumnNameToken generateSQLTokens(final
ShorthandProjectionSegment segment, final ShorthandProjection
shorthandProjection, final String tableName,
- final EncryptTable
encryptTable, final DatabaseType databaseType, final SubqueryType subqueryType)
{
+ private SubstitutableColumnNameToken generateSQLTokens(final
ShorthandProjectionSegment segment, final Collection<ColumnProjection>
actualColumns,
+ final DatabaseType
databaseType, final SubqueryType subqueryType, final Map<String, String>
columnTableNames) {
List<ColumnProjection> projections = new LinkedList<>();
- for (ColumnProjection each :
shorthandProjection.getActualColumns().values()) {
- if (encryptTable.getLogicColumns().contains(each.getName())) {
- String owner = null == each.getOwner() ? null :
each.getOwner();
- projections.addAll(generateProjections(tableName,
each.getName(), each.getName(), owner, encryptTable, subqueryType));
+ for (ColumnProjection each : actualColumns) {
+ String tableName = columnTableNames.get(each.getExpression());
+ if (null != tableName && getEncryptRule().findEncryptor(tableName,
each.getName()).isPresent()) {
+ projections.addAll(generateProjections(tableName, each,
subqueryType, true));
} else {
projections.add(new ColumnProjection(each.getOwner(),
each.getName(), each.getAlias().orElse(null)));
}
@@ -129,6 +123,24 @@ public final class EncryptProjectionTokenGenerator extends
BaseEncryptSQLTokenGe
return new SubstitutableColumnNameToken(segment.getStartIndex(),
segment.getStopIndex(), projections, databaseType.getQuoteCharacter());
}
+ private ColumnProjection buildColumnProjection(final
ColumnProjectionSegment segment) {
+ String owner = segment.getColumn().getOwner().map(optional ->
optional.getIdentifier().getValue()).orElse(null);
+ return new ColumnProjection(owner,
segment.getColumn().getIdentifier().getValue(),
segment.getAlias().orElse(null));
+ }
+
+ private Map<String, String> getColumnTableNames(final
SelectStatementContext selectStatementContext) {
+ Collection<ColumnProjection> columns = new LinkedList<>();
+ for (Projection projection :
selectStatementContext.getProjectionsContext().getProjections()) {
+ if (projection instanceof ColumnProjection) {
+ columns.add((ColumnProjection) projection);
+ }
+ if (projection instanceof ShorthandProjection) {
+ columns.addAll(((ShorthandProjection)
projection).getActualColumns().values());
+ }
+ }
+ return
selectStatementContext.getTablesContext().findTableName(columns, schema);
+ }
+
private Collection<SelectStatementContext>
getSelectStatementContexts(final SelectStatementContext selectStatementContext)
{
Collection<SelectStatementContext> result = new LinkedList<>();
result.add(selectStatementContext);
@@ -136,62 +148,48 @@ public final class EncryptProjectionTokenGenerator
extends BaseEncryptSQLTokenGe
return result;
}
- private boolean isOwnerSameWithTableNameOrAlias(final String tableName,
final String owner, final TablesContext tablesContext) {
- if (Strings.isNullOrEmpty(owner)) {
- return true;
- }
- return tablesContext.findTableNameFromSQL(owner).filter(optional ->
optional.equals(tableName)).isPresent();
- }
-
- private Collection<ColumnProjection> generateProjections(final String
tableName, final String columnName, final String alias, final String owner,
- final
EncryptTable encryptTable, final SubqueryType subqueryType) {
+ private Collection<ColumnProjection> generateProjections(final String
tableName, final ColumnProjection column, final SubqueryType subqueryType,
final boolean shorthand) {
Collection<ColumnProjection> result = new LinkedList<>();
if (SubqueryType.PREDICATE_SUBQUERY.equals(subqueryType)) {
- result.add(generatePredicateSubqueryProjection(tableName,
columnName, owner, encryptTable));
+ result.add(generatePredicateSubqueryProjection(tableName, column));
} else if (SubqueryType.TABLE_SUBQUERY.equals(subqueryType)) {
- result.addAll(generateTableSubqueryProjections(tableName,
columnName, alias, owner));
+ result.addAll(generateTableSubqueryProjections(tableName, column));
} else {
- result.add(generateCommonProjection(tableName, columnName, alias,
owner));
+ result.add(generateCommonProjection(tableName, column, shorthand));
}
return result;
}
- private ColumnProjection generatePredicateSubqueryProjection(final String
tableName, final String columnName, final String owner, final EncryptTable
encryptTable) {
- if (Boolean.FALSE.equals(encryptTable.getQueryWithCipherColumn()) ||
!queryWithCipherColumn) {
- Optional<String> plainColumn =
getEncryptRule().findPlainColumn(tableName, columnName);
+ private ColumnProjection generatePredicateSubqueryProjection(final String
tableName, final ColumnProjection column) {
+ Boolean queryWithCipherColumn =
getEncryptRule().findEncryptTable(tableName).map(EncryptTable::getQueryWithCipherColumn).orElse(null);
+ if (Boolean.FALSE.equals(queryWithCipherColumn) ||
!this.queryWithCipherColumn) {
+ Optional<String> plainColumn =
getEncryptRule().findPlainColumn(tableName, column.getName());
if (plainColumn.isPresent()) {
- return new ColumnProjection(owner, plainColumn.get(), null);
+ return new ColumnProjection(column.getOwner(),
plainColumn.get(), null);
}
}
- Optional<String> assistedQueryColumn =
getEncryptRule().findAssistedQueryColumn(tableName, columnName);
+ Optional<String> assistedQueryColumn =
getEncryptRule().findAssistedQueryColumn(tableName, column.getName());
if (assistedQueryColumn.isPresent()) {
- return new ColumnProjection(owner, assistedQueryColumn.get(),
null);
+ return new ColumnProjection(column.getOwner(),
assistedQueryColumn.get(), null);
}
- String cipherColumn = getEncryptRule().getCipherColumn(tableName,
columnName);
- return new ColumnProjection(owner, cipherColumn, null);
+ String cipherColumn = getEncryptRule().getCipherColumn(tableName,
column.getName());
+ return new ColumnProjection(column.getOwner(), cipherColumn, null);
}
- private Collection<ColumnProjection>
generateTableSubqueryProjections(final String tableName, final String
columnName, final String alias, final String owner) {
+ private Collection<ColumnProjection>
generateTableSubqueryProjections(final String tableName, final ColumnProjection
column) {
Collection<ColumnProjection> result = new LinkedList<>();
- result.add(new ColumnProjection(owner,
getEncryptRule().getCipherColumn(tableName, columnName), alias));
- Optional<String> assistedQueryColumn =
getEncryptRule().findAssistedQueryColumn(tableName, columnName);
- assistedQueryColumn.ifPresent(optional -> result.add(new
ColumnProjection(owner, optional, null)));
- Optional<String> plainColumn =
getEncryptRule().findPlainColumn(tableName, columnName);
- plainColumn.ifPresent(optional -> result.add(new
ColumnProjection(owner, optional, null)));
+ result.add(new ColumnProjection(column.getOwner(),
getEncryptRule().getCipherColumn(tableName, column.getName()),
column.getAlias().orElse(column.getName())));
+ Optional<String> assistedQueryColumn =
getEncryptRule().findAssistedQueryColumn(tableName, column.getName());
+ assistedQueryColumn.ifPresent(optional -> result.add(new
ColumnProjection(column.getOwner(), optional, null)));
+ Optional<String> plainColumn =
getEncryptRule().findPlainColumn(tableName, column.getName());
+ plainColumn.ifPresent(optional -> result.add(new
ColumnProjection(column.getOwner(), optional, null)));
return result;
}
- private ColumnProjection generateCommonProjection(final String tableName,
final String columnName, final String alias, final String owner) {
- String encryptColumnName = getEncryptColumnName(tableName, columnName);
- return new ColumnProjection(owner, encryptColumnName, alias);
- }
-
- private boolean isToGeneratedSQLToken(final ProjectionSegment
projectionSegment, final SelectStatementContext selectStatementContext, final
String tableName) {
- if (!(projectionSegment instanceof ShorthandProjectionSegment)) {
- return false;
- }
- Optional<OwnerSegment> ownerSegment = ((ShorthandProjectionSegment)
projectionSegment).getOwner();
- return ownerSegment.map(segment ->
selectStatementContext.getTablesContext().findTableNameFromSQL(segment.getIdentifier().getValue()).orElse("").equalsIgnoreCase(tableName)).orElse(true);
+ private ColumnProjection generateCommonProjection(final String tableName,
final ColumnProjection column, final boolean shorthand) {
+ String encryptColumnName = getEncryptColumnName(tableName,
column.getName());
+ String owner = shorthand ? column.getOwner() : null;
+ return new ColumnProjection(owner, encryptColumnName,
column.getAlias().orElse(column.getName()));
}
private String getEncryptColumnName(final String tableName, final String
logicEncryptColumnName) {
@@ -218,4 +216,9 @@ public final class EncryptProjectionTokenGenerator extends
BaseEncryptSQLTokenGe
public void setPreviousSQLTokens(final List<SQLToken> previousSQLTokens) {
this.previousSQLTokens = previousSQLTokens;
}
+
+ @Override
+ public void setSchema(final ShardingSphereSchema schema) {
+ this.schema = schema;
+ }
}
diff --git
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptAlgorithmMetaDataTest.java
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptAlgorithmMetaDataTest.java
index 8b41925..22fda95 100644
---
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptAlgorithmMetaDataTest.java
+++
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptAlgorithmMetaDataTest.java
@@ -36,7 +36,9 @@ import org.mockito.junit.MockitoJUnitRunner;
import java.util.Arrays;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
import java.util.Properties;
@@ -82,7 +84,9 @@ public final class EncryptAlgorithmMetaDataTest {
@Test
public void assertFindEncryptorByTableNameAndColumnName() {
- when(tablesContext.findTableName(columnProjection,
schema)).thenReturn(Optional.of("t_order"));
+ Map<String, String> columnTableNames = new HashMap<>();
+ columnTableNames.put(columnProjection.getExpression(), "t_order");
+
when(tablesContext.findTableName(Collections.singletonList(columnProjection),
schema)).thenReturn(columnTableNames);
when(encryptRule.findEncryptor(null, "t_order",
"id")).thenReturn(Optional.of(encryptAlgorithm));
EncryptAlgorithmMetaData encryptAlgorithmMetaData = new
EncryptAlgorithmMetaData(schema, encryptRule, selectStatementContext);
Optional<EncryptAlgorithm> actualEncryptor =
encryptAlgorithmMetaData.findEncryptor(1);
@@ -92,7 +96,7 @@ public final class EncryptAlgorithmMetaDataTest {
@Test
public void assertFindEncryptorByColumnName() {
- when(tablesContext.findTableName(columnProjection,
schema)).thenReturn(Optional.empty());
+
when(tablesContext.findTableName(Collections.singletonList(columnProjection),
schema)).thenReturn(Collections.emptyMap());
when(tablesContext.getTableNames()).thenReturn(Arrays.asList("t_user",
"t_user_item", "t_order_item"));
when(encryptRule.findEncryptor(null, "t_order_item",
"id")).thenReturn(Optional.of(encryptAlgorithm));
EncryptAlgorithmMetaData encryptAlgorithmMetaData = new
EncryptAlgorithmMetaData(schema, encryptRule, selectStatementContext);
diff --git
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/impl/EncryptProjectionTokenGeneratorTest.java
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/impl/EncryptProjectionTokenGeneratorTest.java
index 57b30d4..79d1b34 100644
---
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/impl/EncryptProjectionTokenGeneratorTest.java
+++
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/impl/EncryptProjectionTokenGeneratorTest.java
@@ -17,9 +17,11 @@
package org.apache.shardingsphere.encrypt.rewrite.impl;
+import org.apache.shardingsphere.encrypt.fixture.TestEncryptAlgorithm;
import
org.apache.shardingsphere.encrypt.rewrite.token.generator.impl.EncryptProjectionTokenGenerator;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.EncryptTable;
+import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.segment.table.TablesContext;
import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import
org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnNameToken;
@@ -68,6 +70,7 @@ public final class EncryptProjectionTokenGeneratorTest {
when(sqlStatementContext.getSubqueryContexts().values()).thenReturn(Collections.emptyList());
SimpleTableSegment doctorOneTable = new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("doctor1")));
when(sqlStatementContext.getTablesContext()).thenReturn(new
TablesContext(Arrays.asList(doctorTable, doctorOneTable)));
+
when(sqlStatementContext.getProjectionsContext().getProjections()).thenReturn(Collections.singletonList(new
ColumnProjection("a", "mobile", null)));
Collection<SubstitutableColumnNameToken> actual =
generator.generateSQLTokens(sqlStatementContext);
assertThat(actual.size(), is(1));
}
@@ -85,6 +88,7 @@ public final class EncryptProjectionTokenGeneratorTest {
when(sqlStatementContext.getSubqueryContexts().values()).thenReturn(Collections.emptyList());
SimpleTableSegment sameDoctorTable = new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("doctor")));
when(sqlStatementContext.getTablesContext()).thenReturn(new
TablesContext(Arrays.asList(doctorTable, sameDoctorTable)));
+
when(sqlStatementContext.getProjectionsContext().getProjections()).thenReturn(Collections.singletonList(new
ColumnProjection("a", "mobile", null)));
Collection<SubstitutableColumnNameToken> actual =
generator.generateSQLTokens(sqlStatementContext);
assertThat(actual.size(), is(1));
}
@@ -101,6 +105,7 @@ public final class EncryptProjectionTokenGeneratorTest {
SimpleTableSegment doctorTable = new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("doctor")));
SimpleTableSegment doctorOneTable = new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("doctor1")));
when(sqlStatementContext.getTablesContext()).thenReturn(new
TablesContext(Arrays.asList(doctorTable, doctorOneTable)));
+
when(sqlStatementContext.getProjectionsContext().getProjections()).thenReturn(Collections.singletonList(new
ColumnProjection("doctor", "mobile", null)));
Collection<SubstitutableColumnNameToken> actual =
generator.generateSQLTokens(sqlStatementContext);
assertThat(actual.size(), is(1));
}
@@ -115,6 +120,7 @@ public final class EncryptProjectionTokenGeneratorTest {
when(encryptRule.findPlainColumn("doctor1",
"mobile")).thenReturn(Optional.of("Mobile"));
when(encryptRule.findEncryptTable("doctor")).thenReturn(Optional.of(encryptTable1));
when(encryptRule.findEncryptTable("doctor1")).thenReturn(Optional.of(encryptTable2));
+ when(encryptRule.findEncryptor("doctor",
"mobile")).thenReturn(Optional.of(new TestEncryptAlgorithm()));
return encryptRule;
}
}
diff --git
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/impl/WhereClauseShardingConditionEngine.java
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/impl/WhereClauseShardingConditionEngine.java
index a0cd2d2..61dbe0c 100644
---
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/impl/WhereClauseShardingConditionEngine.java
+++
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/impl/WhereClauseShardingConditionEngine.java
@@ -19,6 +19,7 @@ package
org.apache.shardingsphere.sharding.route.engine.condition.engine.impl;
import com.google.common.collect.Range;
import lombok.RequiredArgsConstructor;
+import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.type.WhereAvailable;
import org.apache.shardingsphere.infra.exception.ShardingSphereException;
@@ -51,7 +52,6 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
-import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
@@ -95,11 +95,16 @@ public final class WhereClauseShardingConditionEngine
implements ShardingConditi
}
private Map<String, String> getColumnTableNames(final
SQLStatementContext<?> sqlStatementContext, final Collection<AndPredicate>
andPredicates) {
- Collection<ColumnSegment> columns =
andPredicates.stream().flatMap(each -> each.getPredicates().stream())
- .flatMap(each ->
ColumnExtractor.extract(each).stream()).filter(Objects::nonNull).collect(Collectors.toList());
+ Collection<ColumnProjection> columns =
andPredicates.stream().flatMap(each -> each.getPredicates().stream())
+ .flatMap(each ->
ColumnExtractor.extract(each).stream().map(this::buildColumnProjection)).collect(Collectors.toList());
return sqlStatementContext.getTablesContext().findTableName(columns,
schema);
}
+ private ColumnProjection buildColumnProjection(final ColumnSegment
segment) {
+ String owner = segment.getOwner().map(optional ->
optional.getIdentifier().getValue()).orElse(null);
+ return new ColumnProjection(owner, segment.getIdentifier().getValue(),
null);
+ }
+
private Collection<WhereSegment> getWhereSegments(final
SQLStatementContext<?> sqlStatementContext) {
Collection<WhereSegment> result = new LinkedList<>();
((WhereAvailable)
sqlStatementContext).getWhere().ifPresent(result::add);
@@ -115,11 +120,12 @@ public final class WhereClauseShardingConditionEngine
implements ShardingConditi
Map<Column, Collection<ShardingConditionValue>> result = new
HashMap<>(predicates.size(), 1);
for (ExpressionSegment each : predicates) {
for (ColumnSegment columnSegment : ColumnExtractor.extract(each)) {
- Optional<String> tableName =
Optional.ofNullable(columnTableNames.get(columnSegment.getQualifiedName()));
+ ColumnProjection projection =
buildColumnProjection(columnSegment);
+ Optional<String> tableName =
Optional.ofNullable(columnTableNames.get(projection.getExpression()));
if (!tableName.isPresent() ||
!shardingRule.isShardingColumn(columnSegment.getIdentifier().getValue(),
tableName.get())) {
continue;
}
- Column column = new
Column(columnSegment.getIdentifier().getValue(), tableName.get());
+ Column column = new Column(projection.getName(),
tableName.get());
Optional<ShardingConditionValue> shardingConditionValue =
ConditionValueGeneratorFactory.generate(each, column, parameters);
if (!shardingConditionValue.isPresent()) {
continue;
diff --git
a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
index 5d27a21..d182228 100644
---
a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
+++
b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
@@ -25,7 +25,6 @@ import
org.apache.shardingsphere.infra.binder.segment.select.subquery.SubqueryTa
import
org.apache.shardingsphere.infra.binder.segment.select.subquery.engine.SubqueryTableContextEngine;
import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
-import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
@@ -35,7 +34,6 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
-import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
@@ -96,70 +94,58 @@ public final class TablesContext {
/**
* Find table name.
*
- * @param columns column segment collection
+ * @param columns column projection collection
* @param schema schema meta data
* @return table name map
*/
- public Map<String, String> findTableName(final Collection<ColumnSegment>
columns, final ShardingSphereSchema schema) {
- if (1 == tableNames.size()) {
- String tableName = tableNames.iterator().next();
- return
columns.stream().collect(Collectors.toMap(ColumnSegment::getQualifiedName, each
-> tableName, (oldValue, currentValue) -> oldValue));
+ public Map<String, String> findTableName(final
Collection<ColumnProjection> columns, final ShardingSphereSchema schema) {
+ if (1 == tables.size()) {
+ String tableName =
tables.iterator().next().getTableName().getIdentifier().getValue();
+ return
columns.stream().collect(Collectors.toMap(ColumnProjection::getExpression, each
-> tableName, (oldValue, currentValue) -> oldValue));
}
Map<String, String> result = new HashMap<>(columns.size(), 1);
- Map<String, List<ColumnSegment>> ownerColumns =
columns.stream().filter(each ->
each.getOwner().isPresent()).collect(Collectors.groupingBy(each
- -> each.getOwner().map(optional ->
optional.getIdentifier().getValue()).orElse(null), () -> new
TreeMap<>(String.CASE_INSENSITIVE_ORDER), Collectors.toList()));
- result.putAll(findTableNameFromSQL(ownerColumns));
- Collection<String> columnNames = columns.stream().filter(each ->
!each.getOwner().isPresent()).map(each ->
each.getIdentifier().getValue()).collect(Collectors.toSet());
+ result.putAll(findTableNameFromSQL(getOwnerColumnNames(columns)));
+ Collection<String> columnNames = columns.stream().filter(each -> null
== each.getOwner()).map(ColumnProjection::getName).collect(Collectors.toSet());
result.putAll(findTableNameFromMetaData(columnNames, schema));
+ if (result.size() < columns.size() && !subqueryTables.isEmpty()) {
+ appendRemainingResult(columns, result);
+ }
return result;
}
- /**
- * Find table name.
- *
- * @param column column projection
- * @param schema schema meta data
- * @return table name
- */
- public Optional<String> findTableName(final ColumnProjection column, final
ShardingSphereSchema schema) {
- if (1 == tableNames.size()) {
- return Optional.of(tableNames.iterator().next());
- }
- if (null != column.getOwner()) {
- return findTableNameFromSQL(column.getOwner());
+ private void appendRemainingResult(final Collection<ColumnProjection>
columns, final Map<String, String> result) {
+ Collection<ColumnProjection> remainingColumns =
columns.stream().filter(each ->
!result.containsKey(each.getExpression())).collect(Collectors.toList());
+ for (ColumnProjection each : remainingColumns) {
+ findTableNameFromSubquery(each.getName(),
each.getOwner()).ifPresent(optional -> result.put(each.getExpression(),
optional));
}
- return findTableNameFromMetaData(column.getName(), schema);
}
- /**
- * Find table name from SQL.
- *
- * @param tableNameOrAlias table name or alias
- * @return table name
- */
- public Optional<String> findTableNameFromSQL(final String
tableNameOrAlias) {
- for (SimpleTableSegment each : tables) {
- String tableName = each.getTableName().getIdentifier().getValue();
- if (tableNameOrAlias.equalsIgnoreCase(tableName) ||
tableNameOrAlias.equalsIgnoreCase(each.getAlias().orElse(null))) {
- return Optional.of(tableName);
+ private Map<String, Collection<String>> getOwnerColumnNames(final
Collection<ColumnProjection> columns) {
+ Map<String, Collection<String>> result = new
TreeMap<>(String.CASE_INSENSITIVE_ORDER);
+ for (ColumnProjection each : columns) {
+ if (null == each.getOwner()) {
+ continue;
}
+ Collection<String> columnExpressions =
result.getOrDefault(each.getOwner(), new LinkedList<>());
+ columnExpressions.add(each.getExpression());
+ result.put(each.getOwner(), columnExpressions);
}
- return Optional.empty();
+ return result;
}
- private Map<String, String> findTableNameFromSQL(final Map<String,
List<ColumnSegment>> ownerColumns) {
- if (ownerColumns.isEmpty()) {
+ private Map<String, String> findTableNameFromSQL(final Map<String,
Collection<String>> ownerColumnNames) {
+ if (ownerColumnNames.isEmpty()) {
return Collections.emptyMap();
}
Map<String, String> result = new HashMap<>();
for (SimpleTableSegment each : tables) {
String tableName = each.getTableName().getIdentifier().getValue();
- if (ownerColumns.containsKey(tableName)) {
-
ownerColumns.get(tableName).stream().map(ColumnSegment::getQualifiedName).forEach(column
-> result.put(column, tableName));
+ if (ownerColumnNames.containsKey(tableName)) {
+ ownerColumnNames.get(tableName).forEach(column ->
result.put(column, tableName));
}
Optional<String> alias = each.getAlias();
- if (alias.isPresent() && ownerColumns.containsKey(alias.get())) {
-
ownerColumns.get(alias.get()).stream().map(ColumnSegment::getQualifiedName).forEach(column
-> result.put(column, tableName));
+ if (alias.isPresent() &&
ownerColumnNames.containsKey(alias.get())) {
+ ownerColumnNames.get(alias.get()).forEach(column ->
result.put(column, tableName));
}
}
return result;
@@ -184,24 +170,7 @@ public final class TablesContext {
return result;
}
- private Optional<String> findTableNameFromMetaData(final String
columnName, final ShardingSphereSchema schema) {
- for (SimpleTableSegment each : tables) {
- String tableName = each.getTableName().getIdentifier().getValue();
- if (schema.containsColumn(tableName, columnName)) {
- return Optional.of(tableName);
- }
- }
- return Optional.empty();
- }
-
- /**
- * Find table name from subquery.
- *
- * @param columnName column name
- * @param owner column owner
- * @return table name
- */
- public Optional<String> findTableNameFromSubquery(final String columnName,
final String owner) {
+ private Optional<String> findTableNameFromSubquery(final String
columnName, final String owner) {
Collection<SubqueryTableContext> subqueryTableContexts =
subqueryTables.get(owner);
if (null != subqueryTableContexts) {
return subqueryTableContexts.stream().filter(each ->
each.getColumnNames().contains(columnName)).map(SubqueryTableContext::getTableName).findFirst();
diff --git
a/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContextTest.java
b/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContextTest.java
index 7da9e93..b652682 100644
---
a/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContextTest.java
+++
b/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContextTest.java
@@ -18,8 +18,8 @@
package org.apache.shardingsphere.infra.binder.segment.table;
import com.google.common.collect.Sets;
+import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
-import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.AliasSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
@@ -30,7 +30,6 @@ import org.junit.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
-import java.util.Optional;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertFalse;
@@ -58,7 +57,8 @@ public final class TablesContextTest {
@Test
public void assertFindTableNameWhenSingleTable() {
SimpleTableSegment tableSegment = createTableSegment("table_1",
"tbl_1");
- Map<String, String> actual = new
TablesContext(Collections.singletonList(tableSegment)).findTableName(Collections.singletonList(createColumnSegment()),
mock(ShardingSphereSchema.class));
+ ColumnProjection columnProjection = createColumnProjection(null,
"col", null);
+ Map<String, String> actual = new
TablesContext(Collections.singletonList(tableSegment)).findTableName(Collections.singletonList(columnProjection),
mock(ShardingSphereSchema.class));
assertFalse(actual.isEmpty());
assertThat(actual.get("col"), is("table_1"));
}
@@ -67,9 +67,8 @@ public final class TablesContextTest {
public void assertFindTableNameWhenColumnSegmentOwnerPresent() {
SimpleTableSegment tableSegment1 = createTableSegment("table_1",
"tbl_1");
SimpleTableSegment tableSegment2 = createTableSegment("table_2",
"tbl_2");
- ColumnSegment columnSegment = createColumnSegment();
- columnSegment.setOwner(new OwnerSegment(0, 10, new
IdentifierValue("table_1")));
- Map<String, String> actual = new
TablesContext(Arrays.asList(tableSegment1,
tableSegment2)).findTableName(Collections.singletonList(columnSegment),
mock(ShardingSphereSchema.class));
+ ColumnProjection columnProjection = createColumnProjection("table_1",
"col", "");
+ Map<String, String> actual = new
TablesContext(Arrays.asList(tableSegment1,
tableSegment2)).findTableName(Collections.singletonList(columnProjection),
mock(ShardingSphereSchema.class));
assertFalse(actual.isEmpty());
assertThat(actual.get("table_1.col"), is("table_1"));
}
@@ -78,7 +77,8 @@ public final class TablesContextTest {
public void assertFindTableNameWhenColumnSegmentOwnerAbsent() {
SimpleTableSegment tableSegment1 = createTableSegment("table_1",
"tbl_1");
SimpleTableSegment tableSegment2 = createTableSegment("table_2",
"tbl_2");
- Map<String, String> actual = new
TablesContext(Arrays.asList(tableSegment1,
tableSegment2)).findTableName(Collections.singletonList(createColumnSegment()),
mock(ShardingSphereSchema.class));
+ ColumnProjection columnProjection = createColumnProjection(null,
"col", null);
+ Map<String, String> actual = new
TablesContext(Arrays.asList(tableSegment1,
tableSegment2)).findTableName(Collections.singletonList(columnProjection),
mock(ShardingSphereSchema.class));
assertTrue(actual.isEmpty());
}
@@ -88,20 +88,12 @@ public final class TablesContextTest {
SimpleTableSegment tableSegment2 = createTableSegment("table_2",
"tbl_2");
ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
when(schema.getAllColumnNames("table_1")).thenReturn(Collections.singletonList("col"));
- Map<String, String> actual = new
TablesContext(Arrays.asList(tableSegment1,
tableSegment2)).findTableName(Collections.singletonList(createColumnSegment()),
schema);
+ ColumnProjection columnProjection = createColumnProjection(null,
"col", null);
+ Map<String, String> actual = new
TablesContext(Arrays.asList(tableSegment1,
tableSegment2)).findTableName(Collections.singletonList(columnProjection),
schema);
assertFalse(actual.isEmpty());
assertThat(actual.get("col"), is("table_1"));
}
- @Test
- public void assertFindTableNameWhenTableNameOrAliasIgnoreCase() {
- SimpleTableSegment tableSegment1 = createTableSegment("table_1",
"tbl_1");
- SimpleTableSegment tableSegment2 = createTableSegment("table_2",
"tbl_2");
- Optional<String> actual = new
TablesContext(Arrays.asList(tableSegment1,
tableSegment2)).findTableNameFromSQL("Tbl_1");
- assertTrue(actual.isPresent());
- assertThat(actual.get(), is("table_1"));
- }
-
private SimpleTableSegment createTableSegment(final String tableName,
final String alias) {
SimpleTableSegment result = new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue(tableName)));
AliasSegment aliasSegment = new AliasSegment(0, 0, new
IdentifierValue(alias));
@@ -109,8 +101,8 @@ public final class TablesContextTest {
return result;
}
- private ColumnSegment createColumnSegment() {
- return new ColumnSegment(0, 0, new IdentifierValue("col"));
+ private ColumnProjection createColumnProjection(final String owner, final
String name, final String alias) {
+ return new ColumnProjection(owner, name, alias);
}
@Test
diff --git
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/select_for_query_with_cipher.xml
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/select_for_query_with_cipher.xml
index 6d97a3e..de53f92 100644
---
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/select_for_query_with_cipher.xml
+++
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/select_for_query_with_cipher.xml
@@ -73,6 +73,11 @@
<input sql="SELECT a.*, account_id, 1+1 FROM t_account a" />
<output sql="SELECT `a`.`account_id`, `a`.`cipher_certificate_number`
AS certificate_number, `a`.`cipher_password` AS password, `a`.`cipher_amount`
AS amount, `a`.`status`, account_id, 1+1 FROM t_account a" />
</rewrite-assertion>
+
+ <rewrite-assertion id="select_unqualified_shorthand_projection_with_join"
db-types="MySQL">
+ <input sql="SELECT * FROM t_account t INNER JOIN t_account_bak b ON
t.id = b.id WHERE t.amount = ? OR b.amount = ?" parameters="1, 2" />
+ <output sql="SELECT `t`.`account_id`, `t`.`cipher_certificate_number`
AS certificate_number, `t`.`cipher_password` AS password, `t`.`cipher_amount`
AS amount, `t`.`status`, `b`.`account_id`, `b`.`cipher_certificate_number` AS
certificate_number, `b`.`cipher_password` AS password, `b`.`cipher_amount` AS
amount, `b`.`status` FROM t_account t INNER JOIN t_account_bak b ON t.id = b.id
WHERE t.cipher_amount = ? OR b.cipher_amount = ?" parameters="encrypt_1,
encrypt_2" />
+ </rewrite-assertion>
<rewrite-assertion id="select_with_join" db-types="MySQL">
<input sql="SELECT t_account.amount, t_account_bak.amount FROM
t_account LEFT JOIN t_account_bak ON t_account.id = t_account_bak.id WHERE
t_account.amount = ? OR t_account_bak.amount = ?" parameters="1, 2" />