This is an automated email from the ASF dual-hosted git repository. amagyar pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/knox.git
The following commit(s) were added to refs/heads/master by this push: new 083dc8977 KNOX-2983 - Combine the functionality of different identity assertion providers (#817) 083dc8977 is described below commit 083dc8977fcae7e6669670412d62061a599b49cf Author: Attila Magyar <m.magy...@gmail.com> AuthorDate: Fri Nov 17 00:31:39 2023 +0100 KNOX-2983 - Combine the functionality of different identity assertion providers (#817) --- .../knox/gateway/IdentityAsserterMessages.java | 3 + .../filter/CommonIdentityAssertionFilter.java | 35 ++- .../common/filter/VirtualGroupMapper.java | 6 +- .../filter/CommonIdentityAssertionFilterTest.java | 3 + gateway-provider-identity-assertion-regex/pom.xml | 5 + .../regex/filter/RegexTemplate.java | 0 .../java/org/apache/knox/gateway/plang/Arity.java | 18 ++ .../org/apache/knox/gateway/plang/Interpreter.java | 120 +++++++++- .../apache/knox/gateway/plang/InterpreterTest.java | 246 +++++++++++++++++++++ 9 files changed, 423 insertions(+), 13 deletions(-) diff --git a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/IdentityAsserterMessages.java b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/IdentityAsserterMessages.java index f533f202a..ab1a3eaab 100644 --- a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/IdentityAsserterMessages.java +++ b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/IdentityAsserterMessages.java @@ -63,4 +63,7 @@ public interface IdentityAsserterMessages { @Message( level = MessageLevel.ERROR, text = "Proxy user Authentication failed: {0}" ) void hadoopAuthProxyUserFailed(@StackTrace Throwable t); + + @Message( level = MessageLevel.WARN, text = "Invalid result: {2}. Expected String when evaluating mapping: {1} for user: {0}.") + void invalidAdvancedPrincipalMappingResult(String principalName, AbstractSyntaxTree mapping, Object result); } diff --git a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilter.java b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilter.java index 537a99281..f3bb669d7 100644 --- a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilter.java +++ b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilter.java @@ -19,6 +19,7 @@ package org.apache.knox.gateway.identityasserter.common.filter; import static org.apache.knox.gateway.identityasserter.common.filter.AbstractIdentityAsserterDeploymentContributor.IMPERSONATION_PARAMS; import static org.apache.knox.gateway.identityasserter.common.filter.AbstractIdentityAsserterDeploymentContributor.ROLE; +import static org.apache.knox.gateway.identityasserter.common.filter.VirtualGroupMapper.addRequestFunctions; import java.io.IOException; import java.security.AccessController; @@ -32,7 +33,6 @@ import java.util.Map; import java.util.Set; import java.util.StringTokenizer; import java.util.stream.Collectors; - import javax.security.auth.Subject; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; @@ -50,6 +50,7 @@ import org.apache.knox.gateway.IdentityAsserterMessages; import org.apache.knox.gateway.context.ContextAttributes; import org.apache.knox.gateway.i18n.messages.MessagesFactory; import org.apache.knox.gateway.plang.AbstractSyntaxTree; +import org.apache.knox.gateway.plang.Interpreter; import org.apache.knox.gateway.plang.Parser; import org.apache.knox.gateway.plang.SyntaxException; import org.apache.knox.gateway.security.GroupPrincipal; @@ -67,6 +68,7 @@ public class CommonIdentityAssertionFilter extends AbstractIdentityAssertionFilt public static final String VIRTUAL_GROUP_MAPPING_PREFIX = "group.mapping."; public static final String GROUP_PRINCIPAL_MAPPING = "group.principal.mapping"; public static final String PRINCIPAL_MAPPING = "principal.mapping"; + public static final String ADVANCED_PRINCIPAL_MAPPING = "expression.principal.mapping"; private static final String PRINCIPAL_PARAM = "user.name"; private static final String DOAS_PRINCIPAL_PARAM = "doAs"; static final String IMPERSONATION_ENABLED_PARAM = AuthFilterUtils.PROXYUSER_PREFIX + ".impersonation.enabled"; @@ -77,6 +79,7 @@ public class CommonIdentityAssertionFilter extends AbstractIdentityAssertionFilt /* List of all default and configured impersonation params */ protected final List<String> impersonationParamsList = new ArrayList<>(); protected boolean impersonationEnabled; + private AbstractSyntaxTree expressionPrincipalMapping; private String topologyName; @Override @@ -97,6 +100,7 @@ public class CommonIdentityAssertionFilter extends AbstractIdentityAssertionFilt throw new ServletException("Unable to load principal mapping table.", e); } } + expressionPrincipalMapping = parseAdvancedPrincipalMapping(filterConfig); final List<String> initParameterNames = AuthFilterUtils.getInitParameterNamesAsList(filterConfig); @@ -106,6 +110,14 @@ public class CommonIdentityAssertionFilter extends AbstractIdentityAssertionFilt initProxyUserConfiguration(filterConfig, initParameterNames); } + private AbstractSyntaxTree parseAdvancedPrincipalMapping(FilterConfig filterConfig) { + String expression = filterConfig.getInitParameter(ADVANCED_PRINCIPAL_MAPPING); + if (StringUtils.isBlank(expression)) { + expression = filterConfig.getServletContext().getInitParameter(ADVANCED_PRINCIPAL_MAPPING); + } + return StringUtils.isBlank(expression) ? null : parser.parse(expression); + } + /* * Initialize the impersonation params list. * This list contains query params that needs to be scrubbed @@ -228,7 +240,12 @@ public class CommonIdentityAssertionFilter extends AbstractIdentityAssertionFilt // mapping principal name using user principal mapping (if configured) mappedPrincipalName = mapUserPrincipalBase(mappedPrincipalName); mappedPrincipalName = mapUserPrincipal(mappedPrincipalName); - + if (expressionPrincipalMapping != null) { + String result = evalAdvancedPrincipalMapping(request, subject, mappedPrincipalName); + if (result != null) { + mappedPrincipalName = result; + } + } String[] mappedGroups = mapGroupPrincipalsBase(mappedPrincipalName, subject); String[] groups = mapGroupPrincipals(mappedPrincipalName, subject); String[] virtualGroups = virtualGroupMapper.mapGroups(mappedPrincipalName, combine(subject, groups), request).toArray(new String[0]); @@ -241,6 +258,20 @@ public class CommonIdentityAssertionFilter extends AbstractIdentityAssertionFilt continueChainAsPrincipal(wrapper, response, chain, mappedPrincipalName, unique(groups)); } + private String evalAdvancedPrincipalMapping(ServletRequest request, Subject subject, String originalPrincipal) { + Interpreter interpreter = new Interpreter(); + interpreter.addConstant("username", originalPrincipal); + interpreter.addConstant("groups", groups(subject)); + addRequestFunctions(request, interpreter); + Object mappedPrincipal = interpreter.eval(expressionPrincipalMapping); + if (mappedPrincipal instanceof String) { + return (String)mappedPrincipal; + } else { + LOG.invalidAdvancedPrincipalMappingResult(originalPrincipal, expressionPrincipalMapping, mappedPrincipal); + return null; + } + } + private String handleProxyUserImpersonation(ServletRequest request, Subject subject) throws AuthorizationException { String principalName = SubjectUtils.getEffectivePrincipalName(subject); if (impersonationEnabled) { diff --git a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/VirtualGroupMapper.java b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/VirtualGroupMapper.java index 7a99119ae..9ab392019 100644 --- a/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/VirtualGroupMapper.java +++ b/gateway-provider-identity-assertion-common/src/main/java/org/apache/knox/gateway/identityasserter/common/filter/VirtualGroupMapper.java @@ -73,7 +73,7 @@ public class VirtualGroupMapper { return (boolean)result; } - private void addRequestFunctions(ServletRequest req, Interpreter interpreter) { + public static void addRequestFunctions(ServletRequest req, Interpreter interpreter) { if (req instanceof HttpServletRequest) { interpreter.addFunction("request-attribute", Arity.UNARY, params -> ensureNotNull(req.getAttribute((String)params.get(0)))); @@ -84,11 +84,11 @@ public class VirtualGroupMapper { } } - private String ensureNotNull(Object value) { + private static String ensureNotNull(Object value) { return value == null ? "" : value.toString(); } - private Object sessionAttribute(HttpServletRequest req, String key) { + private static Object sessionAttribute(HttpServletRequest req, String key) { HttpSession session = req.getSession(false); return session != null ? session.getAttribute(key) : ""; } diff --git a/gateway-provider-identity-assertion-common/src/test/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilterTest.java b/gateway-provider-identity-assertion-common/src/test/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilterTest.java index f68545be7..4d3f9ec74 100644 --- a/gateway-provider-identity-assertion-common/src/test/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilterTest.java +++ b/gateway-provider-identity-assertion-common/src/test/java/org/apache/knox/gateway/identityasserter/common/filter/CommonIdentityAssertionFilterTest.java @@ -111,6 +111,8 @@ public class CommonIdentityAssertionFilterTest { EasyMock.replay(servletContext); FilterConfig config = EasyMock.createNiceMock( FilterConfig.class ); EasyMock.expect(config.getServletContext()).andReturn(servletContext).anyTimes(); + EasyMock.expect(config.getInitParameter(CommonIdentityAssertionFilter.ADVANCED_PRINCIPAL_MAPPING)). + andReturn("username").anyTimes(); EasyMock.expect(config.getInitParameter(CommonIdentityAssertionFilter.GROUP_PRINCIPAL_MAPPING)). andReturn("*=everyone;lmccay=test-virtual-group").once(); EasyMock.expect(config.getInitParameter(CommonIdentityAssertionFilter.PRINCIPAL_MAPPING)). @@ -272,6 +274,7 @@ public class CommonIdentityAssertionFilterTest { EasyMock.expect(servletContext.getAttribute(GatewayServices.GATEWAY_CLUSTER_ATTRIBUTE)).andReturn("topology1").anyTimes(); EasyMock.expect(servletContext.getInitParameter(CommonIdentityAssertionFilter.PRINCIPAL_MAPPING)).andReturn(null).anyTimes(); EasyMock.expect(servletContext.getInitParameter(CommonIdentityAssertionFilter.GROUP_PRINCIPAL_MAPPING)).andReturn(null).anyTimes(); + EasyMock.expect(servletContext.getInitParameter(CommonIdentityAssertionFilter.ADVANCED_PRINCIPAL_MAPPING)).andReturn("username").anyTimes(); EasyMock.expect(servletContext.getInitParameterNames()).andReturn(Collections.enumeration(filterConfigParameterNames)).anyTimes(); EasyMock.expect(servletContext.getInitParameter(IMPERSONATION_PARAMS)).andReturn("doAs").anyTimes(); if (!configuredInHadoopAuth) { diff --git a/gateway-provider-identity-assertion-regex/pom.xml b/gateway-provider-identity-assertion-regex/pom.xml index 8de776537..7b03ee536 100644 --- a/gateway-provider-identity-assertion-regex/pom.xml +++ b/gateway-provider-identity-assertion-regex/pom.xml @@ -48,6 +48,11 @@ <artifactId>javax.servlet-api</artifactId> </dependency> + <dependency> + <groupId>org.apache.knox</groupId> + <artifactId>gateway-util-common</artifactId> + </dependency> + <dependency> <groupId>org.apache.knox</groupId> <artifactId>gateway-test-utils</artifactId> diff --git a/gateway-provider-identity-assertion-regex/src/main/java/org/apache/knox/gateway/identityasserter/regex/filter/RegexTemplate.java b/gateway-util-common/src/main/java/org/apache/knox/gateway/identityasserter/regex/filter/RegexTemplate.java similarity index 100% rename from gateway-provider-identity-assertion-regex/src/main/java/org/apache/knox/gateway/identityasserter/regex/filter/RegexTemplate.java rename to gateway-util-common/src/main/java/org/apache/knox/gateway/identityasserter/regex/filter/RegexTemplate.java diff --git a/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Arity.java b/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Arity.java index fb05407c7..cb2310361 100644 --- a/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Arity.java +++ b/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Arity.java @@ -39,4 +39,22 @@ public interface Arity { } }; } + + static Arity even() { + return (methodName, params) -> { + if (params.size() % 2 != 0) { + throw new ArityException("wrong number of arguments in call to '" + methodName + + "'. Expected even number of arguments, got " + params.size() + "."); + } + }; + } + + static Arity between(int min, int max) { + return (methodName, params) -> { + if (params.size() < min || params.size() > max) { + throw new ArityException("wrong number of arguments in call to '" + methodName + + "'. Expected at least " + min + ", at max " + max + " arguments, got " + params.size() + "."); + } + }; + } } diff --git a/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Interpreter.java b/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Interpreter.java index 503727898..a3223fe78 100644 --- a/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Interpreter.java +++ b/gateway-util-common/src/main/java/org/apache/knox/gateway/plang/Interpreter.java @@ -25,7 +25,9 @@ import java.util.List; import java.util.Locale; import java.util.Map; import java.util.regex.Pattern; +import java.util.stream.Collectors; +import org.apache.knox.gateway.identityasserter.regex.filter.RegexTemplate; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -44,17 +46,28 @@ public class Interpreter { } public Interpreter() { - specialForms.put("or", args -> { - Arity.min(1).check("or", args); - return args.stream().anyMatch(each -> (boolean)eval(each)); - }); - specialForms.put("and", args -> { - Arity.min(1).check("and", args); - return args.stream().allMatch(each -> (boolean)eval(each)); + addSpecialForm(Arity.min(1), "or", args -> args.stream().anyMatch(each -> (boolean)eval(each))); + addSpecialForm(Arity.min(1), "and", args -> args.stream().allMatch(each -> (boolean)eval(each))); + addSpecialForm(Arity.between(2, 3), "if", args -> { + if ((boolean)eval(args.get(0))) { + return eval(args.get(1)); + } else if (args.size() == 3) { + return eval(args.get(2)); + } + return null; }); addFunction("not", Arity.UNARY, args -> !(boolean)args.get(0)); addFunction("=", Arity.BINARY, args -> equalTo(args.get(0), args.get(1))); addFunction("!=", Arity.BINARY, args -> !equalTo(args.get(0), args.get(1))); + // The comparisons are floating point based, we might need proper integer-integer comparison in the future + addFunction("<", Arity.BINARY, args -> ((Number)args.get(0)).doubleValue() < ((Number)args.get(1)).doubleValue()); + addFunction("<=", Arity.BINARY, args -> ((Number)args.get(0)).doubleValue() <= ((Number)args.get(1)).doubleValue()); + addFunction(">", Arity.BINARY, args -> ((Number)args.get(0)).doubleValue() > ((Number)args.get(1)).doubleValue()); + addFunction(">=", Arity.BINARY, args -> ((Number)args.get(0)).doubleValue() >= ((Number)args.get(1)).doubleValue()); + addFunction("+", Arity.BINARY, args -> add((Number)args.get(0), (Number)args.get(1))); + addFunction("-", Arity.BINARY, args -> sub((Number)args.get(0), (Number)args.get(1))); + addFunction("*", Arity.BINARY, args -> mul((Number)args.get(0), (Number)args.get(1))); + addFunction("/", Arity.BINARY, args -> div((Number)args.get(0), (Number)args.get(1))); addFunction("match", Arity.BINARY, args -> args.get(0) instanceof String ? Pattern.matches((String)args.get(1), (String)args.get(0)) @@ -63,17 +76,101 @@ public class Interpreter { addFunction("size", Arity.UNARY, args -> ((Collection<?>) args.get(0)).size()); addFunction("empty", Arity.UNARY, args -> ((Collection<?>) args.get(0)).isEmpty()); addFunction("username", Arity.UNARY, args -> constants.get("username").equals(args.get(0))); - addFunction("member", Arity.UNARY, args -> ((List<String>)constants.get("groups")).contains((String)args.get(0))); + addFunction("member", Arity.UNARY, args -> ((Collection<String>)constants.get("groups")).contains((String)args.get(0))); addFunction("lowercase", Arity.UNARY, args -> ((String)args.get(0)).toLowerCase(Locale.getDefault())); addFunction("uppercase", Arity.UNARY, args -> ((String)args.get(0)).toUpperCase(Locale.getDefault())); + addFunction("concat", Arity.min(1), args -> args.stream().map(Object::toString).collect(Collectors.joining())); + addFunction("substr", Arity.min(2), args -> + args.size() == 2 + ? ((String)args.get(0)).substring(((Number)args.get(1)).intValue()) + : ((String)args.get(0)).substring(((Number)args.get(1)).intValue(), ((Number)args.get(2)).intValue()) + ); + addFunction("strlen", Arity.UNARY, args -> ((String)args.get(0)).length()); + addFunction("starts-with", Arity.BINARY, args -> ((String)args.get(0)).startsWith((String)args.get(1))); + addFunction("ends-with", Arity.BINARY, args -> ((String)args.get(0)).endsWith((String)args.get(1))); + addFunction("contains", Arity.BINARY, args -> ((String)args.get(1)).contains((String)args.get(0))); + addFunction("index-of", Arity.BINARY, args -> ((String)args.get(1)).indexOf((String)args.get(0))); + addFunction("regex-template", Arity.between(3, 5), args -> { + String str = (String) args.get(0); + String regex = (String) args.get(1); + String template = (String) args.get(2); + if (args.size() == 3) { + return new RegexTemplate(regex, template, null, false).apply(str); + } else { + boolean useOriginalOnLookupFailure = args.size() >= 5 && (boolean) args.get(4); + return new RegexTemplate(regex, template, (Map)args.get(3), useOriginalOnLookupFailure).apply(str); + } + }); addFunction("print", Arity.min(1), args -> { // for debugging args.forEach(arg -> LOG.info(arg == null ? "null" : arg.toString())); return false; }); + addFunction("hash", Arity.even(), args -> { // create a hashmap, number of arguments must be an even number, this is needed for the RegExp lookup table + Map<Object,Object> map = new HashMap<>(); + for (int i = 0; i < args.size() -1; i+=2) { + map.put(args.get(i), args.get(i +1)); + } + return map; + }); + addFunction("at", Arity.BINARY, args -> ((Map<Object,Object>)args.get(1)).get(args.get(0))); constants.put("true", true); constants.put("false", false); } + private Number add(Number a, Number b) { + if (isFloatingPoint(a) && isFloatingPoint(b)) { + return a.doubleValue() + b.doubleValue(); + } else if (isInteger(a) && isInteger(b)) { + return a.longValue() + b.longValue(); + } else if (isInteger(a) && isFloatingPoint(b)) { + return a.longValue() + b.doubleValue(); + } else if (isFloatingPoint(a) && isInteger(b)) { + return a.doubleValue() + b.longValue(); + } else { + throw new TypeException("Unsupported operands: (+ " + a + " " + b + ")", null); + } + } + + private Number sub(Number a, Number b) { + if (isFloatingPoint(a) && isFloatingPoint(b)) { + return a.doubleValue() - b.doubleValue(); + } else if (isInteger(a) && isInteger(b)) { + return a.longValue() - b.longValue(); + } else if (isInteger(a) && isFloatingPoint(b)) { + return a.longValue() - b.doubleValue(); + } else if (isFloatingPoint(a) && isInteger(b)) { + return a.doubleValue() - b.longValue(); + } else { + throw new TypeException("Unsupported operands: (- " + a + " " + b + ")", null); + } + } + + private Number mul(Number a, Number b) { + if (isFloatingPoint(a) && isFloatingPoint(b)) { + return a.doubleValue() * b.doubleValue(); + } else if (isInteger(a) && isInteger(b)) { + return a.longValue() * b.longValue(); + } else if (isInteger(a) && isFloatingPoint(b)) { + return a.longValue() * b.doubleValue(); + } else if (isFloatingPoint(a) && isInteger(b)) { + return a.doubleValue() * b.longValue(); + } else { + throw new TypeException("Unsupported operands: (* " + a + " " + b + ")", null); + } + } + + private Number div(Number a, Number b) { + return a.doubleValue() / b.doubleValue(); // div will always result a floating point result to + } + + private static boolean isInteger(Number n) { + return n instanceof Long || n instanceof Integer; + } + + private static boolean isFloatingPoint(Number n) { + return n instanceof Double || n instanceof Float; + } + private static boolean equalTo(Object a, Object b) { if (a instanceof Number && b instanceof Number) { return Double.compare(((Number)a).doubleValue(), ((Number)b).doubleValue()) == 0; @@ -86,6 +183,13 @@ public class Interpreter { constants.put(name, value); } + private void addSpecialForm(Arity arity, String name, SpecialForm form) { + specialForms.put(name, parameters -> { + arity.check(name, parameters); + return form.call(parameters); + }); + } + public void addFunction(String name, Arity arity, Func func) { functions.put(name, parameters -> { arity.check(name, parameters); diff --git a/gateway-util-common/src/test/java/org/apache/knox/gateway/plang/InterpreterTest.java b/gateway-util-common/src/test/java/org/apache/knox/gateway/plang/InterpreterTest.java index e2ed2d3fa..8e7880d7d 100644 --- a/gateway-util-common/src/test/java/org/apache/knox/gateway/plang/InterpreterTest.java +++ b/gateway-util-common/src/test/java/org/apache/knox/gateway/plang/InterpreterTest.java @@ -25,6 +25,8 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import java.util.HashMap; + import org.junit.Test; public class InterpreterTest { @@ -78,6 +80,42 @@ public class InterpreterTest { assertFalse((boolean)eval("(= 12 '12')")); } + @Test + public void testLessThan() { + assertTrue((boolean)eval("(< 1 10)")); + assertTrue((boolean)eval("(< 1.21 1.32)")); + assertFalse((boolean)eval("(< 1 1)")); + assertFalse((boolean)eval("(< 10 1)")); + assertFalse((boolean)eval("(< 1.31 1.30)")); + } + + @Test + public void testLessEqThan() { + assertTrue((boolean)eval("(<= 1 10)")); + assertTrue((boolean)eval("(<= 1.21 1.32)")); + assertTrue((boolean)eval("(<= 1 1)")); + assertFalse((boolean)eval("(<= 10 1)")); + assertFalse((boolean)eval("(<= 1.31 1.30)")); + } + + @Test + public void testGreaterThan() { + assertFalse((boolean)eval("(> 1 10)")); + assertFalse((boolean)eval("(> 1.21 1.32)")); + assertFalse((boolean)eval("(> 1 1)")); + assertTrue((boolean)eval("(> 10 1)")); + assertTrue((boolean)eval("(> 1.31 1.30)")); + } + + @Test + public void testGreaterEqThan() { + assertFalse((boolean)eval("(>= 1 10)")); + assertFalse((boolean)eval("(>= 1.21 1.32)")); + assertTrue((boolean)eval("(>= 1 1)")); + assertTrue((boolean)eval("(>= 10 1)")); + assertTrue((boolean)eval("(>= 1.31 1.30)")); + } + @Test public void testOr() { assertTrue((boolean)eval("(or true true)")); @@ -106,6 +144,76 @@ public class InterpreterTest { assertTrue((boolean)eval("(not (not true))")); } + @Test + public void testAdd() { + assertEquals(10L, eval("(+ 5 5)")); + assertEquals(10d, eval("(+ 5.0 5)")); + assertEquals(10d, eval("(+ 5 5.0)")); + assertEquals(10d, eval("(+ 5.0 5.0)")); + assertEquals(5.7d, eval("(+ 2.5 3.2)")); + assertEquals(8.2d, eval("(+ 5 3.2)")); + assertEquals(9L, eval("(+ 7 2)")); + assertEquals(5L, eval("(+ (strlen 'ab') 3)")); + assertEquals(5d, eval("(+ (strlen 'ab') 3.0)")); + } + + @Test(expected = TypeException.class) + public void testAddInvalidType() { + eval("(+ 'apple' 'orange')"); + } + + @Test + public void testSub() { + assertEquals(10L, eval("(- 15 5)")); + assertEquals(10d, eval("(- 15.0 5)")); + assertEquals(10d, eval("(- 15 5.0)")); + assertEquals(10d, eval("(- 15.0 5.0)")); + assertEquals(2.3d, eval("(- 5.5 3.2)")); + assertEquals(1.8d, (double)eval("(- 5 3.2)"), 0.01d); + assertEquals(5L, eval("(- 7 2)")); + assertEquals(1L, eval("(- (strlen 'abcd') 3)")); + assertEquals(1d, eval("(- (strlen 'abcd') 3.0)")); + } + + @Test(expected = TypeException.class) + public void testSubInvalidType() { + eval("(- 'apple' 'orange')"); + } + + @Test + public void testMul() { + assertEquals(50L, eval("(* 10 5)")); + assertEquals(50d, eval("(* 10.0 5)")); + assertEquals(50d, eval("(* 10 5.0)")); + assertEquals(50d, eval("(* 10.0 5.0)")); + assertEquals(17.6d, eval("(* 5.5 3.2)")); + assertEquals(16d, eval("(* 5 3.2)")); + assertEquals(14L, eval("(* 7 2)")); + assertEquals(6L, eval("(* (strlen 'ab') 3)")); + assertEquals(6d, eval("(* (strlen 'ab') 3.0)")); + } + + @Test(expected = TypeException.class) + public void testMulInvalidType() { + eval("(* 'apple' 'orange')"); + } + + @Test + public void testDiv() { + assertEquals(2d, eval("(/ 10 5)")); + assertEquals(2d, eval("(/ 10.0 5)")); + assertEquals(2d, eval("(/ 10 5.0)")); + assertEquals(2d, eval("(/ 10.0 5.0)")); + assertEquals(1.71875d, eval("(/ 5.5 3.2)")); + assertEquals(1.5625d, eval("(/ 5 3.2)")); + assertEquals(3.5d, eval("(/ 7 2)")); + } + + @Test(expected = TypeException.class) + public void testDivInvalidType() { + eval("(/ 'apple' 'orange')"); + } + @Test public void testComplex() { assertTrue((boolean)eval("(and (not false) (or (not (or (not true) (not false) )) true))")); @@ -172,6 +280,144 @@ public class InterpreterTest { assertFalse((boolean)eval("(and false (invalid-expression 1 2 3))")); } + @Test + public void testIf() { + assertNull(eval("(if false (invalid-expression))")); + assertEquals("testStr", eval("(if true 'testStr')")); + } + + @Test + public void testIfElse() { + assertEquals("apple", eval("(if false (invalid-expression) (lowercase 'APPLE'))")); + assertEquals("orange", eval("(if true (lowercase 'ORANGE') (invalid-expression) )")); + } + + @Test(expected = ArityException.class) + public void testIfWrongNumberOfArgs() { + eval("(if true 1 2 3)"); + } + + @Test + public void testConcat() { + assertEquals("asdf", eval("(concat 'asdf')")); + assertEquals("asdfjkl", eval("(concat 'asdf' 'jkl')")); + assertEquals("asdfjklqwerty", eval("(concat 'asdf' 'jkl' 'qwerty')")); + assertEquals("orange APPLE", eval("(concat (lowercase 'ORANGE') ' ' (uppercase 'apple'))")); + assertEquals("123", eval("(concat 1 2 3)")); + } + + @Test + public void testSubStr1() { + assertEquals("123456789", eval("(substr '123456789' 0)")); + assertEquals("456789", eval("(substr '123456789' 3)")); + assertEquals("9", eval("(substr '123456789' 8)")); + assertEquals("", eval("(substr '123456789' 9)")); + } + + @Test + public void testSubStr2() { + assertEquals("123456789", eval("(substr '123456789' 0 9)")); + assertEquals("", eval("(substr '123456789' 9 9)")); + assertEquals("123", eval("(substr '123456789' 0 3)")); + assertEquals("3", eval("(substr '123456789' 2 3)")); + assertEquals("345", eval("(substr '123456789' 2 5)")); + } + + @Test + public void testStrLen() { + assertEquals(0, eval("(strlen '')")); + assertEquals(5, eval("(strlen (uppercase 'apple'))")); + } + + @Test + public void testStartsWith() { + assertTrue((boolean)eval("(starts-with '' '')")); + assertTrue((boolean)eval("(starts-with 'apple' '')")); + assertTrue((boolean)eval("(starts-with 'apple' 'ap')")); + assertTrue((boolean)eval("(starts-with 'apple' 'app')")); + assertTrue((boolean)eval("(starts-with 'apple' 'appl')")); + assertTrue((boolean)eval("(starts-with 'apple' 'apple')")); + assertFalse((boolean)eval("(starts-with 'apple' 'applex')")); + assertFalse((boolean)eval("(starts-with '' 'a')")); + } + + @Test + public void testEndsWith() { + assertTrue((boolean)eval("(ends-with '' '')")); + assertTrue((boolean)eval("(ends-with 'apple' '')")); + assertTrue((boolean)eval("(ends-with 'apple' 'e')")); + assertTrue((boolean)eval("(ends-with 'apple' 'ple')")); + assertTrue((boolean)eval("(ends-with 'apple' 'apple' )")); + assertFalse((boolean)eval("(ends-with 'apple' 'xapple' )")); + assertFalse((boolean)eval("(ends-with '' 'a')")); + } + + @Test + public void testStrIn() { + assertTrue((boolean)eval("(contains 'ppl' 'apple')")); + assertTrue((boolean)eval("(contains '' 'apple')")); + assertTrue((boolean)eval("(contains 'a' 'apple')")); + assertFalse((boolean)eval("(contains 'x' 'apple')")); + } + + @Test + public void testStrIndex() { + assertEquals(1, eval("(index-of 'ppl' 'apple')")); + assertEquals(-1, eval("(index-of 'xx' 'apple')")); + } + + @Test + public void testRegexpGroup() { + assertEquals("user.1", eval("(regex-template 'prefix_user-1_suffix' 'prefix_(\\w+)\\-(\\d)_suffix' '{1}.{2}')")); + assertEquals("123", eval("(regex-template 'usr123' 'usr(\\d+)' '{1}')")); + assertEquals("usr123", eval("(regex-template 'usr123' 'usr\\d+' '{0}')")); + assertEquals("{0}", eval("(regex-template 'admin' '\\d+' '{0}')")); + } + + @Test + public void testRegexpGroupWithLookup() { + // See RegexTemplateTest + String script = "(regex-template 'nob...@us.imaginary.tld' '(.*)@(.*?)\\..*' '{1}_{[2]}' (hash 'us' 'USA' 'ca' 'CANADA'))"; + assertEquals("nobody_USA", eval(script)); + + script = "(regex-template 'mem...@us.apache.org' '(.*)@(.*?)\\..*' 'prefix_{1}:{[2]}_suffix' (hash 'us' 'USA' 'ca' 'CANADA'))"; + assertEquals("prefix_member:USA_suffix", eval(script)); + + script = "(regex-template 'mem...@ca.apache.org' '(.*)@(.*?)\\..*' 'prefix_{1}:{[2]}_suffix' (hash 'us' 'USA' 'ca' 'CANADA'))"; + assertEquals("prefix_member:CANADA_suffix", eval(script)); + + script = "(regex-template 'mem...@uk.apache.org' '(.*)@(.*?)\\..*' 'prefix_{1}:{[2]}_suffix' (hash 'us' 'USA' 'ca' 'CANADA'))"; + assertEquals("prefix_member:_suffix", eval(script)); + + script = "(regex-template 'mem...@uk.apache.org' '(.*)@(.*?)\\..*' 'prefix_{1}:{[2]}_suffix' (hash 'us' 'USA' 'ca' 'CANADA') true)"; + assertEquals("prefix_member:uk_suffix", eval(script)); + } + + @Test + public void testHashMaps() { + HashMap<Object, Object> expected = new HashMap<>(); + assertEquals(expected, eval("(hash)")); + expected.put(1L , 2L); + assertEquals(expected, eval("(hash 1 2)")); + expected.put("a", "b"); + assertEquals(expected, eval("(hash 1 2 'a' 'b')")); + expected.clear(); + expected.put("apple123", true); + assertEquals(expected, eval("(hash (lowercase (concat 'Apple' '123')) (and (< 10 12) (> 10 1)))")); + } + + @Test + public void testHashMapLookup() { + assertEquals(2L, eval("(at 1 (hash 1 2 'a' 'b'))")); + assertEquals("b", eval("(at 'a' (hash 1 2 'a' 'b'))")); + assertNull(eval("(at 'b' (hash 1 2 'a' 'b'))")); + } + + @Test(expected = ArityException.class) + public void testHashMapInvalid() { + eval("(hash 'key1' 'value1' 'key2')"); + } + private Object eval(String script) { return interpreter.eval(parser.parse(script)); }