This is an automated email from the ASF dual-hosted git repository.

srinivasulu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new 29c0a1c  SAMZA-2234: Passing context to Samza SQL UDFs (#1063)
29c0a1c is described below

commit 29c0a1cdcc269324bff89bcb8ad2de922dbae019
Author: Srinivasulu Punuru <[email protected]>
AuthorDate: Thu Jun 6 09:12:10 2019 -0700

    SAMZA-2234: Passing context to Samza SQL UDFs (#1063)
    
    * Passing context to UDFs
    
    * Fix for the tests
    
    * Fixed doc comments
---
 .../src/main/java/org/apache/samza/sql/udfs/ScalarUdf.java |  4 +++-
 .../main/java/org/apache/samza/sql/data/Expression.java    |  8 +++++---
 .../java/org/apache/samza/sql/data/RexToJavaCompiler.java  | 14 ++++++++------
 .../apache/samza/sql/data/SamzaSqlExecutionContext.java    |  9 +++++----
 .../java/org/apache/samza/sql/fn/BuildOutputRecordUdf.java |  3 ++-
 .../java/org/apache/samza/sql/fn/ConvertToStringUdf.java   |  3 ++-
 .../src/main/java/org/apache/samza/sql/fn/FlattenUdf.java  |  3 ++-
 .../main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java  |  3 ++-
 .../main/java/org/apache/samza/sql/fn/RegexMatchUdf.java   |  3 ++-
 .../samza/sql/planner/SamzaSqlScalarFunctionImpl.java      |  7 ++++---
 .../org/apache/samza/sql/translator/FilterTranslator.java  |  4 +++-
 .../org/apache/samza/sql/translator/ProjectTranslator.java |  4 +++-
 .../apache/samza/sql/translator/TestFilterTranslator.java  |  8 ++++----
 .../apache/samza/sql/translator/TestProjectTranslator.java | 11 +++++++----
 .../java/org/apache/samza/sql/util/MyTestArrayUdf.java     |  3 ++-
 .../test/java/org/apache/samza/sql/util/MyTestPolyUdf.java |  3 ++-
 .../src/test/java/org/apache/samza/sql/util/MyTestUdf.java |  3 ++-
 17 files changed, 58 insertions(+), 35 deletions(-)

diff --git a/samza-api/src/main/java/org/apache/samza/sql/udfs/ScalarUdf.java 
b/samza-api/src/main/java/org/apache/samza/sql/udfs/ScalarUdf.java
index ff67487..5dda42f 100644
--- a/samza-api/src/main/java/org/apache/samza/sql/udfs/ScalarUdf.java
+++ b/samza-api/src/main/java/org/apache/samza/sql/udfs/ScalarUdf.java
@@ -20,6 +20,7 @@
 package org.apache.samza.sql.udfs;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 
 
 /**
@@ -34,6 +35,7 @@ public interface ScalarUdf {
   /**
    * Udfs can implement this method to perform any initialization that they 
may need.
    * @param udfConfig Config specific to the udf.
+   * @param context Samza application and framework context
    */
-  void init(Config udfConfig);
+  void init(Config udfConfig, Context context);
 }
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/data/Expression.java 
b/samza-sql/src/main/java/org/apache/samza/sql/data/Expression.java
index 9386e31..dabea07 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/data/Expression.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/data/Expression.java
@@ -20,6 +20,7 @@
 package org.apache.samza.sql.data;
 
 import org.apache.calcite.DataContext;
+import org.apache.samza.context.Context;
 
 
 /**
@@ -29,10 +30,11 @@ import org.apache.calcite.DataContext;
 public interface Expression {
   /**
    * This method is used to implement the expressions that takes in columns as 
input and returns multiple values.
-   * @param context the context
-   * @param root the root
+   * @param sqlContext SamzaSqlExecution context
+   * @param context Samza context that contains both framework and application 
specific context.
+   * @param root Calcite DataContext
    * @param inputValues All the relational columns for the particular row
    * @param results the results Result values after executing the java code 
corresponding to the relational expression.
    */
-  void execute(SamzaSqlExecutionContext context, DataContext root, Object[] 
inputValues, Object[] results);
+  void execute(SamzaSqlExecutionContext sqlContext, Context context, 
DataContext root, Object[] inputValues, Object[] results);
 }
diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/data/RexToJavaCompiler.java 
b/samza-sql/src/main/java/org/apache/samza/sql/data/RexToJavaCompiler.java
index 1cfa95f..37ce229 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/data/RexToJavaCompiler.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/data/RexToJavaCompiler.java
@@ -49,6 +49,7 @@ import org.apache.calcite.rex.RexProgram;
 import org.apache.calcite.rex.RexProgramBuilder;
 import org.apache.calcite.util.Pair;
 import org.apache.samza.SamzaException;
+import org.apache.samza.context.Context;
 import org.apache.samza.sql.interfaces.SamzaSqlJavaTypeFactoryImpl;
 import org.codehaus.commons.compiler.CompileException;
 import org.codehaus.commons.compiler.CompilerFactoryFactory;
@@ -111,7 +112,8 @@ public class RexToJavaCompiler {
     final RexProgram program = programBuilder.getProgram();
 
     final BlockBuilder builder = new BlockBuilder();
-    final ParameterExpression executionContext = 
Expressions.parameter(SamzaSqlExecutionContext.class, "context");
+    final ParameterExpression sqlContext = 
Expressions.parameter(SamzaSqlExecutionContext.class, "sqlContext");
+    final ParameterExpression context = Expressions.parameter(Context.class, 
"context");
     final ParameterExpression root = DataContext.ROOT;
     final ParameterExpression inputValues = 
Expressions.parameter(Object[].class, "inputValues");
     final ParameterExpression outputValues = 
Expressions.parameter(Object[].class, "outputValues");
@@ -130,7 +132,7 @@ public class RexToJavaCompiler {
       builder.add(Expressions.statement(
           Expressions.assign(Expressions.arrayIndex(outputValues, 
Expressions.constant(i)), list.get(i))));
     }
-    return createSamzaExpressionFromCalcite(executionContext, root, 
inputValues, outputValues, builder.toBlock());
+    return createSamzaExpressionFromCalcite(sqlContext, context, root, 
inputValues, outputValues, builder.toBlock());
   }
 
   /**
@@ -159,14 +161,14 @@ public class RexToJavaCompiler {
    *
    */
   static org.apache.samza.sql.data.Expression 
createSamzaExpressionFromCalcite(ParameterExpression executionContext,
-      ParameterExpression dataContext, ParameterExpression inputValues, 
ParameterExpression outputValues,
-      BlockStatement block) {
+      ParameterExpression context, ParameterExpression dataContext, 
ParameterExpression inputValues,
+      ParameterExpression outputValues, BlockStatement block) {
     final List<MemberDeclaration> declarations = Lists.newArrayList();
 
     // public void execute(Object[] inputValues, Object[] outputValues)
     declarations.add(
         Expressions.methodDecl(Modifier.PUBLIC, void.class, 
SamzaBuiltInMethod.EXPR_EXECUTE2.method.getName(),
-            ImmutableList.of(executionContext, dataContext, inputValues, 
outputValues), block));
+            ImmutableList.of(executionContext, context, dataContext, 
inputValues, outputValues), block));
 
     final ClassDeclaration classDeclaration = 
Expressions.classDecl(Modifier.PUBLIC, "SqlExpression", null,
         ImmutableList.<Type>of(org.apache.samza.sql.data.Expression.class), 
declarations);
@@ -210,7 +212,7 @@ public class RexToJavaCompiler {
    */
   public enum SamzaBuiltInMethod {
     EXPR_EXECUTE2(org.apache.samza.sql.data.Expression.class, "execute", 
SamzaSqlExecutionContext.class,
-        DataContext.class, Object[].class, Object[].class);
+        Context.class, DataContext.class, Object[].class, Object[].class);
 
     public final Method method;
 
diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java
 
b/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java
index 65f7ebd..3307bcb 100644
--- 
a/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java
+++ 
b/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java
@@ -26,6 +26,7 @@ import java.util.Map;
 
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.sql.interfaces.UdfMetadata;
 import org.apache.samza.sql.runner.SamzaSqlApplicationConfig;
 import org.apache.samza.sql.udfs.ScalarUdf;
@@ -62,15 +63,15 @@ public class SamzaSqlExecutionContext implements Cloneable {
     }
   }
 
