This is an automated email from the ASF dual-hosted git repository.

kimmking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new a8fbbc2  fix subquery sharding key check result and rewrite result 
(#7043)
a8fbbc2 is described below

commit a8fbbc21716cdbe5cd30f8179aeda11f4f7e593e
Author: DuanZhengqiang <[email protected]>
AuthorDate: Mon Aug 24 22:21:23 2020 -0500

    fix subquery sharding key check result and rewrite result (#7043)
    
    * fix subquery sharding key check result
    
    * fix checkstyle and test case
---
 .../src/test/resources/sharding/select.xml         | 10 +++
 .../route/engine/ShardingRouteDecorator.java       |  4 +-
 .../engine/WhereClauseShardingConditionEngine.java |  6 +-
 .../ConditionValueBetweenOperatorGenerator.java    |  4 +-
 ...ConditionValueBetweenOperatorGeneratorTest.java |  6 +-
 ...ionUtils.java => SafeNumberOperationUtils.java} | 30 ++++++--
 .../sql/parser/sql/util/TableExtractUtils.java     | 29 ++++++--
 ...Test.java => SafeNumberOperationUtilsTest.java} | 82 +++++++++++++++++-----
 8 files changed, 133 insertions(+), 38 deletions(-)

diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml
index 044325f..acdb3f9 100644
--- 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml
@@ -63,6 +63,16 @@
         <input sql="SELECT (select id from t_account limit 1) as myid FROM 
(select b.account_id from (select t_account.account_id from t_account) b where 
b.account_id=?) a WHERE account_id >= (select account_id from t_account limit 
1)"  parameters="100"/>
         <output sql="SELECT (select id from t_account_0 limit 1) as myid FROM 
(select b.account_id from (select t_account_0.account_id from t_account_0) b 
where b.account_id=?) a WHERE account_id >= (select account_id from t_account_0 
limit 1)"  parameters="100"/>
     </rewrite-assertion>
+
+    <rewrite-assertion id="select_with_subquery_for_where_in_predicate" 
db-type="MySQL">
+        <input sql="SELECT * FROM t_account WHERE account_id = ? AND amount IN 
(SELECT amount FROM t_account WHERE account_id = ?)"  parameters="100, 100"/>
+        <output sql="SELECT * FROM t_account_0 WHERE account_id = ? AND amount 
IN (SELECT amount FROM t_account_0 WHERE account_id = ?)"  parameters="100, 
100"/>
+    </rewrite-assertion>
+
+    <rewrite-assertion 
id="select_with_subquery_for_where_between_and_predicate" db-type="MySQL">
+        <input sql="SELECT * FROM t_account WHERE account_id = ? AND amount 
BETWEEN (SELECT amount FROM t_account WHERE account_id = ?) AND ?"  
parameters="100, 100, 1500"/>
+        <output sql="SELECT * FROM t_account_0 WHERE account_id = ? AND amount 
BETWEEN (SELECT amount FROM t_account_0 WHERE account_id = ?) AND ?"  
parameters="100, 100, 1500"/>
+    </rewrite-assertion>
     
     <rewrite-assertion id="select_without_sharding_value_for_parameters">
         <input sql="SELECT * FROM db.t_account WHERE amount = ?" 
parameters="1000" />
diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/ShardingRouteDecorator.java
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/ShardingRouteDecorator.java
index 77cc5e8..0c9ada8 100644
--- 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/ShardingRouteDecorator.java
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/ShardingRouteDecorator.java
@@ -45,6 +45,7 @@ import 
org.apache.shardingsphere.sql.parser.binder.statement.dml.InsertStatement
 import 
org.apache.shardingsphere.sql.parser.binder.statement.dml.SelectStatementContext;
 import org.apache.shardingsphere.sql.parser.sql.statement.SQLStatement;
 import org.apache.shardingsphere.sql.parser.sql.statement.dml.DMLStatement;
+import org.apache.shardingsphere.sql.parser.sql.util.SafeNumberOperationUtils;
 
 import java.util.Collections;
 import java.util.List;
