This is an automated email from the ASF dual-hosted git repository.
amagyar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/knox.git
The following commit(s) were added to refs/heads/master by this push:
new 083dc8977 KNOX-2983 - Combine the functionality of different identity
assertion providers (#817)
083dc8977 is described below
commit 083dc8977fcae7e6669670412d62061a599b49cf
Author: Attila Magyar <[email protected]>
AuthorDate: Fri Nov 17 00:31:39 2023 +0100
KNOX-2983 - Combine the functionality of different identity assertion
providers (#817)
---
.../knox/gateway/IdentityAsserterMessages.java | 3 +
.../filter/CommonIdentityAssertionFilter.java | 35 ++-
.../common/filter/VirtualGroupMapper.java | 6 +-
.../filter/CommonIdentityAssertionFilterTest.java | 3 +
gateway-provider-identity-assertion-regex/pom.xml | 5 +
.../regex/filter/RegexTemplate.java | 0
.../java/org/apache/knox/gateway/plang/Arity.java | 18 ++
.../org/apache/knox/gateway/plang/Interpreter.java | 120 +++++++++-
.../apache/knox/gateway/plang/InterpreterTest.java | 246 +++++++++++++++++++++
9 files changed, 423 insertions(+), 13 deletions(-)
diff --git
a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/IdentityAsserterMessages.java
b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/IdentityAsserterMessages.java
index f533f202a..ab1a3eaab 100644
---
a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/IdentityAsserterMessages.java
+++
b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/IdentityAsserterMessages.java
@@ -63,4 +63,7 @@ public interface IdentityAsserterMessages {
@Message( level = MessageLevel.ERROR, text = "Proxy user Authentication
failed: {0}" )
void hadoopAuthProxyUserFailed(@StackTrace Throwable t);
+
+ @Message( level = MessageLevel.WARN, text = "Invalid result: {2}. Expected
String when evaluating mapping: {1} for user: {0}.")
+ void invalidAdvancedPrincipalMappingResult(String principalName,
AbstractSyntaxTree mapping, Object result);
}
diff --git
a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilter.java
b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilter.java
index 537a99281..f3bb669d7 100644
---
a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilter.java
+++
b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilter.java
@@ -19,6 +19,7 @@ package
org.apache.knox.gateway.identityasserter.common.filter;
import static
org.apache.knox.gateway.identityasserter.common.filter.AbstractIdentityAsserterDeploymentContributor.IMPERSONATION_PARAMS;
import static
org.apache.knox.gateway.identityasserter.common.filter.AbstractIdentityAsserterDeploymentContributor.ROLE;
+import static
org.apache.knox.gateway.identityasserter.common.filter.VirtualGroupMapper.addRequestFunctions;
import java.io.IOException;
import java.security.AccessController;
@@ -32,7 +33,6 @@ import java.util.Map;
import java.util.Set;
import java.util.StringTokenizer;
import java.util.stream.Collectors;
-
import javax.security.auth.Subject;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
@@ -50,6 +50,7 @@ import org.apache.knox.gateway.IdentityAsserterMessages;
import org.apache.knox.gateway.context.ContextAttributes;
import org.apache.knox.gateway.i18n.messages.MessagesFactory;
import org.apache.knox.gateway.plang.AbstractSyntaxTree;
+import org.apache.knox.gateway.plang.Interpreter;
import org.apache.knox.gateway.plang.Parser;
import org.apache.knox.gateway.plang.SyntaxException;
import org.apache.knox.gateway.security.GroupPrincipal;
@@ -67,6 +68,7 @@ public class CommonIdentityAssertionFilter extends
AbstractIdentityAssertionFilt
public static final String VIRTUAL_GROUP_MAPPING_PREFIX = "group.mapping.";
public static final String GROUP_PRINCIPAL_MAPPING =
"group.principal.mapping";
public static final String PRINCIPAL_MAPPING = "principal.mapping";
+ public static final String ADVANCED_PRINCIPAL_MAPPING =
"expression.principal.mapping";
private static final String PRINCIPAL_PARAM = "user.name";
private static final String DOAS_PRINCIPAL_PARAM = "doAs";
static final String IMPERSONATION_ENABLED_PARAM =
AuthFilterUtils.PROXYUSER_PREFIX + ".impersonation.enabled";
@@ -77,6 +79,7 @@ public class CommonIdentityAssertionFilter extends
AbstractIdentityAssertionFilt
/* List of all default and configured impersonation params */
protected final List<String> impersonationParamsList = new ArrayList<>();
protected boolean impersonationEnabled;
+ private AbstractSyntaxTree expressionPrincipalMapping;
private String topologyName;
@Override
@@ -97,6 +100,7 @@ public class CommonIdentityAssertionFilter extends
AbstractIdentityAssertionFilt
throw new ServletException("Unable to load principal mapping table.",
e);
}
}
+ expressionPrincipalMapping = parseAdvancedPrincipalMapping(filterConfig);
final List<String> initParameterNames =
AuthFilterUtils.getInitParameterNamesAsList(filterConfig);
@@ -106,6 +110,14 @@ public class CommonIdentityAssertionFilter extends
AbstractIdentityAssertionFilt
initProxyUserConfiguration(filterConfig, initParameterNames);
}
+ private AbstractSyntaxTree parseAdvancedPrincipalMapping(FilterConfig
filterConfig) {
+ String expression =
filterConfig.getInitParameter(ADVANCED_PRINCIPAL_MAPPING);
+ if (StringUtils.isBlank(expression)) {
+ expression =
filterConfig.getServletContext().getInitParameter(ADVANCED_PRINCIPAL_MAPPING);
+ }
+ return StringUtils.isBlank(expression) ? null : parser.parse(expression);
+ }
+
/*
* Initialize the impersonation params list.
* This list contains query params that needs to be scrubbed
@@ -228,7 +240,12 @@ public class CommonIdentityAssertionFilter extends
AbstractIdentityAssertionFilt
// mapping principal name using user principal mapping (if configured)
mappedPrincipalName = mapUserPrincipalBase(mappedPrincipalName);
mappedPrincipalName = mapUserPrincipal(mappedPrincipalName);
-
+ if (expressionPrincipalMapping != null) {
+ String result = evalAdvancedPrincipalMapping(request, subject,
mappedPrincipalName);
+ if (result != null) {
+ mappedPrincipalName = result;
+ }
+ }
String[] mappedGroups = mapGroupPrincipalsBase(mappedPrincipalName,
subject);
String[] groups = mapGroupPrincipals(mappedPrincipalName, subject);
String[] virtualGroups = virtualGroupMapper.mapGroups(mappedPrincipalName,
combine(subject, groups), request).toArray(new String[0]);
@@ -241,6 +258,20 @@ public class CommonIdentityAssertionFilter extends
AbstractIdentityAssertionFilt
continueChainAsPrincipal(wrapper, response, chain, mappedPrincipalName,
unique(groups));
}
+ private String evalAdvancedPrincipalMapping(ServletRequest request, Subject
subject, String originalPrincipal) {
+ Interpreter interpreter = new Interpreter();
+ interpreter.addConstant("username", originalPrincipal);
+ interpreter.addConstant("groups", groups(subject));
+ addRequestFunctions(request, interpreter);
+ Object mappedPrincipal = interpreter.eval(expressionPrincipalMapping);
+ if (mappedPrincipal instanceof String) {
+ return (String)mappedPrincipal;
+ } else {
+ LOG.invalidAdvancedPrincipalMappingResult(originalPrincipal,
expressionPrincipalMapping, mappedPrincipal);
+ return null;
+ }
+ }
+
private String handleProxyUserImpersonation(ServletRequest request, Subject
subject) throws AuthorizationException {
String principalName = SubjectUtils.getEffectivePrincipalName(subject);
if (impersonationEnabled) {
diff --git
a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/VirtualGroupMapper.java
b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/VirtualGroupMapper.java
index 7a99119ae..9ab392019 100644
---
a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/VirtualGroupMapper.java
+++
b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/VirtualGroupMapper.java
@@ -73,7 +73,7 @@ public class VirtualGroupMapper {
return (boolean)result;
}
- private void addRequestFunctions(ServletRequest req, Interpreter
interpreter) {
+ public static void addRequestFunctions(ServletRequest req, Interpreter
interpreter) {
if (req instanceof HttpServletRequest) {
interpreter.addFunction("request-attribute", Arity.UNARY, params ->
ensureNotNull(req.getAttribute((String)params.get(0))));
@@ -84,11 +84,11 @@ public class VirtualGroupMapper {
}
}
- private String ensureNotNull(Object value) {
+ private static String ensureNotNull(Object value) {
return value == null ? "" : value.toString();
}
- private Object sessionAttribute(HttpServletRequest req, String key) {
+ private static Object sessionAttribute(HttpServletRequest req, String key)
{
HttpSession session = req.getSession(false);
return session != null ? session.getAttribute(key) : "";
}
diff --git
a/gateway-provider-identity-assertion-common/src/test/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilterTest.java
b/gateway-provider-identity-assertion-common/src/test/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilterTest.java
index f68545be7..4d3f9ec74 100644
---
a/gateway-provider-identity-assertion-common/src/test/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilterTest.java
+++
b/gateway-provider-identity-assertion-common/src/test/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilterTest.java
@@ -111,6 +111,8 @@ public class CommonIdentityAssertionFilterTest {
EasyMock.replay(servletContext);
FilterConfig config = EasyMock.createNiceMock( FilterConfig.class );
EasyMock.expect(config.getServletContext()).andReturn(servletContext).anyTimes();
+
EasyMock.expect(config.getInitParameter(CommonIdentityAssertionFilter.ADVANCED_PRINCIPAL_MAPPING)).
+ andReturn("username").anyTimes();
EasyMock.expect(config.getInitParameter(CommonIdentityAssertionFilter.GROUP_PRINCIPAL_MAPPING)).
andReturn("*=everyone;lmccay=test-virtual-group").once();
EasyMock.expect(config.getInitParameter(CommonIdentityAssertionFilter.PRINCIPAL_MAPPING)).
@@ -272,6 +274,7 @@ public class CommonIdentityAssertionFilterTest {
EasyMock.expect(servletContext.getAttribute(GatewayServices.GATEWAY_CLUSTER_ATTRIBUTE)).andReturn("topology1").anyTimes();
EasyMock.expect(servletContext.getInitParameter(CommonIdentityAssertionFilter.PRINCIPAL_MAPPING)).andReturn(null).anyTimes();
EasyMock.expect(servletContext.getInitParameter(CommonIdentityAssertionFilter.GROUP_PRINCIPAL_MAPPING)).andReturn(null).anyTimes();
+
EasyMock.expect(servletContext.getInitParameter(CommonIdentityAssertionFilter.ADVANCED_PRINCIPAL_MAPPING)).andReturn("username").anyTimes();
EasyMock.expect(servletContext.getInitParameterNames()).andReturn(Collections.enumeration(filterConfigParameterNames)).anyTimes();
EasyMock.expect(servletContext.getInitParameter(IMPERSONATION_PARAMS)).andReturn("doAs").anyTimes();
if (!configuredInHadoopAuth) {
diff --git a/gateway-provider-identity-assertion-regex/pom.xml
b/gateway-provider-identity-assertion-regex/pom.xml
index 8de776537..7b03ee536 100644
--- a/gateway-provider-identity-assertion-regex/pom.xml
+++ b/gateway-provider-identity-assertion-regex/pom.xml
@@ -48,6 +48,11 @@
<artifactId>javax.servlet-api</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.apache.knox</groupId>
+ <artifactId>gateway-util-common</artifactId>
+ </dependency>
+
<dependency>
<groupId>org.apache.knox</groupId>
<artifactId>gateway-test-utils</artifactId>
diff --git
a/gateway-provider-identity-assertion-regex/src/main/java/org/apache/knox/gateway/identityasserter/regex/filter/RegexTemplate.java
b/gateway-util-common/src/main/java/org/apache/knox/gateway/identityasserter/regex/filter/RegexTemplate.java
similarity index 100%
rename from
gateway-provider-identity-assertion-regex/src/main/java/org/apache/knox/gateway/identityasserter/regex/filter/RegexTemplate.java
rename to
gateway-util-common/src/main/java/org/apache/knox/gateway/identityasserter/regex/filter/RegexTemplate.java
diff --git
a/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Arity.java
b/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Arity.java
index fb05407c7..cb2310361 100644
--- a/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Arity.java
+++ b/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Arity.java
@@ -39,4 +39,22 @@ public interface Arity {
}
};
}
+
+ static Arity even() {
+ return (methodName, params) -> {
+ if (params.size() % 2 != 0) {
+ throw new ArityException("wrong number of arguments in call to
'" + methodName
+ + "'. Expected even number of arguments, got " +
params.size() + ".");
+ }
+ };
+ }
+
+ static Arity between(int min, int max) {
+ return (methodName, params) -> {
+ if (params.size() < min || params.size() > max) {
+ throw new ArityException("wrong number of arguments in call to
'" + methodName
+ + "'. Expected at least " + min + ", at max " + max +
" arguments, got " + params.size() + ".");
+ }
+ };
+ }
}
diff --git
a/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Interpreter.java
b/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Interpreter.java
index 503727898..a3223fe78 100644
---
a/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Interpreter.java
+++
b/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Interpreter.java
@@ -25,7 +25,9 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import org.apache.knox.gateway.identityasserter.regex.filter.RegexTemplate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@@ -44,17 +46,28 @@ public class Interpreter {
}
public Interpreter() {
- specialForms.put("or", args -> {
- Arity.min(1).check("or", args);
- return args.stream().anyMatch(each -> (boolean)eval(each));
- });
- specialForms.put("and", args -> {
- Arity.min(1).check("and", args);
- return args.stream().allMatch(each -> (boolean)eval(each));
+ addSpecialForm(Arity.min(1), "or", args -> args.stream().anyMatch(each
-> (boolean)eval(each)));
+ addSpecialForm(Arity.min(1), "and", args ->
args.stream().allMatch(each -> (boolean)eval(each)));
+ addSpecialForm(Arity.between(2, 3), "if", args -> {
+ if ((boolean)eval(args.get(0))) {
+ return eval(args.get(1));
+ } else if (args.size() == 3) {
+ return eval(args.get(2));
+ }
+ return null;
});
addFunction("not", Arity.UNARY, args -> !(boolean)args.get(0));
addFunction("=", Arity.BINARY, args -> equalTo(args.get(0),
args.get(1)));
addFunction("!=", Arity.BINARY, args -> !equalTo(args.get(0),
args.get(1)));
+ // The comparisons are floating point based, we might need proper
integer-integer comparison in the future
+ addFunction("<", Arity.BINARY, args ->
((Number)args.get(0)).doubleValue() < ((Number)args.get(1)).doubleValue());
+ addFunction("<=", Arity.BINARY, args ->
((Number)args.get(0)).doubleValue() <= ((Number)args.get(1)).doubleValue());
+ addFunction(">", Arity.BINARY, args ->
((Number)args.get(0)).doubleValue() > ((Number)args.get(1)).doubleValue());
+ addFunction(">=", Arity.BINARY, args ->
((Number)args.get(0)).doubleValue() >= ((Number)args.get(1)).doubleValue());
+ addFunction("+", Arity.BINARY, args -> add((Number)args.get(0),
(Number)args.get(1)));
+ addFunction("-", Arity.BINARY, args -> sub((Number)args.get(0),
(Number)args.get(1)));
+ addFunction("*", Arity.BINARY, args -> mul((Number)args.get(0),
(Number)args.get(1)));
+ addFunction("/", Arity.BINARY, args -> div((Number)args.get(0),
(Number)args.get(1)));
addFunction("match", Arity.BINARY, args ->
args.get(0) instanceof String
? Pattern.matches((String)args.get(1), (String)args.get(0))
@@ -63,17 +76,101 @@ public class Interpreter {
addFunction("size", Arity.UNARY, args -> ((Collection<?>)
args.get(0)).size());
addFunction("empty", Arity.UNARY, args -> ((Collection<?>)
args.get(0)).isEmpty());
addFunction("username", Arity.UNARY, args ->
constants.get("username").equals(args.get(0)));
- addFunction("member", Arity.UNARY, args ->
((List<String>)constants.get("groups")).contains((String)args.get(0)));
+ addFunction("member", Arity.UNARY, args ->
((Collection<String>)constants.get("groups")).contains((String)args.get(0)));
addFunction("lowercase", Arity.UNARY, args ->
((String)args.get(0)).toLowerCase(Locale.getDefault()));
addFunction("uppercase", Arity.UNARY, args ->
((String)args.get(0)).toUpperCase(Locale.getDefault()));
+ addFunction("concat", Arity.min(1), args ->
args.stream().map(Object::toString).collect(Collectors.joining()));
+ addFunction("substr", Arity.min(2), args ->
+ args.size() == 2
+ ?
((String)args.get(0)).substring(((Number)args.get(1)).intValue())
+ :
((String)args.get(0)).substring(((Number)args.get(1)).intValue(),
((Number)args.get(2)).intValue())
+ );
+ addFunction("strlen", Arity.UNARY, args ->
((String)args.get(0)).length());
+ addFunction("starts-with", Arity.BINARY, args ->
((String)args.get(0)).startsWith((String)args.get(1)));
+ addFunction("ends-with", Arity.BINARY, args ->
((String)args.get(0)).endsWith((String)args.get(1)));
+ addFunction("contains", Arity.BINARY, args ->
((String)args.get(1)).contains((String)args.get(0)));
+ addFunction("index-of", Arity.BINARY, args ->
((String)args.get(1)).indexOf((String)args.get(0)));
+ addFunction("regex-template", Arity.between(3, 5), args -> {
+ String str = (String) args.get(0);
+ String regex = (String) args.get(1);
+ String template = (String) args.get(2);
+ if (args.size() == 3) {
+ return new RegexTemplate(regex, template, null,
false).apply(str);
+ } else {
+ boolean useOriginalOnLookupFailure = args.size() >= 5 &&
(boolean) args.get(4);
+ return new RegexTemplate(regex, template, (Map)args.get(3),
useOriginalOnLookupFailure).apply(str);
+ }
+ });
addFunction("print", Arity.min(1), args -> { // for debugging
args.forEach(arg -> LOG.info(arg == null ? "null" :
arg.toString()));
return false;
});
+ addFunction("hash", Arity.even(), args -> { // create a hashmap,
number of arguments must be an even number, this is needed for the RegExp
lookup table
+ Map<Object,Object> map = new HashMap<>();
+ for (int i = 0; i < args.size() -1; i+=2) {
+ map.put(args.get(i), args.get(i +1));
+ }
+ return map;
+ });
+ addFunction("at", Arity.BINARY, args ->
((Map<Object,Object>)args.get(1)).get(args.get(0)));
constants.put("true", true);
constants.put("false", false);
}
+ private Number add(Number a, Number b) {
+ if (isFloatingPoint(a) && isFloatingPoint(b)) {
+ return a.doubleValue() + b.doubleValue();
+ } else if (isInteger(a) && isInteger(b)) {
+ return a.longValue() + b.longValue();
+ } else if (isInteger(a) && isFloatingPoint(b)) {
+ return a.longValue() + b.doubleValue();
+ } else if (isFloatingPoint(a) && isInteger(b)) {
+ return a.doubleValue() + b.longValue();
+ } else {
+ throw new TypeException("Unsupported operands: (+ " + a + " " + b
+ ")", null);
+ }
+ }
+
+ private Number sub(Number a, Number b) {
+ if (isFloatingPoint(a) && isFloatingPoint(b)) {
+ return a.doubleValue() - b.doubleValue();
+ } else if (isInteger(a) && isInteger(b)) {
+ return a.longValue() - b.longValue();
+ } else if (isInteger(a) && isFloatingPoint(b)) {
+ return a.longValue() - b.doubleValue();
+ } else if (isFloatingPoint(a) && isInteger(b)) {
+ return a.doubleValue() - b.longValue();
+ } else {
+ throw new TypeException("Unsupported operands: (- " + a + " " + b
+ ")", null);
+ }
+ }
+
+ private Number mul(Number a, Number b) {
+ if (isFloatingPoint(a) && isFloatingPoint(b)) {
+ return a.doubleValue() * b.doubleValue();
+ } else if (isInteger(a) && isInteger(b)) {
+ return a.longValue() * b.longValue();
+ } else if (isInteger(a) && isFloatingPoint(b)) {
+ return a.longValue() * b.doubleValue();
+ } else if (isFloatingPoint(a) && isInteger(b)) {
+ return a.doubleValue() * b.longValue();
+ } else {
+ throw new TypeException("Unsupported operands: (* " + a + " " + b
+ ")", null);
+ }
+ }
+
+ private Number div(Number a, Number b) {
+ return a.doubleValue() / b.doubleValue(); // div will always result a
floating point result to
+ }
+
+ private static boolean isInteger(Number n) {
+ return n instanceof Long || n instanceof Integer;
+ }
+
+ private static boolean isFloatingPoint(Number n) {
+ return n instanceof Double || n instanceof Float;
+ }
+
private static boolean equalTo(Object a, Object b) {
if (a instanceof Number && b instanceof Number) {
return Double.compare(((Number)a).doubleValue(),
((Number)b).doubleValue()) == 0;
@@ -86,6 +183,13 @@ public class Interpreter {
constants.put(name, value);
}
+ private void addSpecialForm(Arity arity, String name, SpecialForm form) {
+ specialForms.put(name, parameters -> {
+ arity.check(name, parameters);
+ return form.call(parameters);
+ });
+ }
+
public void addFunction(String name, Arity arity, Func func) {
functions.put(name, parameters -> {
arity.check(name, parameters);
diff --git
a/gateway-util-common/src/test/java/org/apache/knox/gateway/plang/InterpreterTest.java
b/gateway-util-common/src/test/java/org/apache/knox/gateway/plang/InterpreterTest.java
index e2ed2d3fa..8e7880d7d 100644
---
a/gateway-util-common/src/test/java/org/apache/knox/gateway/plang/InterpreterTest.java
+++
b/gateway-util-common/src/test/java/org/apache/knox/gateway/plang/InterpreterTest.java
@@ -25,6 +25,8 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
+import java.util.HashMap;
+
import org.junit.Test;
public class InterpreterTest {
@@ -78,6 +80,42 @@ public class InterpreterTest {
assertFalse((boolean)eval("(= 12 '12')"));
}
+ @Test
+ public void testLessThan() {
+ assertTrue((boolean)eval("(< 1 10)"));
+ assertTrue((boolean)eval("(< 1.21 1.32)"));
+ assertFalse((boolean)eval("(< 1 1)"));
+ assertFalse((boolean)eval("(< 10 1)"));
+ assertFalse((boolean)eval("(< 1.31 1.30)"));
+ }
+
+ @Test
+ public void testLessEqThan() {
+ assertTrue((boolean)eval("(<= 1 10)"));
+ assertTrue((boolean)eval("(<= 1.21 1.32)"));
+ assertTrue((boolean)eval("(<= 1 1)"));
+ assertFalse((boolean)eval("(<= 10 1)"));
+ assertFalse((boolean)eval("(<= 1.31 1.30)"));
+ }
+
+ @Test
+ public void testGreaterThan() {
+ assertFalse((boolean)eval("(> 1 10)"));
+ assertFalse((boolean)eval("(> 1.21 1.32)"));
+ assertFalse((boolean)eval("(> 1 1)"));
+ assertTrue((boolean)eval("(> 10 1)"));
+ assertTrue((boolean)eval("(> 1.31 1.30)"));
+ }
+
+ @Test
+ public void testGreaterEqThan() {
+ assertFalse((boolean)eval("(>= 1 10)"));
+ assertFalse((boolean)eval("(>= 1.21 1.32)"));
+ assertTrue((boolean)eval("(>= 1 1)"));
+ assertTrue((boolean)eval("(>= 10 1)"));
+ assertTrue((boolean)eval("(>= 1.31 1.30)"));
+ }
+
@Test
public void testOr() {
assertTrue((boolean)eval("(or true true)"));
@@ -106,6 +144,76 @@ public class InterpreterTest {
assertTrue((boolean)eval("(not (not true))"));
}
+ @Test
+ public void testAdd() {
+ assertEquals(10L, eval("(+ 5 5)"));
+ assertEquals(10d, eval("(+ 5.0 5)"));
+ assertEquals(10d, eval("(+ 5 5.0)"));
+ assertEquals(10d, eval("(+ 5.0 5.0)"));
+ assertEquals(5.7d, eval("(+ 2.5 3.2)"));
+ assertEquals(8.2d, eval("(+ 5 3.2)"));
+ assertEquals(9L, eval("(+ 7 2)"));
+ assertEquals(5L, eval("(+ (strlen 'ab') 3)"));
+ assertEquals(5d, eval("(+ (strlen 'ab') 3.0)"));
+ }
+
+ @Test(expected = TypeException.class)
+ public void testAddInvalidType() {
+ eval("(+ 'apple' 'orange')");
+ }
+
+ @Test
+ public void testSub() {
+ assertEquals(10L, eval("(- 15 5)"));
+ assertEquals(10d, eval("(- 15.0 5)"));
+ assertEquals(10d, eval("(- 15 5.0)"));
+ assertEquals(10d, eval("(- 15.0 5.0)"));
+ assertEquals(2.3d, eval("(- 5.5 3.2)"));
+ assertEquals(1.8d, (double)eval("(- 5 3.2)"), 0.01d);
+ assertEquals(5L, eval("(- 7 2)"));
+ assertEquals(1L, eval("(- (strlen 'abcd') 3)"));
+ assertEquals(1d, eval("(- (strlen 'abcd') 3.0)"));
+ }
+
+ @Test(expected = TypeException.class)
+ public void testSubInvalidType() {
+ eval("(- 'apple' 'orange')");
+ }
+
+ @Test
+ public void testMul() {
+ assertEquals(50L, eval("(* 10 5)"));
+ assertEquals(50d, eval("(* 10.0 5)"));
+ assertEquals(50d, eval("(* 10 5.0)"));
+ assertEquals(50d, eval("(* 10.0 5.0)"));
+ assertEquals(17.6d, eval("(* 5.5 3.2)"));
+ assertEquals(16d, eval("(* 5 3.2)"));
+ assertEquals(14L, eval("(* 7 2)"));
+ assertEquals(6L, eval("(* (strlen 'ab') 3)"));
+ assertEquals(6d, eval("(* (strlen 'ab') 3.0)"));
+ }
+
+ @Test(expected = TypeException.class)
+ public void testMulInvalidType() {
+ eval("(* 'apple' 'orange')");
+ }
+
+ @Test
+ public void testDiv() {
+ assertEquals(2d, eval("(/ 10 5)"));
+ assertEquals(2d, eval("(/ 10.0 5)"));
+ assertEquals(2d, eval("(/ 10 5.0)"));
+ assertEquals(2d, eval("(/ 10.0 5.0)"));
+ assertEquals(1.71875d, eval("(/ 5.5 3.2)"));
+ assertEquals(1.5625d, eval("(/ 5 3.2)"));
+ assertEquals(3.5d, eval("(/ 7 2)"));
+ }
+
+ @Test(expected = TypeException.class)
+ public void testDivInvalidType() {
+ eval("(/ 'apple' 'orange')");
+ }
+
@Test
public void testComplex() {
assertTrue((boolean)eval("(and (not false) (or (not (or (not true)
(not false) )) true))"));
@@ -172,6 +280,144 @@ public class InterpreterTest {
assertFalse((boolean)eval("(and false (invalid-expression 1 2 3))"));
}
+ @Test
+ public void testIf() {
+ assertNull(eval("(if false (invalid-expression))"));
+ assertEquals("testStr", eval("(if true 'testStr')"));
+ }
+
+ @Test
+ public void testIfElse() {
+ assertEquals("apple", eval("(if false (invalid-expression) (lowercase
'APPLE'))"));
+ assertEquals("orange", eval("(if true (lowercase 'ORANGE')
(invalid-expression) )"));
+ }
+
+ @Test(expected = ArityException.class)
+ public void testIfWrongNumberOfArgs() {
+ eval("(if true 1 2 3)");
+ }
+
+ @Test
+ public void testConcat() {
+ assertEquals("asdf", eval("(concat 'asdf')"));
+ assertEquals("asdfjkl", eval("(concat 'asdf' 'jkl')"));
+ assertEquals("asdfjklqwerty", eval("(concat 'asdf' 'jkl' 'qwerty')"));
+ assertEquals("orange APPLE", eval("(concat (lowercase 'ORANGE') ' '
(uppercase 'apple'))"));
+ assertEquals("123", eval("(concat 1 2 3)"));
+ }
+
+ @Test
+ public void testSubStr1() {
+ assertEquals("123456789", eval("(substr '123456789' 0)"));
+ assertEquals("456789", eval("(substr '123456789' 3)"));
+ assertEquals("9", eval("(substr '123456789' 8)"));
+ assertEquals("", eval("(substr '123456789' 9)"));
+ }
+
+ @Test
+ public void testSubStr2() {
+ assertEquals("123456789", eval("(substr '123456789' 0 9)"));
+ assertEquals("", eval("(substr '123456789' 9 9)"));
+ assertEquals("123", eval("(substr '123456789' 0 3)"));
+ assertEquals("3", eval("(substr '123456789' 2 3)"));
+ assertEquals("345", eval("(substr '123456789' 2 5)"));
+ }
+
+ @Test
+ public void testStrLen() {
+ assertEquals(0, eval("(strlen '')"));
+ assertEquals(5, eval("(strlen (uppercase 'apple'))"));
+ }
+
+ @Test
+ public void testStartsWith() {
+ assertTrue((boolean)eval("(starts-with '' '')"));
+ assertTrue((boolean)eval("(starts-with 'apple' '')"));
+ assertTrue((boolean)eval("(starts-with 'apple' 'ap')"));
+ assertTrue((boolean)eval("(starts-with 'apple' 'app')"));
+ assertTrue((boolean)eval("(starts-with 'apple' 'appl')"));
+ assertTrue((boolean)eval("(starts-with 'apple' 'apple')"));
+ assertFalse((boolean)eval("(starts-with 'apple' 'applex')"));
+ assertFalse((boolean)eval("(starts-with '' 'a')"));
+ }
+
+ @Test
+ public void testEndsWith() {
+ assertTrue((boolean)eval("(ends-with '' '')"));
+ assertTrue((boolean)eval("(ends-with 'apple' '')"));
+ assertTrue((boolean)eval("(ends-with 'apple' 'e')"));
+ assertTrue((boolean)eval("(ends-with 'apple' 'ple')"));
+ assertTrue((boolean)eval("(ends-with 'apple' 'apple' )"));
+ assertFalse((boolean)eval("(ends-with 'apple' 'xapple' )"));
+ assertFalse((boolean)eval("(ends-with '' 'a')"));
+ }
+
+ @Test
+ public void testStrIn() {
+ assertTrue((boolean)eval("(contains 'ppl' 'apple')"));
+ assertTrue((boolean)eval("(contains '' 'apple')"));
+ assertTrue((boolean)eval("(contains 'a' 'apple')"));
+ assertFalse((boolean)eval("(contains 'x' 'apple')"));
+ }
+
+ @Test
+ public void testStrIndex() {
+ assertEquals(1, eval("(index-of 'ppl' 'apple')"));
+ assertEquals(-1, eval("(index-of 'xx' 'apple')"));
+ }
+
+ @Test
+ public void testRegexpGroup() {
+ assertEquals("user.1", eval("(regex-template 'prefix_user-1_suffix'
'prefix_(\\w+)\\-(\\d)_suffix' '{1}.{2}')"));
+ assertEquals("123", eval("(regex-template 'usr123' 'usr(\\d+)'
'{1}')"));
+ assertEquals("usr123", eval("(regex-template 'usr123' 'usr\\d+'
'{0}')"));
+ assertEquals("{0}", eval("(regex-template 'admin' '\\d+' '{0}')"));
+ }
+
+ @Test
+ public void testRegexpGroupWithLookup() {
+ // See RegexTemplateTest
+ String script = "(regex-template '[email protected]'
'(.*)@(.*?)\\..*' '{1}_{[2]}' (hash 'us' 'USA' 'ca' 'CANADA'))";
+ assertEquals("nobody_USA", eval(script));
+
+ script = "(regex-template '[email protected]' '(.*)@(.*?)\\..*'
'prefix_{1}:{[2]}_suffix' (hash 'us' 'USA' 'ca' 'CANADA'))";
+ assertEquals("prefix_member:USA_suffix", eval(script));
+
+ script = "(regex-template '[email protected]' '(.*)@(.*?)\\..*'
'prefix_{1}:{[2]}_suffix' (hash 'us' 'USA' 'ca' 'CANADA'))";
+ assertEquals("prefix_member:CANADA_suffix", eval(script));
+
+ script = "(regex-template '[email protected]' '(.*)@(.*?)\\..*'
'prefix_{1}:{[2]}_suffix' (hash 'us' 'USA' 'ca' 'CANADA'))";
+ assertEquals("prefix_member:_suffix", eval(script));
+
+ script = "(regex-template '[email protected]' '(.*)@(.*?)\\..*'
'prefix_{1}:{[2]}_suffix' (hash 'us' 'USA' 'ca' 'CANADA') true)";
+ assertEquals("prefix_member:uk_suffix", eval(script));
+ }
+
+ @Test
+ public void testHashMaps() {
+ HashMap<Object, Object> expected = new HashMap<>();
+ assertEquals(expected, eval("(hash)"));
+ expected.put(1L , 2L);
+ assertEquals(expected, eval("(hash 1 2)"));
+ expected.put("a", "b");
+ assertEquals(expected, eval("(hash 1 2 'a' 'b')"));
+ expected.clear();
+ expected.put("apple123", true);
+ assertEquals(expected, eval("(hash (lowercase (concat 'Apple' '123'))
(and (< 10 12) (> 10 1)))"));
+ }
+
+ @Test
+ public void testHashMapLookup() {
+ assertEquals(2L, eval("(at 1 (hash 1 2 'a' 'b'))"));
+ assertEquals("b", eval("(at 'a' (hash 1 2 'a' 'b'))"));
+ assertNull(eval("(at 'b' (hash 1 2 'a' 'b'))"));
+ }
+
+ @Test(expected = ArityException.class)
+ public void testHashMapInvalid() {
+ eval("(hash 'key1' 'value1' 'key2')");
+ }
+
private Object eval(String script) {
return interpreter.eval(parser.parse(script));
}