JackieTien97 commented on code in PR #15577:
URL: https://github.com/apache/iotdb/pull/15577#discussion_r2162918093


##########
example/udf/src/main/java/org/apache/iotdb/udf/table/MatchingKeyState.java:
##########
@@ -0,0 +1,623 @@
+package org.apache.iotdb.udf.table;
+
+import org.apache.iotdb.udf.api.State;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+
+public class MatchingKeyState implements State {
+
+  int labelIndex;
+  double min_support;
+  double min_confidence;
+  int columnLength;
+  int totalPairs;
+  int[] distanceMin;
+  int[] distanceMax;
+  List<TupleEntry> tupleList;
+  Set<String> positivePairs;
+  Set<String> negativePairs;
+  Map<Integer, Set<Integer>> distanceMap;
+  Map<String, int[]> distanceCache;
+  Set<String> CandidateCache;

Review Comment:
   make it private



##########
example/udf/src/main/java/org/apache/iotdb/udf/UDAFMKIdentify.java:
##########
@@ -0,0 +1,157 @@
+package org.apache.iotdb.udf;
+
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.customizer.analysis.AggregateFunctionAnalysis;
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments;
+import org.apache.iotdb.udf.api.exception.UDFArgumentNotValidException;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.type.Type;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+import org.apache.iotdb.udf.table.MatchingKeyState;
+
+import org.apache.tsfile.utils.Binary;
+
+import java.nio.charset.Charset;
+import java.util.List;
+import java.util.Set;
+
+public class UDAFMKIdentify implements AggregateFunction {
+  int label;
+  int length;
+  double min_confidence;
+  double min_support;
+  String a;
+  int l = 0;
+
+  private FunctionArguments arguments;

Review Comment:
   don't store this in your function, `UDAFMKIdentify.analyze` will only be 
called in query plan stage, and then other methods like `addInput` will be 
called in query execution stage. It's not the same object in these two stages, 
so the field you init in analyze won't be accessible to `addInput`



##########
example/udf/src/main/java/org/apache/iotdb/udf/UDAFMKIdentify.java:
##########
@@ -0,0 +1,157 @@
+package org.apache.iotdb.udf;
+
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.customizer.analysis.AggregateFunctionAnalysis;
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments;
+import org.apache.iotdb.udf.api.exception.UDFArgumentNotValidException;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.type.Type;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+import org.apache.iotdb.udf.table.MatchingKeyState;
+
+import org.apache.tsfile.utils.Binary;
+
+import java.nio.charset.Charset;
+import java.util.List;
+import java.util.Set;
+
+public class UDAFMKIdentify implements AggregateFunction {
+  int label;
+  int length;
+  double min_confidence;
+  double min_support;
+  String a;
+  int l = 0;

Review Comment:
   make them private



##########
example/udf/src/main/java/org/apache/iotdb/udf/table/MatchingKeyState.java:
##########
@@ -0,0 +1,623 @@
+package org.apache.iotdb.udf.table;
+
+import org.apache.iotdb.udf.api.State;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+
+public class MatchingKeyState implements State {
+
+  int labelIndex;
+  double min_support;
+  double min_confidence;
+  int columnLength;
+  int totalPairs;
+  int[] distanceMin;
+  int[] distanceMax;
+  List<TupleEntry> tupleList;
+  Set<String> positivePairs;
+  Set<String> negativePairs;
+  Map<Integer, Set<Integer>> distanceMap;
+  Map<String, int[]> distanceCache;
+  Set<String> CandidateCache;
+
+  public MatchingKeyState() {
+    reset();
+  }
+
+  public void init(int labelIndex, double min_support, double min_confidence) {
+    this.labelIndex = labelIndex;
+    this.min_support = min_support;
+    this.min_confidence = min_confidence;
+  }
+
+  public void add(int index, long time, String[] fullTuple) {
+    if (fullTuple.length <= labelIndex) {
+      throw new IllegalArgumentException("Tuple length is less than label 
index");
+    }
+    tupleList.add(new TupleEntry(index, time, fullTuple));
+  }
+
+  public Candidate computeAllPairs() {
+    for (int i = 0; i < tupleList.size(); i++) {
+      for (int j = i + 1; j < tupleList.size(); j++) {
+        totalPairs++;
+        TupleEntry t1 = tupleList.get(i);
+        TupleEntry t2 = tupleList.get(j);
+        int index = 0;
+        int[] distances = new int[columnLength];
+        int timeDiff = (int) Math.abs((t1.time - t2.time) / 1000);
+        distances[index] = timeDiff;
+        distanceMap.get(index).add(timeDiff);
+        for (int k = 0; k < columnLength; k++) {
+          if (k != labelIndex) {
+            index++;
+            int distance = editDistance(t1.tuple[k], t2.tuple[k]);
+            distances[index] = distance;
+            distanceMap.get(index).add(distance);
+          }
+        }
+        String label1 = t1.tuple[labelIndex];
+        String label2 = t2.tuple[labelIndex];
+        String key = i + "+" + j;
+        if (label1.equals(label2)) {
+          positivePairs.add(key);
+          distanceCache.put(key, distances);
+        } else {
+          negativePairs.add(key);
+          distanceCache.put(key, distances);
+        }
+      }
+    }
+    List<int[]> distanceRestrictions = new ArrayList<>();
+    for (int i = 0; i < distanceMap.size(); i++) {
+      Set<Integer> distances = distanceMap.get(i);
+      if (distances == null || distances.isEmpty()) {
+        distanceRestrictions.add(new int[] {0, 0});
+      } else {
+        int min = Collections.min(distances);
+        int max = Collections.max(distances);
+        distanceRestrictions.add(new int[] {min, max});
+        distanceMin[i] = min;
+        distanceMax[i] = max;
+      }
+    }
+    Candidate psi = new Candidate(distanceRestrictions, totalPairs);
+    for (String pair : positivePairs) {
+      psi.addpositive(pair);
+    }
+    for (String pair : negativePairs) {
+      psi.addNegative(pair);
+    }
+    return psi;
+  }
+
+  public void reset() {
+    tupleList = new ArrayList<>();
+    positivePairs = new HashSet<>();
+    negativePairs = new HashSet<>();
+    totalPairs = 0;
+    columnLength = 0;
+    distanceMap = new HashMap<>();
+    distanceCache = new HashMap<>();
+    CandidateCache = new HashSet<>();
+    labelIndex = 0;
+    min_support = 0.0;
+    min_confidence = 0.0;
+  }
+
+  public byte[] serialize() {

Review Comment:
   add UT for serialize and deserialize method, make sure you can deserialize 
from what you serialize and the deserialized object is same as the previous one



##########
example/udf/src/main/java/org/apache/iotdb/udf/UDAFMKIdentify.java:
##########
@@ -0,0 +1,157 @@
+package org.apache.iotdb.udf;
+
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.customizer.analysis.AggregateFunctionAnalysis;
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments;
+import org.apache.iotdb.udf.api.exception.UDFArgumentNotValidException;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.type.Type;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+import org.apache.iotdb.udf.table.MatchingKeyState;
+
+import org.apache.tsfile.utils.Binary;
+
+import java.nio.charset.Charset;
+import java.util.List;
+import java.util.Set;
+
+public class UDAFMKIdentify implements AggregateFunction {
+  int label;
+  int length;
+  double min_confidence;
+  double min_support;
+  String a;
+  int l = 0;
+
+  private FunctionArguments arguments;
+
+  public AggregateFunctionAnalysis analyze(FunctionArguments arguments)
+      throws UDFArgumentNotValidException {
+    int num = arguments.getArgumentsSize();
+    if (num < 5) {
+      throw new UDFArgumentNotValidException("At least 2 columns and 3 
parameters are required.");
+    }
+
+    this.arguments = arguments;
+    Type thirdLastType = arguments.getDataType(num - 3);
+    Type secondLastType = arguments.getDataType(num - 2);
+    Type lastType = arguments.getDataType(num - 1);
+
+    if (thirdLastType != Type.INT32) {
+      throw new UDFArgumentNotValidException(
+          String.format(
+              "The third last parameter must be of type INT, but found: %s", 
thirdLastType));
+    }
+
+    if (secondLastType != Type.DOUBLE) {
+      throw new UDFArgumentNotValidException(
+          String.format(
+              "The second last parameter must be of type DOUBLE, but found: 
%s", secondLastType));
+    }
+
+    if (lastType != Type.DOUBLE) {
+      throw new UDFArgumentNotValidException(
+          String.format("The last parameter must be of type DOUBLE, but found: 
%s", lastType));
+    }
+
+    return new AggregateFunctionAnalysis.Builder()
+        .outputDataType(Type.TEXT)
+        .removable(true)
+        .build();
+  }
+
+  @Override
+  public State createState() {
+    return new MatchingKeyState();
+  }
+
+  @Override
+  public void addInput(State state, Record input) {
+    MatchingKeyState mkState = (MatchingKeyState) state;
+    int num = input.size();
+    length = num - 3;
+    label = input.getInt(length) - 1;
+    min_support = input.getDouble(length + 1);
+    min_confidence = input.getDouble(length + 2);
+    mkState.init(label, min_support, min_confidence);
+    length = length - 1;
+    mkState.setColumnLength(length);
+    long time = input.getLong(0);
+    String[] fullTuple = new String[length];
+    for (int i = 0; i < length; i++) {
+      Type col0Type = arguments.getDataType(i + 1);

Review Comment:
   you can use `Type getDataType(int columnIndex);` in Record instead



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to