@@ -136,7 +137,8 @@ public final class ShardingRouteDecorator implements 
RouteDecorator<ShardingRule
     }
     
     private boolean isSameRouteValue(final ShardingRule shardingRule, final 
ListRouteValue routeValue1, final ListRouteValue routeValue2) {
-        return isSameLogicTable(shardingRule, routeValue1, routeValue2) && 
routeValue1.getColumnName().equals(routeValue2.getColumnName()) && 
routeValue1.getValues().equals(routeValue2.getValues());
+        return isSameLogicTable(shardingRule, routeValue1, routeValue2) && 
routeValue1.getColumnName().equals(routeValue2.getColumnName()) 
+                && 
SafeNumberOperationUtils.safeEquals(routeValue1.getValues(), 
routeValue2.getValues());
     }
     
     private boolean isSameLogicTable(final ShardingRule shardingRule, final 
ListRouteValue shardingValue1, final ListRouteValue shardingValue2) {
diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/WhereClauseShardingConditionEngine.java
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/WhereClauseShardingConditionEngine.java
index 623e533..929d1f4 100644
--- 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/WhereClauseShardingConditionEngine.java
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/WhereClauseShardingConditionEngine.java
@@ -36,7 +36,7 @@ import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.AndPredica
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.PredicateSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.WhereSegment;
 import org.apache.shardingsphere.sql.parser.sql.statement.dml.SelectStatement;
-import org.apache.shardingsphere.sql.parser.sql.util.SafeRangeOperationUtils;
+import org.apache.shardingsphere.sql.parser.sql.util.SafeNumberOperationUtils;
 import org.apache.shardingsphere.sql.parser.sql.util.WhereSegmentExtractUtils;
 
 import java.util.ArrayList;
@@ -171,13 +171,13 @@ public final class WhereClauseShardingConditionEngine {
     }
     
     private Range<Comparable<?>> mergeRangeRouteValues(final 
Range<Comparable<?>> value1, final Range<Comparable<?>> value2) {
-        return null == value2 ? value1 : 
SafeRangeOperationUtils.safeIntersection(value1, value2);
+        return null == value2 ? value1 : 
SafeNumberOperationUtils.safeIntersection(value1, value2);
     }
     
     private Collection<Comparable<?>> mergeListAndRangeRouteValues(final 
Collection<Comparable<?>> listValue, final Range<Comparable<?>> rangeValue) {
         Collection<Comparable<?>> result = new LinkedList<>();
         for (Comparable<?> each : listValue) {
-            if (SafeRangeOperationUtils.safeContains(rangeValue, each)) {
+            if (SafeNumberOperationUtils.safeContains(rangeValue, each)) {
                 result.add(each);
             }
         }
diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/generator/impl/ConditionValueBetweenOperatorGenerator.java
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/generator/impl/ConditionValueBetweenOperatorGenerator.java
index cb25672..f816584 100644
--- 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/generator/impl/ConditionValueBetweenOperatorGenerator.java
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/generator/impl/ConditionValueBetweenOperatorGenerator.java
@@ -26,7 +26,7 @@ import 
org.apache.shardingsphere.sharding.route.engine.condition.generator.Condi
 import 
org.apache.shardingsphere.sharding.route.engine.condition.generator.ConditionValueGenerator;
 import org.apache.shardingsphere.sharding.route.spi.SPITimeService;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateBetweenRightValue;
-import org.apache.shardingsphere.sql.parser.sql.util.SafeRangeOperationUtils;
+import org.apache.shardingsphere.sql.parser.sql.util.SafeNumberOperationUtils;
 
 import java.util.Date;
 import java.util.List;
@@ -42,7 +42,7 @@ public final class ConditionValueBetweenOperatorGenerator 
implements ConditionVa
         Optional<Comparable<?>> betweenRouteValue = new 
ConditionValue(predicateRightValue.getBetweenExpression(), 
parameters).getValue();
         Optional<Comparable<?>> andRouteValue = new 
ConditionValue(predicateRightValue.getAndExpression(), parameters).getValue();
         if (betweenRouteValue.isPresent() && andRouteValue.isPresent()) {
-            return Optional.of(new RangeRouteValue<>(column.getName(), 
column.getTableName(), 
SafeRangeOperationUtils.safeClosed(betweenRouteValue.get(), 
andRouteValue.get())));
+            return Optional.of(new RangeRouteValue<>(column.getName(), 
column.getTableName(), 
SafeNumberOperationUtils.safeClosed(betweenRouteValue.get(), 
andRouteValue.get())));
         }
         Date date = new SPITimeService().getTime();
         if (!betweenRouteValue.isPresent() && 
ExpressionConditionUtils.isNowExpression(predicateRightValue.getBetweenExpression()))
 {
diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/condition/generator/impl/ConditionValueBetweenOperatorGeneratorTest.java
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/condition/generator/impl/ConditionValueBetweenOperatorGeneratorTest.java
index 940e8ca..e781168 100644
--- 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/condition/generator/impl/ConditionValueBetweenOperatorGeneratorTest.java
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/condition/generator/impl/ConditionValueBetweenOperatorGeneratorTest.java
@@ -24,7 +24,7 @@ import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegme
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.complex.CommonExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateBetweenRightValue;
-import org.apache.shardingsphere.sql.parser.sql.util.SafeRangeOperationUtils;
+import org.apache.shardingsphere.sql.parser.sql.util.SafeNumberOperationUtils;
 import org.junit.Test;
 
 import java.util.Calendar;
@@ -72,8 +72,8 @@ public final class ConditionValueBetweenOperatorGeneratorTest 
{
         RangeRouteValue<Comparable<?>> rangeRouteValue = 
(RangeRouteValue<Comparable<?>>) routeValue.get();
         assertThat(rangeRouteValue.getColumnName(), is(column.getName()));
         assertThat(rangeRouteValue.getTableName(), is(column.getTableName()));
-        
assertTrue(SafeRangeOperationUtils.safeContains(rangeRouteValue.getValueRange(),
 between));
-        
assertTrue(SafeRangeOperationUtils.safeContains(rangeRouteValue.getValueRange(),
 and));
+        
assertTrue(SafeNumberOperationUtils.safeContains(rangeRouteValue.getValueRange(),
 between));
+        
assertTrue(SafeNumberOperationUtils.safeContains(rangeRouteValue.getValueRange(),
 and));
     }
     
     @Test(expected = ClassCastException.class)
diff --git 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/SafeRangeOperationUtils.java
 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/SafeNumberOperationUtils.java
similarity index 82%
rename from 
shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/SafeRangeOperationUtils.java
rename to 
shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/SafeNumberOperationUtils.java
index dfb41ab..13c8756 100644
--- 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/SafeRangeOperationUtils.java
+++ 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/SafeNumberOperationUtils.java
@@ -26,19 +26,20 @@ import lombok.SneakyThrows;
 
 import java.math.BigDecimal;
 import java.math.BigInteger;
+import java.util.Collection;
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
 import java.util.stream.Collectors;
 
 /**
- * Safe range operation utility class.
+ * Safe number operation utility class.
  */
 @NoArgsConstructor(access = AccessLevel.PRIVATE)
-public final class SafeRangeOperationUtils {
+public final class SafeNumberOperationUtils {
     
     /**
-     * Execute intersection method by safe mode.
+     * Execute range intersection method by safe mode.
      *
      * @param range range
      * @param connectedRange connected range
@@ -63,7 +64,7 @@ public final class SafeRangeOperationUtils {
     }
     
     /**
-     * Execute closed method by safe mode.
+     * Execute range closed method by safe mode.
      *
      * @param lowerEndpoint lower endpoint
      * @param upperEndpoint upper endpoint
@@ -82,7 +83,7 @@ public final class SafeRangeOperationUtils {
     }
     
     /**
-     * Execute contains method by safe mode.
+     * Execute range contains method by safe mode.
      *
      * @param range range
      * @param endpoint endpoint
@@ -102,6 +103,25 @@ public final class SafeRangeOperationUtils {
             return newRange.contains(parseNumberByClazz(endpoint.toString(), 
clazz));
         }
     }
+
+    /**
+     * Execute collection equals method by safe mode.
+     *
+     * @param sourceCollection source collection
+     * @param targetCollection target collection
+     * @return whether the element in source collection and target collection 
are all same
+     */
+    public static boolean safeEquals(final Collection<Comparable<?>> 
sourceCollection, final Collection<Comparable<?>> targetCollection) {
+        List<Comparable<?>> collection = Lists.newArrayList(sourceCollection);
+        collection.addAll(targetCollection);
+        Class<?> clazz = getTargetNumericType(collection);
+        if (null == clazz) {
+            return sourceCollection.equals(targetCollection);
+        }
+        List<Comparable<?>> sourceClazzCollection = 
sourceCollection.stream().map(number -> parseNumberByClazz(number.toString(), 
clazz)).collect(Collectors.toList());
+        List<Comparable<?>> targetClazzCollection = 
targetCollection.stream().map(number -> parseNumberByClazz(number.toString(), 
clazz)).collect(Collectors.toList());
+        return sourceClazzCollection.equals(targetClazzCollection);
+    }
     
     private static Range<Comparable<?>> createTargetNumericTypeRange(final 
Range<Comparable<?>> range, final Class<?> clazz) {
         if (range.hasLowerBound() && range.hasUpperBound()) {
diff --git 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/TableExtractUtils.java
 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/TableExtractUtils.java
index 38467c4..e4d7d45 100644
--- 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/TableExtractUtils.java
+++ 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/TableExtractUtils.java
@@ -23,6 +23,7 @@ import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.JoinedTableSegment;
 import org.apache.shardingsphere.sql.parser.sql.segment.dml.TableFactorSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.TableReferenceSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.subquery.SubqueryExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.item.ColumnProjectionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.item.ProjectionSegment;
@@ -33,7 +34,9 @@ import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.order.item.OrderByIt
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.AndPredicate;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.PredicateSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.WhereSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateBetweenRightValue;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateCompareRightValue;
+import 
org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateInRightValue;
 import org.apache.shardingsphere.sql.parser.sql.segment.generic.OwnerAvailable;
 import org.apache.shardingsphere.sql.parser.sql.segment.generic.OwnerSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SimpleTableSegment;
@@ -224,12 +227,28 @@ public final class TableExtractUtils {
             if (((PredicateCompareRightValue) 
predicate.getRightValue()).getExpression() instanceof 
SubqueryExpressionSegment) {
                 
result.addAll(TableExtractUtils.getTablesFromSelect(((SubqueryExpressionSegment)
 ((PredicateCompareRightValue) 
predicate.getRightValue()).getExpression()).getSubquery().getSelect()));
             }
-        } else {
-            if (predicate.getRightValue() instanceof ColumnSegment) {
-                Preconditions.checkState(((ColumnSegment) 
predicate.getRightValue()).getOwner().isPresent());
-                OwnerSegment segment = ((ColumnSegment) 
predicate.getRightValue()).getOwner().get();
-                result.add(new SimpleTableSegment(segment.getStartIndex(), 
segment.getStopIndex(), segment.getIdentifier()));
+        }
+        if (predicate.getRightValue() instanceof PredicateInRightValue) {
+            for (ExpressionSegment expressionSegment : 
((PredicateInRightValue) predicate.getRightValue()).getSqlExpressions()) {
+                if (expressionSegment instanceof SubqueryExpressionSegment) {
+                    
result.addAll(TableExtractUtils.getTablesFromSelect(((SubqueryExpressionSegment)
 expressionSegment).getSubquery().getSelect()));
+                }
+            }
+        } 
+        if (predicate.getRightValue() instanceof PredicateBetweenRightValue) {
+            if (((PredicateBetweenRightValue) 
predicate.getRightValue()).getBetweenExpression() instanceof 
SubqueryExpressionSegment) {
+                SelectStatement subquerySelect = ((SubqueryExpressionSegment) 
(((PredicateBetweenRightValue) 
predicate.getRightValue()).getBetweenExpression())).getSubquery().getSelect();
+                
result.addAll(TableExtractUtils.getTablesFromSelect(subquerySelect));    
             }
+            if (((PredicateBetweenRightValue) 
predicate.getRightValue()).getAndExpression() instanceof 
SubqueryExpressionSegment) {
+                SelectStatement subquerySelect = ((SubqueryExpressionSegment) 
(((PredicateBetweenRightValue) 
predicate.getRightValue()).getAndExpression())).getSubquery().getSelect();
+                
result.addAll(TableExtractUtils.getTablesFromSelect(subquerySelect));
+            }
+        } 
+        if (predicate.getRightValue() instanceof ColumnSegment) {
+            Preconditions.checkState(((ColumnSegment) 
predicate.getRightValue()).getOwner().isPresent());
+            OwnerSegment segment = ((ColumnSegment) 
predicate.getRightValue()).getOwner().get();
+            result.add(new SimpleTableSegment(segment.getStartIndex(), 
segment.getStopIndex(), segment.getIdentifier()));
         }
         return result;
     }
diff --git 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/util/SafeRangeOperationUtilsTest.java
 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/util/SafeNumberOperationUtilsTest.java
similarity index 62%
rename from 
shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/util/SafeRangeOperationUtilsTest.java
rename to 
shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/util/SafeNumberOperationUtilsTest.java
index 4be724d..3d05cab 100644
--- 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/util/SafeRangeOperationUtilsTest.java
+++ 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/util/SafeNumberOperationUtilsTest.java
@@ -18,24 +18,26 @@
 package org.apache.shardingsphere.sql.parser.sql.util;
 
 import com.google.common.collect.BoundType;
+import com.google.common.collect.Lists;
 import com.google.common.collect.Range;
 import org.junit.Test;
 
 import java.math.BigDecimal;
 import java.math.BigInteger;
+import java.util.List;
 
 import static org.hamcrest.CoreMatchers.is;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 
-public class SafeRangeOperationUtilsTest {
+public class SafeNumberOperationUtilsTest {
     
     @Test
     public void assertSafeIntersectionForInteger() {
         Range<Comparable<?>> range = Range.closed(10, 2000);
         Range<Comparable<?>> connectedRange = Range.closed(1500, 4000);
-        Range<Comparable<?>> newRange = 
SafeRangeOperationUtils.safeIntersection(range, connectedRange);
+        Range<Comparable<?>> newRange = 
SafeNumberOperationUtils.safeIntersection(range, connectedRange);
         assertThat(newRange.lowerEndpoint(), is(1500));
         assertThat(newRange.lowerBoundType(), is(BoundType.CLOSED));
         assertThat(newRange.upperEndpoint(), is(2000));
@@ -46,7 +48,7 @@ public class SafeRangeOperationUtilsTest {
     public void assertSafeIntersectionForLong() {
         Range<Comparable<?>> range = Range.upTo(3147483647L, BoundType.OPEN);
         Range<Comparable<?>> connectedRange = Range.downTo(3, BoundType.OPEN);
-        Range<Comparable<?>> newRange = 
SafeRangeOperationUtils.safeIntersection(range, connectedRange);
+        Range<Comparable<?>> newRange = 
SafeNumberOperationUtils.safeIntersection(range, connectedRange);
         assertThat(newRange.lowerEndpoint(), is(3L));
         assertThat(newRange.lowerBoundType(), is(BoundType.OPEN));
         assertThat(newRange.upperEndpoint(), is(3147483647L));
@@ -57,7 +59,7 @@ public class SafeRangeOperationUtilsTest {
     public void assertSafeIntersectionForBigInteger() {
         Range<Comparable<?>> range = Range.upTo(new 
BigInteger("131323233123211"), BoundType.CLOSED);
         Range<Comparable<?>> connectedRange = Range.downTo(35, BoundType.OPEN);
-        Range<Comparable<?>> newRange = 
SafeRangeOperationUtils.safeIntersection(range, connectedRange);
+        Range<Comparable<?>> newRange = 
SafeNumberOperationUtils.safeIntersection(range, connectedRange);
         assertThat(newRange.lowerEndpoint(), is(new BigInteger("35")));
         assertThat(newRange.lowerBoundType(), is(BoundType.OPEN));
         assertThat(newRange.upperEndpoint(), is(new 
BigInteger("131323233123211")));
@@ -68,7 +70,7 @@ public class SafeRangeOperationUtilsTest {
     public void assertSafeIntersectionForFloat() {
         Range<Comparable<?>> range = Range.closed(5.5F, 13.8F);
         Range<Comparable<?>> connectedRange = Range.closed(7.14F, 11.3F);
-        Range<Comparable<?>> newRange = 
SafeRangeOperationUtils.safeIntersection(range, connectedRange);
+        Range<Comparable<?>> newRange = 
SafeNumberOperationUtils.safeIntersection(range, connectedRange);
         assertThat(newRange.lowerEndpoint(), is(7.14F));
         assertThat(newRange.lowerBoundType(), is(BoundType.CLOSED));
         assertThat(newRange.upperEndpoint(), is(11.3F));
@@ -79,7 +81,7 @@ public class SafeRangeOperationUtilsTest {
     public void assertSafeIntersectionForDouble() {
         Range<Comparable<?>> range = Range.closed(1242.114, 31474836.12);
         Range<Comparable<?>> connectedRange = Range.downTo(567.34F, 
BoundType.OPEN);
-        Range<Comparable<?>> newRange = 
SafeRangeOperationUtils.safeIntersection(range, connectedRange);
+        Range<Comparable<?>> newRange = 
SafeNumberOperationUtils.safeIntersection(range, connectedRange);
         assertThat(newRange.lowerEndpoint(), is(1242.114));
         assertThat(newRange.lowerBoundType(), is(BoundType.CLOSED));
         assertThat(newRange.upperEndpoint(), is(31474836.12));
@@ -90,7 +92,7 @@ public class SafeRangeOperationUtilsTest {
     public void assertSafeIntersectionForBigDecimal() {
         Range<Comparable<?>> range = Range.upTo(new BigDecimal("2331.23211"), 
BoundType.CLOSED);
         Range<Comparable<?>> connectedRange = Range.open(135.13F, 45343.23F);
-        Range<Comparable<?>> newRange = 
SafeRangeOperationUtils.safeIntersection(range, connectedRange);
+        Range<Comparable<?>> newRange = 
SafeNumberOperationUtils.safeIntersection(range, connectedRange);
         assertThat(newRange.lowerEndpoint(), is(new BigDecimal("135.13")));
         assertThat(newRange.lowerBoundType(), is(BoundType.OPEN));
         assertThat(newRange.upperEndpoint(), is(new BigDecimal("2331.23211")));
@@ -99,42 +101,42 @@ public class SafeRangeOperationUtilsTest {
     
     @Test
     public void assertSafeClosedForInteger() {
-        Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(12, 
500);
+        Range<Comparable<?>> range = SafeNumberOperationUtils.safeClosed(12, 
500);
         assertThat(range.lowerEndpoint(), is(12));
         assertThat(range.upperEndpoint(), is(500));
     }
     
     @Test
     public void assertSafeClosedForLong() {
-        Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(12, 
5001L);
+        Range<Comparable<?>> range = SafeNumberOperationUtils.safeClosed(12, 
5001L);
         assertThat(range.lowerEndpoint(), is(12L));
         assertThat(range.upperEndpoint(), is(5001L));
     }
     
     @Test
     public void assertSafeClosedForBigInteger() {
-        Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(12L, 
new BigInteger("12344"));
+        Range<Comparable<?>> range = SafeNumberOperationUtils.safeClosed(12L, 
new BigInteger("12344"));
         assertThat(range.lowerEndpoint(), is(new BigInteger("12")));
         assertThat(range.upperEndpoint(), is(new BigInteger("12344")));
     }
     
     @Test
     public void assertSafeClosedForFloat() {
-        Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(4.5F, 
11.13F);
+        Range<Comparable<?>> range = SafeNumberOperationUtils.safeClosed(4.5F, 
11.13F);
         assertThat(range.lowerEndpoint(), is(4.5F));
         assertThat(range.upperEndpoint(), is(11.13F));
     }
     
     @Test
     public void assertSafeClosedForDouble() {
-        Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(5.12F, 
13.75);
+        Range<Comparable<?>> range = 
SafeNumberOperationUtils.safeClosed(5.12F, 13.75);
         assertThat(range.lowerEndpoint(), is(5.12));
         assertThat(range.upperEndpoint(), is(13.75));
     }
     
     @Test
     public void assertSafeClosedForBigDecimal() {
-        Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(5.1F, 
new BigDecimal("17.666"));
+        Range<Comparable<?>> range = SafeNumberOperationUtils.safeClosed(5.1F, 
new BigDecimal("17.666"));
         assertThat(range.lowerEndpoint(), is(new BigDecimal("5.1")));
         assertThat(range.upperEndpoint(), is(new BigDecimal("17.666")));
     }
@@ -142,36 +144,78 @@ public class SafeRangeOperationUtilsTest {
     @Test
     public void assertSafeContainsForInteger() {
         Range<Comparable<?>> range = Range.closed(12, 100);
-        assertFalse(SafeRangeOperationUtils.safeContains(range, 500));
+        assertFalse(SafeNumberOperationUtils.safeContains(range, 500));
     }
     
     @Test
     public void assertSafeContainsForLong() {
         Range<Comparable<?>> range = Range.closed(12L, 1000L);
-        assertTrue(SafeRangeOperationUtils.safeContains(range, 500));
+        assertTrue(SafeNumberOperationUtils.safeContains(range, 500));
     }
     
     @Test
     public void assertSafeContainsForBigInteger() {
         Range<Comparable<?>> range = Range.closed(new BigInteger("123"), new 
BigInteger("1000"));
-        assertTrue(SafeRangeOperationUtils.safeContains(range, 510));
+        assertTrue(SafeNumberOperationUtils.safeContains(range, 510));
     }
     
     @Test
     public void assertSafeContainsForFloat() {
         Range<Comparable<?>> range = Range.closed(123.11F, 9999.123F);
-        assertTrue(SafeRangeOperationUtils.safeContains(range, 510.12));
+        assertTrue(SafeNumberOperationUtils.safeContains(range, 510.12));
     }
     
     @Test
     public void assertSafeContainsForDouble() {
         Range<Comparable<?>> range = Range.closed(11.11, 9999.99);
-        assertTrue(SafeRangeOperationUtils.safeContains(range, new 
BigDecimal("510.12")));
+        assertTrue(SafeNumberOperationUtils.safeContains(range, new 
BigDecimal("510.12")));
     }
     
     @Test
     public void assertSafeContainsForBigDecimal() {
         Range<Comparable<?>> range = Range.closed(new BigDecimal("123.11"), 
new BigDecimal("9999.123"));
-        assertTrue(SafeRangeOperationUtils.safeContains(range, 510.12));
+        assertTrue(SafeNumberOperationUtils.safeContains(range, 510.12));
+    }
+
+    @Test
+    public void assertSafeEqualsForInteger() {
+        List<Comparable<?>> sourceCollection = Lists.newArrayList(10, 12);
+        List<Comparable<?>> targetCollection = Lists.newArrayList(10, 12);
+        assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, 
targetCollection));
+    }
+
+    @Test
+    public void assertSafeEqualsForLong() {
+        List<Comparable<?>> sourceCollection = Lists.newArrayList(10, 12);
+        List<Comparable<?>> targetCollection = Lists.newArrayList(10L, 12L);
+        assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, 
targetCollection));
+    }
+
+    @Test
+    public void assertSafeEqualsForBigInteger() {
+        List<Comparable<?>> sourceCollection = Lists.newArrayList(10, 12);
+        List<Comparable<?>> targetCollection = 
Lists.newArrayList(BigInteger.valueOf(10), BigInteger.valueOf(12L));
+        assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, 
targetCollection));
+    }
+
+    @Test
+    public void assertSafeEqualsForFloat() {
+        List<Comparable<?>> sourceCollection = Lists.newArrayList(10.01F, 
12.01F);
+        List<Comparable<?>> targetCollection = Lists.newArrayList(10.01F, 
12.01F);
+        assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, 
targetCollection));
+    }
+
+    @Test
+    public void assertSafeEqualsForDouble() {
+        List<Comparable<?>> sourceCollection = Lists.newArrayList(10.01, 
12.01);
+        List<Comparable<?>> targetCollection = Lists.newArrayList(10.01F, 
12.01);
+        assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, 
targetCollection));
+    }
+
+    @Test
+    public void assertSafeEqualsForBigDecimal() {
+        List<Comparable<?>> sourceCollection = Lists.newArrayList(10.01, 
12.01);
+        List<Comparable<?>> targetCollection = 
Lists.newArrayList(BigDecimal.valueOf(10.01), BigDecimal.valueOf(12.01));
+        assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, 
targetCollection));
     }
 }

Reply via email to