-  public ScalarUdf getOrCreateUdf(String clazz, String udfName) {
-    return udfInstances.computeIfAbsent(udfName, s -> createInstance(clazz, 
udfName));
+  public ScalarUdf getOrCreateUdf(String clazz, String udfName, Context 
context) {
+    return udfInstances.computeIfAbsent(udfName, s -> createInstance(clazz, 
udfName, context));
   }
 
-  public ScalarUdf createInstance(String clazz, String udfName) {
+  public ScalarUdf createInstance(String clazz, String udfName, Context 
context) {
     // Configs should be same for all the UDF methods within a UDF. Hence 
taking the first one.
     Config udfConfig = udfMetadata.get(udfName).get(0).getUdfConfig();
     ScalarUdf scalarUdf = ReflectionUtil.getObj(getClass().getClassLoader(), 
clazz, ScalarUdf.class);
-    scalarUdf.init(udfConfig);
+    scalarUdf.init(udfConfig, context);
     return scalarUdf;
   }
 
diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/fn/BuildOutputRecordUdf.java 
b/samza-sql/src/main/java/org/apache/samza/sql/fn/BuildOutputRecordUdf.java
index e0c34f1..7f9de1a 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/BuildOutputRecordUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/BuildOutputRecordUdf.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
 import java.util.List;
 import org.apache.commons.lang.Validate;
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.sql.SamzaSqlRelRecord;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
@@ -64,7 +65,7 @@ import org.apache.samza.sql.udfs.ScalarUdf;
 @SamzaSqlUdf(name = "BuildOutputRecord", description = "Creates an Output 
record.")
 public class BuildOutputRecordUdf implements ScalarUdf {
   @Override
-  public void init(Config udfConfig) {
+  public void init(Config udfConfig, Context context) {
   }
 
   @SamzaSqlUdfMethod(disableArgumentCheck = true)
diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/fn/ConvertToStringUdf.java 
b/samza-sql/src/main/java/org/apache/samza/sql/fn/ConvertToStringUdf.java
index 659f7e3..f7e599b 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/ConvertToStringUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/ConvertToStringUdf.java
@@ -20,6 +20,7 @@
 package org.apache.samza.sql.fn;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
@@ -32,7 +33,7 @@ import org.apache.samza.sql.udfs.ScalarUdf;
 @SamzaSqlUdf(name = "convertToString", description = "Converts the object to 
string.")
 public class ConvertToStringUdf implements ScalarUdf {
   @Override
-  public void init(Config udfConfig) {
+  public void init(Config udfConfig, Context context) {
   }
 
   @SamzaSqlUdfMethod(params = SamzaSqlFieldType.ANY)
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/fn/FlattenUdf.java 
b/samza-sql/src/main/java/org/apache/samza/sql/fn/FlattenUdf.java
index 0734a3a..a24a992 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/FlattenUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/FlattenUdf.java
@@ -21,6 +21,7 @@ package org.apache.samza.sql.fn;
 
 import java.util.List;
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
@@ -30,7 +31,7 @@ import org.apache.samza.sql.udfs.ScalarUdf;
 @SamzaSqlUdf(name = "Flatten", description = "Flattens the array.")
 public class FlattenUdf implements ScalarUdf {
   @Override
-  public void init(Config udfConfig) {
+  public void init(Config udfConfig, Context context) {
   }
 
   @SamzaSqlUdfMethod(params = SamzaSqlFieldType.ARRAY)
diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java 
b/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java
index de56fa0..f0fbf75 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java
@@ -23,6 +23,7 @@ import java.util.List;
 import java.util.Map;
 import org.apache.commons.lang.Validate;
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.sql.SamzaSqlRelRecord;
 import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
@@ -55,7 +56,7 @@ import org.apache.samza.sql.udfs.ScalarUdf;
 @SamzaSqlUdf(name = "GetSqlField", description = "Get an element from complex 
Sql field as a String.")
 public class GetSqlFieldUdf implements ScalarUdf {
   @Override
-  public void init(Config udfConfig) {
+  public void init(Config udfConfig, Context context) {
   }
 
   @SamzaSqlUdfMethod(params = {SamzaSqlFieldType.ANY, 
SamzaSqlFieldType.STRING})
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/fn/RegexMatchUdf.java 
b/samza-sql/src/main/java/org/apache/samza/sql/fn/RegexMatchUdf.java
index c157112..4ae7a80 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/RegexMatchUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/RegexMatchUdf.java
@@ -21,6 +21,7 @@ package org.apache.samza.sql.fn;
 
 import java.util.regex.Pattern;
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
@@ -33,7 +34,7 @@ import org.apache.samza.sql.udfs.ScalarUdf;
 @SamzaSqlUdf(name="RegexMatch", description = "Function to perform the regex 
match.")
 public class RegexMatchUdf implements ScalarUdf {
   @Override
-  public void init(Config config) {
+  public void init(Config config, Context context) {
 
   }
 
diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
 
b/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
index c5d0121..0793bce 100644
--- 
a/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
+++ 
b/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
@@ -77,9 +77,10 @@ public class SamzaSqlScalarFunctionImpl implements 
ScalarFunction, Implementable
   @Override
   public CallImplementor getImplementor() {
     return RexImpTable.createImplementor((translator, call, 
translatedOperands) -> {
-      final Expression context = 
Expressions.parameter(SamzaSqlExecutionContext.class, "context");
-      final Expression getUdfInstance = Expressions.call(ScalarUdf.class, 
context, getUdfMethod,
-          Expressions.constant(udfMethod.getDeclaringClass().getName()), 
Expressions.constant(udfName));
+      final Expression sqlContext = 
Expressions.parameter(SamzaSqlExecutionContext.class, "sqlContext");
+      final Expression samzaContext = 
Expressions.parameter(SamzaSqlExecutionContext.class, "context");
+      final Expression getUdfInstance = Expressions.call(ScalarUdf.class, 
sqlContext, getUdfMethod,
+          Expressions.constant(udfMethod.getDeclaringClass().getName()), 
Expressions.constant(udfName), samzaContext);
       final Expression callExpression = 
Expressions.call(Expressions.convert_(getUdfInstance, 
udfMethod.getDeclaringClass()), udfMethod,
           translatedOperands);
       return callExpression;
diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java 
b/samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java
index 911024a..ccab15a 100644
--- 
a/samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java
+++ 
b/samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java
@@ -69,6 +69,7 @@ class FilterTranslator {
     private final int queryId;
     private final int filterId;
     private final String logicalOpId;
+    private Context context;
 
     FilterTranslatorFunction(int filterId, int queryId, String logicalOpId) {
       this.filterId = filterId;
@@ -78,6 +79,7 @@ class FilterTranslator {
 
     @Override
     public void init(Context context) {
+      this.context = context;
       this.translatorContext = ((SamzaSqlApplicationContext) 
context.getApplicationTaskContext()).getTranslatorContexts().get(queryId);
       this.filter = (LogicalFilter) 
this.translatorContext.getRelNode(filterId);
       this.expr = 
this.translatorContext.getExpressionCompiler().compile(filter.getInputs(), 
Collections.singletonList(filter.getCondition()));
@@ -96,7 +98,7 @@ class FilterTranslator {
     public boolean apply(SamzaSqlRelMessage message) {
       Instant startProcessing = Instant.now();
       Object[] result = new Object[1];
-      expr.execute(translatorContext.getExecutionContext(), 
translatorContext.getDataContext(),
+      expr.execute(translatorContext.getExecutionContext(), context, 
translatorContext.getDataContext(),
           message.getSamzaSqlRelRecord().getFieldValues().toArray(), result);
       if (result.length > 0 && result[0] instanceof Boolean) {
         boolean retVal = (Boolean) result[0];
diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/translator/ProjectTranslator.java
 
b/samza-sql/src/main/java/org/apache/samza/sql/translator/ProjectTranslator.java
index 6e6ff45..8bbe84a 100644
--- 
a/samza-sql/src/main/java/org/apache/samza/sql/translator/ProjectTranslator.java
+++ 
b/samza-sql/src/main/java/org/apache/samza/sql/translator/ProjectTranslator.java
@@ -77,6 +77,7 @@ class ProjectTranslator {
     private final int queryId;
     private final int projectId;
     private final String logicalOpId;
+    private Context context;
 
     ProjectMapFunction(int projectId, int queryId, String logicalOpId) {
       this.projectId = projectId;
@@ -90,6 +91,7 @@ class ProjectTranslator {
      */
     @Override
     public void init(Context context) {
+      this.context = context;
       this.translatorContext = ((SamzaSqlApplicationContext) 
context.getApplicationTaskContext()).getTranslatorContexts().get(queryId);
       this.project = (Project) this.translatorContext.getRelNode(projectId);
       this.expr = 
this.translatorContext.getExpressionCompiler().compile(project.getInputs(), 
project.getProjects());
@@ -112,7 +114,7 @@ class ProjectTranslator {
       Instant arrivalTime = Instant.now();
       RelDataType type = project.getRowType();
       Object[] output = new Object[type.getFieldCount()];
-      expr.execute(translatorContext.getExecutionContext(), 
translatorContext.getDataContext(),
+      expr.execute(translatorContext.getExecutionContext(), context, 
translatorContext.getDataContext(),
           message.getSamzaSqlRelRecord().getFieldValues().toArray(), output);
       List<String> names = new ArrayList<>();
       for (int index = 0; index < output.length; index++) {
diff --git 
a/samza-sql/src/test/java/org/apache/samza/sql/translator/TestFilterTranslator.java
 
b/samza-sql/src/test/java/org/apache/samza/sql/translator/TestFilterTranslator.java
index 037201e..2a98c8c 100644
--- 
a/samza-sql/src/test/java/org/apache/samza/sql/translator/TestFilterTranslator.java
+++ 
b/samza-sql/src/test/java/org/apache/samza/sql/translator/TestFilterTranslator.java
@@ -133,18 +133,18 @@ public class TestFilterTranslator extends 
TranslatorTestBase {
     Object[] result = new Object[1];
 
     doAnswer( invocation -> {
-      Object[] retValue = invocation.getArgumentAt(3, Object[].class);
+      Object[] retValue = invocation.getArgumentAt(4, Object[].class);
       retValue[0] = new Boolean(true);
       return null;
-    }).when(mockExpr).execute(eq(executionContext), eq(dataContext),
+    }).when(mockExpr).execute(eq(executionContext), eq(mockContext), 
eq(dataContext),
         eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), 
eq(result));
     assertTrue(filterFn.apply(mockInputMsg));
 
     doAnswer( invocation -> {
-      Object[] retValue = invocation.getArgumentAt(3, Object[].class);
+      Object[] retValue = invocation.getArgumentAt(4, Object[].class);
       retValue[0] = new Boolean(false);
       return null;
-    }).when(mockExpr).execute(eq(executionContext), eq(dataContext),
+    }).when(mockExpr).execute(eq(executionContext), eq(mockContext), 
eq(dataContext),
         eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), 
eq(result));
     assertFalse(filterFn.apply(mockInputMsg));
 
diff --git 
a/samza-sql/src/test/java/org/apache/samza/sql/translator/TestProjectTranslator.java
 
b/samza-sql/src/test/java/org/apache/samza/sql/translator/TestProjectTranslator.java
index 050971f..0129284 100644
--- 
a/samza-sql/src/test/java/org/apache/samza/sql/translator/TestProjectTranslator.java
+++ 
b/samza-sql/src/test/java/org/apache/samza/sql/translator/TestProjectTranslator.java
@@ -52,6 +52,7 @@ import org.apache.samza.sql.runner.SamzaSqlApplicationContext;
 import org.apache.samza.sql.util.TestMetricsRegistryImpl;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.internal.matchers.Any;
 import org.mockito.internal.util.reflection.Whitebox;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
@@ -73,6 +74,7 @@ import static org.mockito.Mockito.when;
 @PrepareForTest(LogicalProject.class)
 public class TestProjectTranslator extends TranslatorTestBase {
   final private String LOGICAL_OP_ID = "sql0_project_0";
+
   @Test
   public void testTranslate() throws IOException, ClassNotFoundException {
     // setup mock values to the constructor of FilterTranslator
@@ -150,10 +152,10 @@ public class TestProjectTranslator extends 
TranslatorTestBase {
     final Object mockFieldObj = new Object();
 
     doAnswer( invocation -> {
-      Object[] retValue = invocation.getArgumentAt(3, Object[].class);
+      Object[] retValue = invocation.getArgumentAt(4, Object[].class);
       retValue[0] = mockFieldObj;
       return null;
-    }).when(mockExpr).execute(eq(executionContext), eq(dataContext),
+    }).when(mockExpr).execute(eq(executionContext), eq(mockContext), 
eq(dataContext),
         eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), 
eq(result));
     SamzaSqlRelMessage retMsg = (SamzaSqlRelMessage) mapFn.apply(mockInputMsg);
     assertEquals(retMsg.getSamzaSqlRelRecord().getFieldNames(),
@@ -170,6 +172,7 @@ public class TestProjectTranslator extends 
TranslatorTestBase {
 
   }
 
+
   @Test
   public void testTranslateWithFlatten() throws IOException, 
ClassNotFoundException {
     // setup mock values to the constructor of ProjectTranslator
@@ -308,10 +311,10 @@ public class TestProjectTranslator extends 
TranslatorTestBase {
     final Object mockFieldObj = new Object();
 
     doAnswer( invocation -> {
-      Object[] retValue = invocation.getArgumentAt(3, Object[].class);
+      Object[] retValue = invocation.getArgumentAt(4, Object[].class);
       retValue[0] = mockFieldObj;
       return null;
-    }).when(mockExpr).execute(eq(executionContext), eq(dataContext),
+    }).when(mockExpr).execute(eq(executionContext), eq(mockContext), 
eq(dataContext),
         eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), 
eq(result));
     SamzaSqlRelMessage retMsg = (SamzaSqlRelMessage) mapFn.apply(mockInputMsg);
     assertEquals(retMsg.getSamzaSqlRelRecord().getFieldNames(),
diff --git 
a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestArrayUdf.java 
b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestArrayUdf.java
index 7f6ee50..4bda04b 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestArrayUdf.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestArrayUdf.java
@@ -23,6 +23,7 @@ import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
@@ -32,7 +33,7 @@ import org.apache.samza.sql.udfs.ScalarUdf;
 @SamzaSqlUdf(name = "MyTestArray", description = "Test udf that returns an 
array")
 public class MyTestArrayUdf implements ScalarUdf {
   @Override
-  public void init(Config udfConfig) {
+  public void init(Config udfConfig, Context context) {
   }
 
   @SamzaSqlUdfMethod(params = SamzaSqlFieldType.INT32)
diff --git 
a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestPolyUdf.java 
b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestPolyUdf.java
index f4afbd6..29769c0 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestPolyUdf.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestPolyUdf.java
@@ -19,6 +19,7 @@
 package org.apache.samza.sql.util;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
@@ -46,7 +47,7 @@ public class MyTestPolyUdf implements ScalarUdf {
 
 
   @Override
-  public void init(Config udfConfig) {
+  public void init(Config udfConfig, Context context) {
     LOG.info("Init called with {}", udfConfig);
   }
 }
diff --git a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java 
b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java
index 35a44e3..d0ac517 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java
@@ -20,6 +20,7 @@
 package org.apache.samza.sql.util;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
@@ -48,7 +49,7 @@ public class MyTestUdf implements ScalarUdf {
 
 
   @Override
-  public void init(Config udfConfig) {
+  public void init(Config udfConfig, Context context) {
     LOG.info("Init called with {}", udfConfig);
   }
 }

Reply via email to