This is an automated email from the ASF dual-hosted git repository.

critas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 7e1152938c4 fix: UDAFPatternMatch add validate (#14573)
7e1152938c4 is described below

commit 7e1152938c4422b584d8b208856dd6505f12c176
Author: CritasWang <[email protected]>
AuthorDate: Fri Dec 27 22:13:54 2024 +0800

    fix: UDAFPatternMatch add validate (#14573)
    
    * fix: UDAFPatternMatch add validate
    
    * fix: Add boundary conditions.
---
 .../iotdb/library/match/PatternExecutor.java       |  2 +-
 .../iotdb/library/match/UDAFPatternMatch.java      | 62 ++++++++++++++-----
 .../org/apache/iotdb/library/UDAFPatternTest.java  | 70 ++++++++++++++++++++++
 3 files changed, 117 insertions(+), 17 deletions(-)

diff --git 
a/library-udf/src/main/java/org/apache/iotdb/library/match/PatternExecutor.java 
b/library-udf/src/main/java/org/apache/iotdb/library/match/PatternExecutor.java
index b37a0c9ce28..01d1c533493 100644
--- 
a/library-udf/src/main/java/org/apache/iotdb/library/match/PatternExecutor.java
+++ 
b/library-udf/src/main/java/org/apache/iotdb/library/match/PatternExecutor.java
@@ -390,7 +390,7 @@ public class PatternExecutor {
       boolean partialQuery) {
     PatternCalculationResult pointsMatchRes =
         calculatePointsMatch(querySections, matchedSections, partialQuery);
-    if (pointsMatchRes == null) {
+    if (pointsMatchRes == null || pointsMatchRes.getMatchedPoints().isEmpty()) 
{
       return null;
     }
     if (pointsMatchRes.getMatch() > queryCtx.getThreshold()) {
diff --git 
a/library-udf/src/main/java/org/apache/iotdb/library/match/UDAFPatternMatch.java
 
b/library-udf/src/main/java/org/apache/iotdb/library/match/UDAFPatternMatch.java
index 0f8798f979b..6c12ff0de2b 100644
--- 
a/library-udf/src/main/java/org/apache/iotdb/library/match/UDAFPatternMatch.java
+++ 
b/library-udf/src/main/java/org/apache/iotdb/library/match/UDAFPatternMatch.java
@@ -29,6 +29,7 @@ import org.apache.iotdb.udf.api.UDAF;
 import org.apache.iotdb.udf.api.customizer.config.UDAFConfigurations;
 import org.apache.iotdb.udf.api.customizer.parameter.UDFParameterValidator;
 import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters;
+import org.apache.iotdb.udf.api.exception.UDFParameterNotValidException;
 import org.apache.iotdb.udf.api.type.Type;
 import org.apache.iotdb.udf.api.utils.ResultValue;
 
@@ -40,9 +41,14 @@ import java.nio.charset.Charset;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.IntStream;
 
 public class UDAFPatternMatch implements UDAF {
 
+  static final String THRESHOLD_PARAM = "threshold";
+  static final String TIME_PATTERN_PARAM = "timePattern";
+  static final String VALUE_PATTERN_PARAM = "valuePattern";
+
   private Long[] timePattern;
   private Double[] valuePattern;
   private float threshold;
@@ -52,19 +58,7 @@ public class UDAFPatternMatch implements UDAF {
   public void beforeStart(UDFParameters udfParameters, UDAFConfigurations 
udafConfigurations) {
     udafConfigurations.setOutputDataType(Type.TEXT);
     Map<String, String> attributes = udfParameters.getAttributes();
-    if (!attributes.containsKey("threshold")) {
-      threshold = 100;
-    } else {
-      threshold = Float.parseFloat(attributes.get("threshold"));
-    }
-    timePattern =
-        Arrays.stream(attributes.get("timePattern").split(","))
-            .map(Long::valueOf)
-            .toArray(Long[]::new);
-    valuePattern =
-        Arrays.stream(attributes.get("valuePattern").split(","))
-            .map(Double::valueOf)
-            .toArray(Double[]::new);
+    threshold = Float.parseFloat(attributes.get(THRESHOLD_PARAM));
   }
 
   @Override
@@ -136,13 +130,49 @@ public class UDAFPatternMatch implements UDAF {
 
   @Override
   public void validate(UDFParameterValidator validator) {
+
+    try {
+      String timePatternStr = 
validator.getParameters().getStringOrDefault(TIME_PATTERN_PARAM, "");
+      timePattern =
+          
Arrays.stream(timePatternStr.split(",")).map(Long::valueOf).toArray(Long[]::new);
+
+    } catch (Exception e) {
+      throw new UDFParameterNotValidException(
+          "Illegal parameter, timePattern must be long,long...");
+    }
+    try {
+      String valuePatternStr =
+          validator.getParameters().getStringOrDefault(VALUE_PATTERN_PARAM, 
"");
+      valuePattern =
+          
Arrays.stream(valuePatternStr.split(",")).map(Double::valueOf).toArray(Double[]::new);
+    } catch (Exception e) {
+      throw new UDFParameterNotValidException(
+          "Illegal parameter, valuePattern must be double,double...");
+    }
     validator
         .validateInputSeriesNumber(1)
         .validateInputSeriesDataType(
             0, Type.INT32, Type.INT64, Type.FLOAT, Type.DOUBLE, Type.BOOLEAN)
-        .validateRequiredAttribute("timePattern")
-        .validateRequiredAttribute("valuePattern")
-        .validateRequiredAttribute("threshold");
+        .validateRequiredAttribute(THRESHOLD_PARAM)
+        .validateRequiredAttribute(TIME_PATTERN_PARAM)
+        .validateRequiredAttribute(VALUE_PATTERN_PARAM)
+        .validate(
+            (UDFParameterValidator.SingleObjectValidationRule)
+                payload -> ((Long[]) payload).length > 1,
+            "Illegal parameter, timePattern size must larger 1.",
+            timePattern)
+        .validate(
+            (UDFParameterValidator.SingleObjectValidationRule)
+                payload ->
+                    IntStream.range(1, ((Long[]) payload).length)
+                        .allMatch(i -> ((Long[]) payload)[i] > ((Long[]) 
payload)[i - 1]),
+            "Illegal parameter, timePattern value must be in ascending order.",
+            timePattern)
+        .validate(
+            payload -> ((Long[]) payload[0]).length == ((Double[]) 
payload[1]).length,
+            "Illegal parameter, timePattern size must equals valuePattern 
size.",
+            timePattern,
+            valuePattern);
   }
 
   private double getValue(Column column, int i) {
diff --git 
a/library-udf/src/test/java/org/apache/iotdb/library/UDAFPatternTest.java 
b/library-udf/src/test/java/org/apache/iotdb/library/UDAFPatternTest.java
index 4cca64a4fcb..6f1036ff0b5 100644
--- a/library-udf/src/test/java/org/apache/iotdb/library/UDAFPatternTest.java
+++ b/library-udf/src/test/java/org/apache/iotdb/library/UDAFPatternTest.java
@@ -20,9 +20,16 @@
 package org.apache.iotdb.library;
 
 import org.apache.iotdb.library.match.PatternExecutor;
+import org.apache.iotdb.library.match.UDAFPatternMatch;
 import org.apache.iotdb.library.match.model.PatternContext;
 import org.apache.iotdb.library.match.model.PatternResult;
 import org.apache.iotdb.library.match.model.Point;
+import org.apache.iotdb.udf.api.customizer.parameter.UDFParameterValidator;
+import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters;
+import org.apache.iotdb.udf.api.exception.UDFAttributeNotProvidedException;
+import 
org.apache.iotdb.udf.api.exception.UDFInputSeriesNumberNotValidException;
+import org.apache.iotdb.udf.api.exception.UDFParameterNotValidException;
+import org.apache.iotdb.udf.api.type.Type;
 
 import org.apache.commons.lang3.StringUtils;
 import org.junit.Assert;
@@ -33,7 +40,9 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 public class UDAFPatternTest {
   private final PatternExecutor executor = new PatternExecutor();
@@ -85,4 +94,65 @@ public class UDAFPatternTest {
     Assert.assertNotNull(results);
     Assert.assertEquals(1, results.size());
   }
+
+  @Test
+  public void testParameterValidator() {
+    UDAFPatternMatch patternMatch = new UDAFPatternMatch();
+    List<String> stringList = new ArrayList<>();
+    List<Type> typeList = new ArrayList<>();
+    Map<String, String> userAttributes = new HashMap<>();
+    userAttributes.put("timePattern", "1,2,3");
+    userAttributes.put("valuePattern", "1.0,2.0");
+    userAttributes.put("threshold", "100");
+
+    UDFParameterValidator validator =
+        new UDFParameterValidator(new UDFParameters(stringList, typeList, 
userAttributes));
+
+    Assert.assertThrows(
+        UDFInputSeriesNumberNotValidException.class, () -> 
patternMatch.validate(validator));
+
+    stringList.add("s1");
+    typeList.add(Type.FLOAT);
+    userAttributes.clear();
+    Assert.assertThrows(
+        "Illegal parameter, timePattern must be long,long...",
+        UDFParameterNotValidException.class,
+        () -> patternMatch.validate(validator));
+
+    userAttributes.put("timePattern", "1,3,2");
+    Assert.assertThrows(
+        "Illegal parameter, valuePattern must be double,double...",
+        UDFParameterNotValidException.class,
+        () -> patternMatch.validate(validator));
+
+    userAttributes.put("valuePattern", "1.0,2.0");
+    Assert.assertThrows(
+        "Illegal parameter, timePattern size must equals valuePattern size",
+        UDFParameterNotValidException.class,
+        () -> patternMatch.validate(validator));
+
+    userAttributes.remove("valuePattern");
+    userAttributes.put("valuePattern", "1.0,2.0,3.0");
+
+    Assert.assertThrows(
+        "Illegal parameter, timePattern value must be in ascending order.",
+        UDFParameterNotValidException.class,
+        () -> patternMatch.validate(validator));
+
+    userAttributes.remove("timePattern");
+    userAttributes.put("timePattern", "1,2,3");
+
+    Assert.assertThrows(
+        "attribute threshold is required but was not provided.",
+        UDFAttributeNotProvidedException.class,
+        () -> patternMatch.validate(validator));
+
+    userAttributes.put("threshold", "100");
+
+    try {
+      patternMatch.validate(validator);
+    } catch (Exception e) {
+      Assert.fail("Should not throw exception");
+    }
+  }
 }

Reply via email to