This is an automated email from the ASF dual-hosted git repository.

atoomula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new da0bf39  SAMZA-2289: Samza-sql - Fix udf to work with disabled arg 
check. (#1124)
da0bf39 is described below

commit da0bf393bc5b6e03985b28ae706c3e05ae5e6992
Author: Aditya Toomula <[email protected]>
AuthorDate: Thu Aug 1 12:44:49 2019 -0700

    SAMZA-2289: Samza-sql - Fix udf to work with disabled arg check. (#1124)
    
    * Samza-sql: Fix udf to work with disabled arg check.
    
    * Samza-sql: Fix udf to work with disabled arg check.
    
    * Samza-sql: Fix udf to work with disabled arg check.
---
 .../org/apache/samza/sql/fn/GetNestedFieldUdf.java | 42 ++++++++++++++++++
 .../org/apache/samza/sql/fn/GetSqlFieldUdf.java    | 50 ++++++++++++++--------
 .../sql/planner/SamzaSqlScalarFunctionImpl.java    |  3 +-
 .../apache/samza/sql/util/SamzaSqlTestConfig.java  |  6 ++-
 .../samza/test/samzasql/TestSamzaSqlEndToEnd.java  | 20 +++++++++
 5 files changed, 100 insertions(+), 21 deletions(-)

diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/fn/GetNestedFieldUdf.java 
b/samza-sql/src/main/java/org/apache/samza/sql/fn/GetNestedFieldUdf.java
new file mode 100644
index 0000000..4ef2a11
--- /dev/null
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/GetNestedFieldUdf.java
@@ -0,0 +1,42 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+package org.apache.samza.sql.fn;
+
+import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
+import org.apache.samza.sql.udfs.SamzaSqlUdf;
+import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
+import org.apache.samza.sql.udfs.ScalarUdf;
+
+
+@SamzaSqlUdf(name = "GetNestedField", description = "UDF that extracts a field 
value from a nested SamzaSqlRelRecord")
+public class GetNestedFieldUdf implements ScalarUdf {
+  @Override
+  public void init(Config udfConfig, Context context) {
+  }
+
+  @SamzaSqlUdfMethod(params = {SamzaSqlFieldType.ANY, 
SamzaSqlFieldType.STRING},
+      returns = SamzaSqlFieldType.ANY)
+  public Object execute(Object currentFieldOrValue, String fieldName) {
+    GetSqlFieldUdf udf = new GetSqlFieldUdf();
+    return udf.getSqlField(currentFieldOrValue, fieldName);
+  }
+}
diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java 
b/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java
index f0fbf75..ec05d55 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java
@@ -21,7 +21,8 @@ package org.apache.samza.sql.fn;
 
 import java.util.List;
 import java.util.Map;
-import org.apache.commons.lang.Validate;
+import org.apache.avro.util.Utf8;
+import org.apache.commons.lang3.Validate;
 import org.apache.samza.config.Config;
 import org.apache.samza.context.Context;
 import org.apache.samza.sql.SamzaSqlRelRecord;
@@ -53,22 +54,15 @@ import org.apache.samza.sql.udfs.ScalarUdf;
  *           - sessionKey (Scalar)
  *
  */
-@SamzaSqlUdf(name = "GetSqlField", description = "Get an element from complex 
Sql field as a String.")
+@SamzaSqlUdf(name = "GetSqlField", description = "Deprecated : Please use 
GetNestedField.")
 public class GetSqlFieldUdf implements ScalarUdf {
   @Override
   public void init(Config udfConfig, Context context) {
   }
 
-  @SamzaSqlUdfMethod(params = {SamzaSqlFieldType.ANY, 
SamzaSqlFieldType.STRING})
-  public String execute(Object field, String fieldName) {
-    Object currentFieldOrValue = field;
-    Validate.isTrue(currentFieldOrValue == null
-        || currentFieldOrValue instanceof SamzaSqlRelRecord);
-
-    String[] fieldNameChain = fieldName.split("\\.");
-    for (int i = 0; i < fieldNameChain.length && currentFieldOrValue != null; 
i++) {
-      currentFieldOrValue = extractField(fieldNameChain[i], 
currentFieldOrValue);
-    }
+  @SamzaSqlUdfMethod(params = {SamzaSqlFieldType.ANY, 
SamzaSqlFieldType.STRING}, returns = SamzaSqlFieldType.STRING)
+  public String execute(Object currentFieldOrValue, String fieldName) {
+    currentFieldOrValue = getSqlField(currentFieldOrValue, fieldName);
 
     if (currentFieldOrValue != null) {
       return currentFieldOrValue.toString();
@@ -77,23 +71,41 @@ public class GetSqlFieldUdf implements ScalarUdf {
     return null;
   }
 
-  static Object extractField(String fieldName, Object current) {
+  public Object getSqlField(Object currentFieldOrValue, String fieldName) {
+    if (currentFieldOrValue != null) {
+      String[] fieldNameChain = (fieldName).split("\\.");
+      for (int i = 0; i < fieldNameChain.length && currentFieldOrValue != 
null; i++) {
+        currentFieldOrValue = extractField(fieldNameChain[i], 
currentFieldOrValue, true);
+      }
+    }
+
+    return currentFieldOrValue;
+  }
+
+  static Object extractField(String fieldName, Object current, boolean 
validateField) {
     if (current instanceof SamzaSqlRelRecord) {
       SamzaSqlRelRecord record = (SamzaSqlRelRecord) current;
-      Validate.isTrue(record.getFieldNames().contains(fieldName),
-          String.format("Invalid field %s in %s", fieldName, record));
+      if (validateField) {
+        Validate.isTrue(record.getFieldNames().contains(fieldName),
+            String.format("Invalid field %s in record %s", fieldName, record));
+      }
       return record.getField(fieldName).orElse(null);
     } else if (current instanceof Map) {
       Map map = (Map) current;
-      Validate.isTrue(map.containsKey(fieldName), String.format("Invalid field 
%s in %s", fieldName, map));
-      return map.get(fieldName);
+      if (map.containsKey(fieldName)) {
+        return map.get(fieldName);
+      } else if (map.containsKey(new Utf8(fieldName))) {
+        return map.get(new Utf8(fieldName));
+      } else {
+        throw new IllegalArgumentException(String.format("Couldn't find the 
field %s in map %s", fieldName, map));
+      }
     } else if (current instanceof List && fieldName.endsWith("]")) {
       List list = (List) current;
       int index = Integer.parseInt(fieldName.substring(fieldName.indexOf("[") 
+ 1, fieldName.length() - 1));
       return list.get(index);
     }
 
-    throw new IllegalArgumentException(String.format(
-        "Unsupported accessing operation for data type: %s with field: %s.", 
current.getClass(), fieldName));
+    throw new IllegalArgumentException(
+        String.format("Unsupported accessing operation for data type: %s with 
field: %s.", current.getClass(), fieldName));
   }
 }
diff --git 
a/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
 
b/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
index b2b6119..21a48e9 100644
--- 
a/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
+++ 
b/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
@@ -87,7 +87,8 @@ public class SamzaSqlScalarFunctionImpl implements 
ScalarFunction, Implementable
       // SAMZA: 2230 To allow UDFS to accept Untyped arguments.
       // We explicitly Convert the untyped arguments to type that the UDf 
expects.
       for(int index = 0; index < translatedOperands.size(); index++) {
-        if (translatedOperands.get(index).type == Object.class && 
udfMethod.getParameters()[index].getType() != Object.class) {
+        if (!udfMetadata.isDisableArgCheck() && 
translatedOperands.get(index).type == Object.class
+            && udfMethod.getParameters()[index].getType() != Object.class) {
           
convertedOperands.add(Expressions.convert_(translatedOperands.get(index), 
udfMethod.getParameters()[index].getType()));
         } else {
           convertedOperands.add(translatedOperands.get(index));
diff --git 
a/samza-sql/src/test/java/org/apache/samza/sql/util/SamzaSqlTestConfig.java 
b/samza-sql/src/test/java/org/apache/samza/sql/util/SamzaSqlTestConfig.java
index c6067a3..b3bb1ee 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/util/SamzaSqlTestConfig.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/util/SamzaSqlTestConfig.java
@@ -37,6 +37,7 @@ import org.apache.samza.sql.avro.schemas.Profile;
 import org.apache.samza.sql.avro.schemas.SimpleRecord;
 import org.apache.samza.sql.fn.BuildOutputRecordUdf;
 import org.apache.samza.sql.fn.FlattenUdf;
+import org.apache.samza.sql.fn.GetNestedFieldUdf;
 import org.apache.samza.sql.fn.RegexMatchUdf;
 import org.apache.samza.sql.impl.ConfigBasedIOResolverFactory;
 import org.apache.samza.sql.impl.ConfigBasedUdfResolver;
@@ -100,7 +101,7 @@ public class SamzaSqlTestConfig {
     staticConfigs.put(configUdfResolverDomain + 
ConfigBasedUdfResolver.CFG_UDF_CLASSES, Joiner.on(",")
         .join(MyTestUdf.class.getName(), RegexMatchUdf.class.getName(), 
FlattenUdf.class.getName(),
             MyTestArrayUdf.class.getName(), 
BuildOutputRecordUdf.class.getName(), MyTestPolyUdf.class.getName(),
-            MyTestObjUdf.class.getName()));
+            MyTestObjUdf.class.getName(), GetNestedFieldUdf.class.getName()));
 
     String avroSystemConfigPrefix =
         String.format(ConfigBasedIOResolverFactory.CFG_FMT_SAMZA_PREFIX, 
SAMZA_SYSTEM_TEST_AVRO);
@@ -189,6 +190,9 @@ public class SamzaSqlTestConfig {
         "testavro", "PROFILE"), Profile.SCHEMA$.toString());
 
     staticConfigs.put(configAvroRelSchemaProviderDomain + 
String.format(ConfigBasedAvroRelSchemaProviderFactory.CFG_SOURCE_SCHEMA,
+        "testavro", "PROFILE1"), Profile.SCHEMA$.toString());
+
+    staticConfigs.put(configAvroRelSchemaProviderDomain + 
String.format(ConfigBasedAvroRelSchemaProviderFactory.CFG_SOURCE_SCHEMA,
         "testavro", "PAGEVIEW"), PageView.SCHEMA$.toString());
 
     staticConfigs.put(configAvroRelSchemaProviderDomain + 
String.format(ConfigBasedAvroRelSchemaProviderFactory.CFG_SOURCE_SCHEMA,
diff --git 
a/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java
 
b/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java
index eb47af1..0cec337 100644
--- 
a/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java
+++ 
b/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java
@@ -523,6 +523,26 @@ public class TestSamzaSqlEndToEnd extends 
SamzaSqlIntegrationTestHarness {
   }
 
   @Test
+  public void testEndToEndUdfWithDisabledArgCheck() throws Exception {
+    int numMessages = 20;
+    TestAvroSystemFactory.messages.clear();
+    Map<String, String> staticConfigs = 
SamzaSqlTestConfig.fetchStaticConfigsWithFactories(configs, numMessages);
+    String sql1 = "Insert into testavro.PROFILE1(id, address) "
+        + "select id, BuildOutputRecord('key', GetNestedField(address, 'zip')) 
as address from testavro.PROFILE";
+    List<String> sqlStmts = Collections.singletonList(sql1);
+    staticConfigs.put(SamzaSqlApplicationConfig.CFG_SQL_STMTS_JSON, 
JsonUtil.toJson(sqlStmts));
+    runApplication(new MapConfig(staticConfigs));
+
+    LOG.info("output Messages " + TestAvroSystemFactory.messages);
+
+    List<Integer> outMessages = TestAvroSystemFactory.messages.stream()
+        .map(x -> Integer.valueOf(((GenericRecord) 
x.getMessage()).get("id").toString()))
+        .sorted()
+        .collect(Collectors.toList());
+    Assert.assertEquals(outMessages.size(), numMessages);
+  }
+
+  @Test
   public void testEndToEndUdfPolymorphism() throws Exception {
     int numMessages = 20;
     TestAvroSystemFactory.messages.clear();

Reply via email to