TAJO-1562: Python UDAF support. (jihoon) Closes #551
Project: http://git-wip-us.apache.org/repos/asf/tajo/repo Commit: http://git-wip-us.apache.org/repos/asf/tajo/commit/9540f16e Tree: http://git-wip-us.apache.org/repos/asf/tajo/tree/9540f16e Diff: http://git-wip-us.apache.org/repos/asf/tajo/diff/9540f16e Branch: refs/heads/index_support Commit: 9540f16edb0de1a66b016b8a7b65568cc2d64709 Parents: 9d73267 Author: Jihoon Son <[email protected]> Authored: Wed Apr 29 10:50:50 2015 +0900 Committer: Jihoon Son <[email protected]> Committed: Wed Apr 29 10:52:20 2015 +0900 ---------------------------------------------------------------------- CHANGES | 2 + .../tajo/algebra/GeneralSetFunctionExpr.java | 1 - .../tajo/function/FunctionInvocation.java | 19 +- .../tajo/function/PythonInvocationDesc.java | 48 +-- .../src/main/proto/CatalogProtos.proto | 1 + .../tajo/engine/function/FunctionLoader.java | 2 +- .../engine/planner/global/GlobalPlanner.java | 13 +- .../src/main/resources/python/controller.py | 283 ++++++++++---- .../src/main/resources/python/tajo_util.py | 5 +- .../engine/function/TestPythonFunctions.java | 4 +- .../tajo/engine/query/TestGroupByQuery.java | 52 +++ .../src/test/resources/python/test_funcs.py | 10 +- .../src/test/resources/python/test_funcs.pyc | Bin 1042 -> 0 bytes .../src/test/resources/python/test_funcs2.py | 8 +- .../src/test/resources/python/test_udaf.py | 76 ++++ .../testComplexTargetWithPythonUdaf.sql | 1 + .../testDistinctPythonUdafWithUnion1.sql | 21 ++ .../testGroupbyWithPythonFunc.sql | 2 +- .../queries/TestGroupByQuery/testPythonUdaf.sql | 1 + .../TestGroupByQuery/testPythonUdaf2.sql | 1 + .../TestGroupByQuery/testPythonUdaf3.sql | 1 + .../testPythonUdafWithHaving.sql | 3 + .../testPythonUdafWithNullData.sql | 4 + .../testNestedPythonFunction.sql | 2 +- .../TestSelectQuery/testSelectPythonFuncs.sql | 2 +- .../testSelectWithPredicateOnPythonFunc.sql | 2 +- .../testComplexTargetWithPythonUdaf.result | 3 + .../testDistinctPythonUdafWithUnion1.result | 4 + .../TestGroupByQuery/testPythonUdaf.result | 3 + .../TestGroupByQuery/testPythonUdaf2.result | 4 + .../TestGroupByQuery/testPythonUdaf3.result | 5 + .../testPythonUdafWithHaving.result | 4 + .../testPythonUdafWithNullData.result | 2 + tajo-docs/src/main/sphinx/functions.rst | 61 +--- tajo-docs/src/main/sphinx/functions/python.rst | 159 ++++++++ .../org/apache/tajo/plan/ExprAnnotator.java | 24 +- .../plan/expr/AggregationFunctionCallEval.java | 63 +++- .../org/apache/tajo/plan/expr/EvalContext.java | 3 +- .../tajo/plan/expr/WindowFunctionEval.java | 14 +- .../apache/tajo/plan/function/AggFunction.java | 10 + .../tajo/plan/function/AggFunctionInvoke.java | 88 +++++ .../function/ClassBasedAggFunctionInvoke.java | 82 +++++ .../ClassBasedScalarFunctionInvoke.java | 80 ++++ .../tajo/plan/function/FunctionInvoke.java | 4 +- .../plan/function/FunctionInvokeContext.java | 24 +- .../function/LegacyScalarFunctionInvoke.java | 81 ---- .../plan/function/PythonAggFunctionInvoke.java | 139 +++++++ .../plan/function/PythonFunctionInvoke.java | 2 +- .../function/python/PythonScriptEngine.java | 365 +++++++++++++++---- .../plan/function/python/TajoScriptEngine.java | 21 +- .../tajo/plan/function/stream/BufferPool.java | 65 +++- .../function/stream/CSVLineDeserializer.java | 17 +- .../tajo/plan/function/stream/CSVLineSerDe.java | 4 +- .../plan/function/stream/CSVLineSerializer.java | 61 +++- .../stream/FieldSerializerDeserializer.java | 6 +- .../tajo/plan/function/stream/InputHandler.java | 19 +- .../plan/function/stream/OutputHandler.java | 73 +++- .../plan/function/stream/StreamingUtil.java | 12 + .../stream/TextFieldSerializerDeserializer.java | 14 +- .../function/stream/TextLineDeserializer.java | 3 + .../plan/function/stream/TextLineSerDe.java | 2 +- .../function/stream/TextLineSerializer.java | 9 +- .../tajo/plan/serder/EvalNodeDeserializer.java | 10 +- 63 files changed, 1673 insertions(+), 431 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/CHANGES ---------------------------------------------------------------------- diff --git a/CHANGES b/CHANGES index c7a8edd..8bda2bd 100644 --- a/CHANGES +++ b/CHANGES @@ -4,6 +4,8 @@ Release 0.11.0 - unreleased NEW FEATURES + TAJO-1562: Python UDAF support. (jihoon) + TAJO-1344: Python UDF support. (jihoon) TAJO-923: Add VAR_SAMP and VAR_POP window functions. http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-algebra/src/main/java/org/apache/tajo/algebra/GeneralSetFunctionExpr.java ---------------------------------------------------------------------- diff --git a/tajo-algebra/src/main/java/org/apache/tajo/algebra/GeneralSetFunctionExpr.java b/tajo-algebra/src/main/java/org/apache/tajo/algebra/GeneralSetFunctionExpr.java index c10bd76..2ea9a29 100644 --- a/tajo-algebra/src/main/java/org/apache/tajo/algebra/GeneralSetFunctionExpr.java +++ b/tajo-algebra/src/main/java/org/apache/tajo/algebra/GeneralSetFunctionExpr.java @@ -14,7 +14,6 @@ package org.apache.tajo.algebra; -import com.google.common.base.Preconditions; import com.google.gson.annotations.Expose; import com.google.gson.annotations.SerializedName; http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-catalog/tajo-catalog-common/src/main/java/org/apache/tajo/function/FunctionInvocation.java ---------------------------------------------------------------------- diff --git a/tajo-catalog/tajo-catalog-common/src/main/java/org/apache/tajo/function/FunctionInvocation.java b/tajo-catalog/tajo-catalog-common/src/main/java/org/apache/tajo/function/FunctionInvocation.java index 911d5dd..ae57e1d 100644 --- a/tajo-catalog/tajo-catalog-common/src/main/java/org/apache/tajo/function/FunctionInvocation.java +++ b/tajo-catalog/tajo-catalog-common/src/main/java/org/apache/tajo/function/FunctionInvocation.java @@ -127,7 +127,7 @@ public class FunctionInvocation implements ProtoObject<FunctionInvocationProto> } public boolean hasPython() { - return python != null; + return python != null && python.isScalarFunction(); } public void setPython(PythonInvocationDesc python) { @@ -138,6 +138,18 @@ public class FunctionInvocation implements ProtoObject<FunctionInvocationProto> return python; } + public boolean hasPythonAggregation() { + return python != null && !python.isScalarFunction(); + } + + public void setPythonAggregation(PythonInvocationDesc pythonAggregation) { + this.python = pythonAggregation; + } + + public PythonInvocationDesc getPythonAggregation() { + return this.python; + } + @Override public FunctionInvocationProto getProto() { FunctionInvocationProto.Builder builder = FunctionInvocationProto.newBuilder(); @@ -156,7 +168,7 @@ public class FunctionInvocation implements ProtoObject<FunctionInvocationProto> if (hasAggregationJIT()) { builder.setAggregationJIT(aggregationJIT.getProto()); } - if (hasPython()) { + if (hasPython() || hasPythonAggregation()) { builder.setPython(python.getProto()); } return builder.build(); @@ -169,6 +181,7 @@ public class FunctionInvocation implements ProtoObject<FunctionInvocationProto> public String toString() { return "legacy=" + hasLegacy() + ",scalar=" + hasScalar() + ",agg=" + hasAggregation() + - ",scalarJIT=" + hasScalarJIT() + ",aggJIT=" + hasAggregationJIT() + ",python=" + hasPython(); + ",scalarJIT=" + hasScalarJIT() + ",aggJIT=" + hasAggregationJIT() + ",python=" + hasPython() + + ",aggPython=" + hasPythonAggregation(); } } http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-catalog/tajo-catalog-common/src/main/java/org/apache/tajo/function/PythonInvocationDesc.java ---------------------------------------------------------------------- diff --git a/tajo-catalog/tajo-catalog-common/src/main/java/org/apache/tajo/function/PythonInvocationDesc.java b/tajo-catalog/tajo-catalog-common/src/main/java/org/apache/tajo/function/PythonInvocationDesc.java index 160b169..d3365e5 100644 --- a/tajo-catalog/tajo-catalog-common/src/main/java/org/apache/tajo/function/PythonInvocationDesc.java +++ b/tajo-catalog/tajo-catalog-common/src/main/java/org/apache/tajo/function/PythonInvocationDesc.java @@ -29,20 +29,21 @@ import org.apache.tajo.util.TUtil; * and a file path to the script where the function is defined. */ public class PythonInvocationDesc implements ProtoObject<PythonInvocationDescProto>, Cloneable { - @Expose private String funcName; - @Expose private String filePath; - - public PythonInvocationDesc() { - - } - - public PythonInvocationDesc(String funcName, String filePath) { - this.funcName = funcName; + @Expose private boolean isScalarFunction; // true if udf, false if udaf + @Expose private String funcOrClassName; // function name if udf, class name if udaf + @Expose private String filePath; // file path to the python module + + /** + * Constructor of {@link PythonInvocationDesc}. + * + * @param funcOrClassName if udf, function name. else, class name. + * @param filePath path to script file + * @param isScalarFunction + */ + public PythonInvocationDesc(String funcOrClassName, String filePath, boolean isScalarFunction) { + this.funcOrClassName = funcOrClassName; this.filePath = filePath; - } - - public void setFuncName(String funcName) { - this.funcName = funcName; + this.isScalarFunction = isScalarFunction; } public void setFilePath(String filePath) { @@ -50,21 +51,25 @@ public class PythonInvocationDesc implements ProtoObject<PythonInvocationDescPro } public PythonInvocationDesc(PythonInvocationDescProto proto) { - this(proto.getFuncName(), proto.getFilePath()); + this(proto.getFuncName(), proto.getFilePath(), proto.getIsScalarFunction()); } public String getName() { - return funcName; + return funcOrClassName; } public String getPath() { return filePath; } + public boolean isScalarFunction() { + return this.isScalarFunction; + } + @Override public PythonInvocationDescProto getProto() { PythonInvocationDescProto.Builder builder = PythonInvocationDescProto.newBuilder(); - builder.setFuncName(funcName).setFilePath(filePath); + builder.setFuncName(funcOrClassName).setFilePath(filePath).setIsScalarFunction(isScalarFunction); return builder.build(); } @@ -72,27 +77,28 @@ public class PythonInvocationDesc implements ProtoObject<PythonInvocationDescPro public boolean equals(Object o) { if (o instanceof PythonInvocationDesc) { PythonInvocationDesc other = (PythonInvocationDesc) o; - return TUtil.checkEquals(funcName, other.funcName) && - TUtil.checkEquals(filePath, other.filePath); + return TUtil.checkEquals(funcOrClassName, other.funcOrClassName) && + TUtil.checkEquals(filePath, other.filePath) && isScalarFunction == other.isScalarFunction; } return false; } @Override public int hashCode() { - return Objects.hashCode(funcName, filePath); + return Objects.hashCode(funcOrClassName, filePath, isScalarFunction); } @Override public String toString() { - return funcName + " at " + filePath; + return isScalarFunction ? "[UDF] " : "[UDAF] " + funcOrClassName + " at " + filePath; } @Override public Object clone() throws CloneNotSupportedException { PythonInvocationDesc clone = (PythonInvocationDesc) super.clone(); - clone.funcName = funcName == null ? null : funcName; + clone.funcOrClassName = funcOrClassName == null ? null : funcOrClassName; clone.filePath = filePath == null ? null : filePath; + clone.isScalarFunction = isScalarFunction; return clone; } } http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-catalog/tajo-catalog-common/src/main/proto/CatalogProtos.proto ---------------------------------------------------------------------- diff --git a/tajo-catalog/tajo-catalog-common/src/main/proto/CatalogProtos.proto b/tajo-catalog/tajo-catalog-common/src/main/proto/CatalogProtos.proto index fd2cb19..c467b4e 100644 --- a/tajo-catalog/tajo-catalog-common/src/main/proto/CatalogProtos.proto +++ b/tajo-catalog/tajo-catalog-common/src/main/proto/CatalogProtos.proto @@ -428,4 +428,5 @@ message StaticMethodInvocationDescProto { message PythonInvocationDescProto { required string funcName = 1; required string filePath = 2; + required bool isScalarFunction = 3; } http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/main/java/org/apache/tajo/engine/function/FunctionLoader.java ---------------------------------------------------------------------- diff --git a/tajo-core/src/main/java/org/apache/tajo/engine/function/FunctionLoader.java b/tajo-core/src/main/java/org/apache/tajo/engine/function/FunctionLoader.java index 6061d1b..b7e4085 100644 --- a/tajo-core/src/main/java/org/apache/tajo/engine/function/FunctionLoader.java +++ b/tajo-core/src/main/java/org/apache/tajo/engine/function/FunctionLoader.java @@ -84,7 +84,7 @@ public class FunctionLoader { } /** - * Load functions that are defined by users. + * Load functions defined by users. * * @param conf * @param functionMap http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/main/java/org/apache/tajo/engine/planner/global/GlobalPlanner.java ---------------------------------------------------------------------- diff --git a/tajo-core/src/main/java/org/apache/tajo/engine/planner/global/GlobalPlanner.java b/tajo-core/src/main/java/org/apache/tajo/engine/planner/global/GlobalPlanner.java index 54b920f..36bdf21 100644 --- a/tajo-core/src/main/java/org/apache/tajo/engine/planner/global/GlobalPlanner.java +++ b/tajo-core/src/main/java/org/apache/tajo/engine/planner/global/GlobalPlanner.java @@ -536,31 +536,31 @@ public class GlobalPlanner { private AggregationFunctionCallEval createSumFunction(EvalNode[] args) throws InternalException { FunctionDesc functionDesc = getCatalog().getFunction("sum", CatalogProtos.FunctionType.AGGREGATION, args[0].getValueType()); - return new AggregationFunctionCallEval(functionDesc, (AggFunction) functionDesc.newInstance(), args); + return new AggregationFunctionCallEval(functionDesc, args); } private AggregationFunctionCallEval createCountFunction(EvalNode [] args) throws InternalException { FunctionDesc functionDesc = getCatalog().getFunction("count", CatalogProtos.FunctionType.AGGREGATION, args[0].getValueType()); - return new AggregationFunctionCallEval(functionDesc, (AggFunction) functionDesc.newInstance(), args); + return new AggregationFunctionCallEval(functionDesc, args); } private AggregationFunctionCallEval createCountRowFunction(EvalNode[] args) throws InternalException { FunctionDesc functionDesc = getCatalog().getFunction("count", CatalogProtos.FunctionType.AGGREGATION, new TajoDataTypes.DataType[]{}); - return new AggregationFunctionCallEval(functionDesc, (AggFunction) functionDesc.newInstance(), args); + return new AggregationFunctionCallEval(functionDesc, args); } private AggregationFunctionCallEval createMaxFunction(EvalNode [] args) throws InternalException { FunctionDesc functionDesc = getCatalog().getFunction("max", CatalogProtos.FunctionType.AGGREGATION, args[0].getValueType()); - return new AggregationFunctionCallEval(functionDesc, (AggFunction) functionDesc.newInstance(), args); + return new AggregationFunctionCallEval(functionDesc, args); } private AggregationFunctionCallEval createMinFunction(EvalNode [] args) throws InternalException { FunctionDesc functionDesc = getCatalog().getFunction("min", CatalogProtos.FunctionType.AGGREGATION, args[0].getValueType()); - return new AggregationFunctionCallEval(functionDesc, (AggFunction) functionDesc.newInstance(), args); + return new AggregationFunctionCallEval(functionDesc, args); } /** @@ -960,8 +960,9 @@ public class GlobalPlanner { firstPhaseEvals[i].setFirstPhase(); firstPhaseEvalNames[i] = plan.generateUniqueColumnName(firstPhaseEvals[i]); FieldEval param = new FieldEval(firstPhaseEvalNames[i], firstPhaseEvals[i].getValueType()); + secondPhaseEvals[i].setFinalPhase(); - secondPhaseEvals[i].setArgs(new EvalNode[] {param}); + secondPhaseEvals[i].setArgs(new EvalNode[]{param}); } secondPhaseGroupBy.setAggFunctions(secondPhaseEvals); http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/main/resources/python/controller.py ---------------------------------------------------------------------- diff --git a/tajo-core/src/main/resources/python/controller.py b/tajo-core/src/main/resources/python/controller.py index d969b34..126ccdc 100644 --- a/tajo-core/src/main/resources/python/controller.py +++ b/tajo-core/src/main/resources/python/controller.py @@ -19,6 +19,7 @@ import sys import os import logging import base64 +import json from datetime import datetime try: @@ -69,17 +70,42 @@ TYPE_DATETIME = "T" TYPE_BIGINTEGER = "N" TYPE_BIGDECIMAL = "E" +EVAL_FUNC = "eval" +MERGE_FUNC = "merge" +GET_PARTIAL_RESULT_FUNC = "get_partial_result" +GET_FINAL_RESULT_FUNC = "get_final_result" +GET_INTERM_SCHEMA_FUNC = "get_interm_schema" +UPDATE_CONTEXT = "update_context" +GET_CONTEXT = "get_context" + +WRAPPED_EVAL_FUNC = PRE_WRAP_DELIM + EVAL_FUNC + POST_WRAP_DELIM +WRAPPED_MERGE_FUNC = PRE_WRAP_DELIM + MERGE_FUNC + POST_WRAP_DELIM +WRAPPED_GET_PARTIAL_RESULT_FUNC = PRE_WRAP_DELIM + GET_PARTIAL_RESULT_FUNC + POST_WRAP_DELIM +WRAPPED_GET_FINAL_RESULT_FUNC = PRE_WRAP_DELIM + GET_FINAL_RESULT_FUNC + POST_WRAP_DELIM +WRAPPED_GET_INTERM_SCHEMA_FUNC = PRE_WRAP_DELIM + GET_INTERM_SCHEMA_FUNC + POST_WRAP_DELIM +WRAPPED_UPDATE_CONTEXT = PRE_WRAP_DELIM + UPDATE_CONTEXT + POST_WRAP_DELIM +WRAPPED_GET_CONTEXT = PRE_WRAP_DELIM + GET_CONTEXT + POST_WRAP_DELIM + END_OF_STREAM = TYPE_CHARARRAY + "\x04" + END_RECORD_DELIM TURN_ON_OUTPUT_CAPTURING = TYPE_CHARARRAY + "TURN_ON_OUTPUT_CAPTURING" + END_RECORD_DELIM NUM_LINES_OFFSET_TRACE = int(os.environ.get('PYTHON_TRACE_OFFSET', 0)) + class PythonStreamingController: + scalar_func = None + udaf_instance = None + + should_log = False + log_message = logging.info + module_name = None + output_schema = None + def __init__(self, profiling_mode=False): self.profiling_mode = profiling_mode def main(self, - module_name, file_path, func_name, cache_path, - output_stream_path, error_stream_path, log_file_name, output_schema): + module_name, file_path, cache_path, + output_stream_path, error_stream_path, log_file_name, output_schema, name, func_type): sys.stdin = os.fdopen(sys.stdin.fileno(), 'rb', 0) # Need to ensure that user functions can't write to the streams we use to communicate with pig. @@ -88,73 +114,107 @@ class PythonStreamingController: self.input_stream = sys.stdin # TODO: support controller logging - # self.log_stream = open(output_stream_path, 'a') - # sys.stderr = open(error_stream_path, 'w') + self.log_stream = open(output_stream_path, 'a') + sys.stderr = open(error_stream_path, 'w') sys.path.append(file_path) sys.path.append(cache_path) sys.path.append('.') - should_log = False - if should_log: + if self.should_log: logging.basicConfig(filename=log_file_name, format="%(asctime)s %(levelname)s %(message)s", level=udf_logging.udf_log_level) logging.info("To reduce the amount of information being logged only a small subset of rows are logged at the " "INFO level. Call udf_logging.set_log_level_debug in tajo_util to see all rows being processed.") + self.module_name = module_name + self.output_schema = output_schema input_str = self.get_next_input() - try: - func = __import__(module_name, globals(), locals(), [func_name], -1).__dict__[func_name] - except: - # These errors should always be caused by user code. - write_user_exception(module_name, self.stream_error, NUM_LINES_OFFSET_TRACE) - self.close_controller(-1) - - log_message = logging.info if udf_logging.udf_log_level == logging.DEBUG: - log_message = logging.debug + self.log_message = logging.debug while input_str != END_OF_STREAM: - try: - try: - if should_log: - log_message("Serialized Input: %s" % (input_str)) - inputs = deserialize_input(input_str) - if should_log: - log_message("Deserialized Input: %s" % (unicode(inputs))) - except: - # Capture errors where the user passes in bad data. - write_user_exception(module_name, self.stream_error, NUM_LINES_OFFSET_TRACE) - self.close_controller(-3) - - try: - func_output = func(*inputs) - if should_log: - log_message("UDF Output: %s" % (unicode(func_output))) - except: - # These errors should always be caused by user code. - write_user_exception(module_name, self.stream_error, NUM_LINES_OFFSET_TRACE) - self.close_controller(-2) - - output = serialize_output(func_output, output_schema) - if should_log: - log_message("Serialized Output: %s" % (output)) - - self.stream_output.write( "%s%s" % (output, END_RECORD_DELIM) ) - except Exception as e: - # This should only catch internal exceptions with the controller - # and pig- not with user code. - import traceback - traceback.print_exc(file=self.stream_error) - sys.exit(-3) - - sys.stdout.flush() - sys.stderr.flush() - self.stream_output.flush() - self.stream_error.flush() + + if func_type == 'UDAF': + class_name = name + func_name = self.get_func_name(input_str) + data_start = input_str.find(WRAPPED_PARAMETER_DELIMITER) + len(WRAPPED_PARAMETER_DELIMITER) + input_str = input_str[data_start:] + + if func_name == UPDATE_CONTEXT: + self.update_context(input_str) + elif func_name == GET_CONTEXT: + self.get_context() + else: + func = self.load_udaf(module_name, class_name, func_name) + if func_name == MERGE_FUNC: + json_data = input_str.split(WRAPPED_PARAMETER_DELIMITER)[1] + deserialized = json.loads(json_data) + func(deserialized) + self.stream_output.write(END_RECORD_DELIM) + sys.stdout.flush() + sys.stderr.flush() + self.stream_output.flush() + self.stream_error.flush() + del deserialized + del json_data + else: + self.process_input(func_name, func, input_str) + + elif func_type == 'UDF': + func_name = name + if self.scalar_func is None: + self.scalar_func = self.load_udf(module_name, func_name) + self.process_input(func_name, self.scalar_func, input_str) + else: + raise Exception("Unsupported type: " + func_type) input_str = self.get_next_input() + def process_input(self, func_name, func, input_str): + try: + try: + if self.should_log: + self.log_message("Serialized Input: %s" % (input_str)) + inputs = deserialize_input(input_str) + if self.should_log: + self.log_message("Deserialized Input: %s" % (unicode(inputs))) + except: + # Capture errors where the user passes in bad data. + write_user_exception(self.module_name, self.stream_error, NUM_LINES_OFFSET_TRACE) + self.close_controller(-3) + + try: + if func_name == GET_PARTIAL_RESULT_FUNC: + func_output = func() + output = json.dumps(func_output) + elif func_name == GET_FINAL_RESULT_FUNC: + func_output = func() + output = serialize_output(func_output, self.output_schema) + else: + func_output = func(*inputs) + output = serialize_output(func_output, self.output_schema) + + if self.should_log: + self.log_message("Serialized Output: %s" % output) + except: + # These errors should always be caused by user code. + write_user_exception(self.module_name, self.stream_error, NUM_LINES_OFFSET_TRACE) + self.close_controller(-2) + + self.stream_output.write("%s%s" % (output, END_RECORD_DELIM)) + except Exception as e: + # This should only catch internal exceptions with the controller + # and pig- not with user code. + import traceback + traceback.print_exc(file=self.stream_error) + sys.exit(-3) + + sys.stdout.flush() + sys.stderr.flush() + self.stream_output.flush() + self.stream_error.flush() + def get_next_input(self): input_stream = self.input_stream # log_stream = self.log_stream @@ -185,6 +245,77 @@ class PythonStreamingController: self.stream_output.close() sys.exit(exit_code) + def load_udf(self, module_name, func_name): + try: + func = __import__(module_name, globals(), locals(), [func_name], -1).__dict__[func_name] + return func + except: + # These errors should always be caused by user code. + write_user_exception(module_name, self.stream_error, NUM_LINES_OFFSET_TRACE) + self.close_controller(-1) + + def load_udaf(self, module_name, class_name, func_name): + try: + if self.udaf_instance is None: + clazz = __import__(module_name, globals(), locals(), [class_name]).__dict__[class_name] + self.udaf_instance = clazz() + func = getattr(self.udaf_instance, func_name) + return func + except: + # These errors should always be caused by user code. + write_user_exception(module_name, self.stream_error, NUM_LINES_OFFSET_TRACE) + self.close_controller(-1) + + @staticmethod + def get_func_name(input_str): + splits = input_str.split(WRAPPED_PARAMETER_DELIMITER) + if splits[0] == WRAPPED_EVAL_FUNC: + return EVAL_FUNC + elif splits[0] == WRAPPED_MERGE_FUNC: + return MERGE_FUNC + elif splits[0] == WRAPPED_GET_PARTIAL_RESULT_FUNC: + return GET_PARTIAL_RESULT_FUNC + elif splits[0] == WRAPPED_GET_FINAL_RESULT_FUNC: + return GET_FINAL_RESULT_FUNC + elif splits[0] == WRAPPED_GET_INTERM_SCHEMA_FUNC: + return GET_INTERM_SCHEMA_FUNC + elif splits[0] == WRAPPED_UPDATE_CONTEXT: + return UPDATE_CONTEXT + elif splits[0] == WRAPPED_GET_CONTEXT: + return GET_CONTEXT + else: + raise Exception("Not supported function: " + splits[0]) + + def update_context(self, input_str): + if self.udaf_instance is not None: + deserialize_class(self.udaf_instance, input_str) + self.stream_output.write(END_RECORD_DELIM) + sys.stdout.flush() + sys.stderr.flush() + self.stream_output.flush() + self.stream_error.flush() + + def get_context(self): + serialized = '' + if self.udaf_instance is not None: + serialized = serialize_class(self.udaf_instance) + self.stream_output.write("%s%s" % (serialized, END_RECORD_DELIM)) + sys.stdout.flush() + sys.stderr.flush() + self.stream_output.flush() + self.stream_error.flush() + +def serialize_class(instance): + serialized = json.dumps(instance.__dict__) + return serialized + +def deserialize_class(instance, json_data): + if json_data == NULL_BYTE: + instance.reset() + else: + instance.reset() + instance.__dict__ = json.loads(json_data) + def deserialize_input(input_str): if len(input_str) == 0: return [] @@ -209,30 +340,33 @@ def _deserialize_input(input_str, si, ei): schema = tokens[0]; param = tokens[1]; - if schema == NULL_BYTE: + return deserialize_data(schema, param) + +def deserialize_data(type, data_str): + if type == NULL_BYTE: return None - elif schema == TYPE_CHARARRAY: - return unicode(param, 'utf-8') - elif schema == TYPE_BYTEARRAY: - return bytearray(param) - elif schema == TYPE_INTEGER: - return int(param) - elif schema == TYPE_LONG or schema == TYPE_BIGINTEGER: - return long(param) - elif schema == TYPE_FLOAT or schema == TYPE_DOUBLE or schema == TYPE_BIGDECIMAL: - return float(param) - elif schema == TYPE_BOOLEAN: - return param == "true" - elif schema == TYPE_DATETIME: + elif type == TYPE_CHARARRAY: + return unicode(data_str, 'utf-8') + elif type == TYPE_BYTEARRAY: + return bytearray(data_str) + elif type == TYPE_INTEGER: + return int(data_str) + elif type == TYPE_LONG or type == TYPE_BIGINTEGER: + return long(data_str) + elif type == TYPE_FLOAT or type == TYPE_DOUBLE or type == TYPE_BIGDECIMAL: + return float(data_str) + elif type == TYPE_BOOLEAN: + return data_str == "true" + elif type == TYPE_DATETIME: # Format is "yyyy-MM-ddTHH:mm:ss.SSS+00:00" or "2013-08-23T18:14:03.123+ZZ" if USE_DATEUTIL: - return parser.parse(param) + return parser.parse(data_str) else: # Try to use datetime even though it doesn't handle time zones properly, # We only use the first 3 microsecond digits and drop time zone (first 23 characters) - return datetime.strptime(param, "%Y-%m-%dT%H:%M:%S.%f") + return datetime.strptime(data_str, "%Y-%m-%dT%H:%M:%S.%f") else: - raise Exception("Can't determine type of input: %s" % param) + raise Exception("Can't determine type of input: %s" % data_str) def _deserialize_collection(input_str, return_type, si, ei): list_result = [] @@ -313,6 +447,8 @@ def serialize_output(output, out_schema, utfEncodeAllFields=False): result = str(output) elif output_type == datetime: result = output.isoformat() + elif output_type == list: + result = list_to_str(output, out_schema) elif utfEncodeAllFields or output_type == str or output_type == unicode: # unicode is necessary in cases where we're encoding non-strings. result = unicode(output).encode('utf-8') @@ -324,7 +460,14 @@ def serialize_output(output, out_schema, utfEncodeAllFields=False): else: return result +def list_to_str(list_of_item, out_schema): + result = '' + for item in list_of_item: + result += serialize_output(item, out_schema) + WRAPPED_FIELD_DELIMITER + result = result[:len(result)-len(WRAPPED_FIELD_DELIMITER)] + return result + if __name__ == '__main__': controller = PythonStreamingController() controller.main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], - sys.argv[5], sys.argv[6], sys.argv[7], sys.argv[8]) + sys.argv[5], sys.argv[6], sys.argv[7], sys.argv[8], sys.argv[9]) http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/main/resources/python/tajo_util.py ---------------------------------------------------------------------- diff --git a/tajo-core/src/main/resources/python/tajo_util.py b/tajo-core/src/main/resources/python/tajo_util.py index 77b28a6..20ff734 100644 --- a/tajo-core/src/main/resources/python/tajo_util.py +++ b/tajo-core/src/main/resources/python/tajo_util.py @@ -18,6 +18,7 @@ import logging + class udf_logging(object): udf_log_level = logging.INFO @@ -37,13 +38,15 @@ class udf_logging(object): def set_log_level_debug(cls): cls.udf_log_level = logging.DEBUG -def outputType(type_str): + +def output_type(*type_str): def wrap(f): def wrapped_f(*args): return f(*args) return wrapped_f return wrap + def write_user_exception(filename, stream_err_output, num_lines_offset_trace=0): import sys import traceback http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/java/org/apache/tajo/engine/function/TestPythonFunctions.java ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/java/org/apache/tajo/engine/function/TestPythonFunctions.java b/tajo-core/src/test/java/org/apache/tajo/engine/function/TestPythonFunctions.java index 47a0ad2..6f73d01 100644 --- a/tajo-core/src/test/java/org/apache/tajo/engine/function/TestPythonFunctions.java +++ b/tajo-core/src/test/java/org/apache/tajo/engine/function/TestPythonFunctions.java @@ -31,14 +31,14 @@ public class TestPythonFunctions extends ExprTestBase { testSimpleEval("select helloworld()", new String[]{"Hello, World"}); testSimpleEval("select concat_py('1')", new String[]{"11"}); testSimpleEval("select comma_format(12345)", new String[]{"12,345"}); - testSimpleEval("select sum_py(1,2)", new String[]{"3"}); + testSimpleEval("select add_py(1,2)", new String[]{"3"}); testSimpleEval("select percent(386, 1000)", new String[]{"38.6"}); testSimpleEval("select concat4('Tajo', 'is', 'awesome', '!')", new String[]{"Tajo is awesome !"}); } @Test public void testNestedFunctions() throws IOException { - testSimpleEval("select sum_py(3, return_one())", new String[]{"4"}); + testSimpleEval("select add_py(3, return_one())", new String[]{"4"}); testSimpleEval("select concat_py(helloworld())", new String[]{"Hello, WorldHello, World"}); } } http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/java/org/apache/tajo/engine/query/TestGroupByQuery.java ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/java/org/apache/tajo/engine/query/TestGroupByQuery.java b/tajo-core/src/test/java/org/apache/tajo/engine/query/TestGroupByQuery.java index 15a1c9f..1da3ee9 100644 --- a/tajo-core/src/test/java/org/apache/tajo/engine/query/TestGroupByQuery.java +++ b/tajo-core/src/test/java/org/apache/tajo/engine/query/TestGroupByQuery.java @@ -813,4 +813,56 @@ public class TestGroupByQuery extends QueryTestCaseBase { assertResultSet(res); cleanupQuery(res); } + + @Test + public final void testPythonUdaf() throws Exception { + ResultSet res = executeQuery(); + assertResultSet(res); + cleanupQuery(res); + } + + @Test + public final void testPythonUdaf2() throws Exception { + ResultSet res = executeQuery(); + assertResultSet(res); + cleanupQuery(res); + } + + @Test + public final void testPythonUdaf3() throws Exception { + ResultSet res = executeQuery(); + assertResultSet(res); + cleanupQuery(res); + } + + // TODO: this test cannot be executed due to the bug of logical planner +// @Test + public final void testPythonUdafWithHaving() throws Exception { + ResultSet res = executeQuery(); + assertResultSet(res); + cleanupQuery(res); + } + + @Test + public final void testPythonUdafWithNullData() throws Exception { + ResultSet res = executeQuery(); + assertResultSet(res); + cleanupQuery(res); + } + + // TODO: this test cannot be executed due to the bug of logical planner +// @Test + public final void testComplexTargetWithPythonUdaf() throws Exception { + ResultSet res = executeQuery(); + assertResultSet(res); + cleanupQuery(res); + } + + // TODO: this test cannot be executed due to the bug of logical planner +// @Test + public final void testDistinctPythonUdafWithUnion1() throws Exception { + ResultSet res = executeQuery(); + assertResultSet(res); + cleanupQuery(res); + } } http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/python/test_funcs.py ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/python/test_funcs.py b/tajo-core/src/test/resources/python/test_funcs.py index d6b7db5..1167afd 100644 --- a/tajo-core/src/test/resources/python/test_funcs.py +++ b/tajo-core/src/test/resources/python/test_funcs.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tajo_util import outputType +from tajo_util import output_type -@outputType('int4') +@output_type('int4') def return_one(): return 1 -@outputType("text") +@output_type("text") def helloworld(): return 'Hello, World' @@ -28,6 +28,6 @@ def helloworld(): def concat_py(str): return str+str -@outputType('int4') -def sum_py(a,b): +@output_type('int4') +def add_py(a,b): return a+b http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/python/test_funcs.pyc ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/python/test_funcs.pyc b/tajo-core/src/test/resources/python/test_funcs.pyc deleted file mode 100644 index cc84dc1..0000000 Binary files a/tajo-core/src/test/resources/python/test_funcs.pyc and /dev/null differ http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/python/test_funcs2.py ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/python/test_funcs2.py b/tajo-core/src/test/resources/python/test_funcs2.py index e8db7b5..8a6f608 100644 --- a/tajo-core/src/test/resources/python/test_funcs2.py +++ b/tajo-core/src/test/resources/python/test_funcs2.py @@ -14,19 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tajo_util import outputType +from tajo_util import output_type #Percent- Percentage -@outputType("float8") +@output_type("float8") def percent(num, total): return num * 100 / float(total) #commaFormat- format a number with commas, 12345-> 12,345 -@outputType("text") +@output_type("text") def comma_format(num): return '{:,}'.format(num) #concatMultiple- concat multiple words -@outputType("text") +@output_type("text") def concat4(word1, word2, word3, word4): return word1 + " " + word2 + " " + word3 + " " + word4 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/python/test_udaf.py ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/python/test_udaf.py b/tajo-core/src/test/resources/python/test_udaf.py new file mode 100644 index 0000000..da5a3fd --- /dev/null +++ b/tajo-core/src/test/resources/python/test_udaf.py @@ -0,0 +1,76 @@ +############################################################################ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tajo_util import output_type + + +class AvgPy: + sum = 0 + cnt = 0 + + def __init__(self): + self.reset() + + def reset(self): + self.sum = 0 + self.cnt = 0 + + # eval at the first stage + def eval(self, item): + self.sum += item + self.cnt += 1 + + # get intermediate result + def get_partial_result(self): + return [self.sum, self.cnt] + + # merge intermediate results + def merge(self, list): + self.sum += list[0] + self.cnt += list[1] + + # get final result + @output_type('float8') + def get_final_result(self): + return self.sum / float(self.cnt) + + +class CountPy: + cnt = 0 + + def __init__(self): + self.reset() + + def reset(self): + self.cnt = 0 + + # eval at the first stage + def eval(self): + self.cnt += 1 + + # get intermediate result + def get_partial_result(self): + return self.cnt + + # merge intermediate results + def merge(self, cnt): + self.cnt += cnt + + # get final result + @output_type('int4') + def get_final_result(self): + return self.cnt + http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestGroupByQuery/testComplexTargetWithPythonUdaf.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestGroupByQuery/testComplexTargetWithPythonUdaf.sql b/tajo-core/src/test/resources/queries/TestGroupByQuery/testComplexTargetWithPythonUdaf.sql new file mode 100644 index 0000000..551655f --- /dev/null +++ b/tajo-core/src/test/resources/queries/TestGroupByQuery/testComplexTargetWithPythonUdaf.sql @@ -0,0 +1 @@ +select countpy() + max(l_orderkey) as merged from lineitem; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestGroupByQuery/testDistinctPythonUdafWithUnion1.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestGroupByQuery/testDistinctPythonUdafWithUnion1.sql b/tajo-core/src/test/resources/queries/TestGroupByQuery/testDistinctPythonUdafWithUnion1.sql new file mode 100644 index 0000000..73e9579 --- /dev/null +++ b/tajo-core/src/test/resources/queries/TestGroupByQuery/testDistinctPythonUdafWithUnion1.sql @@ -0,0 +1,21 @@ +select + sum(distinct l_orderkey), + l_linenumber, + count(distinct l_orderkey), + countpy() as total +from + ( + select + * + from + lineitem + + union + + select + * + from + lineitem + ) t1 +group by + l_linenumber; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestGroupByQuery/testGroupbyWithPythonFunc.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestGroupByQuery/testGroupbyWithPythonFunc.sql b/tajo-core/src/test/resources/queries/TestGroupByQuery/testGroupbyWithPythonFunc.sql index 888552a..8100f11 100644 --- a/tajo-core/src/test/resources/queries/TestGroupByQuery/testGroupbyWithPythonFunc.sql +++ b/tajo-core/src/test/resources/queries/TestGroupByQuery/testGroupbyWithPythonFunc.sql @@ -1 +1 @@ -select count(*) from nation where sum_py(n_nationkey, 1) > 2 group by n_regionkey \ No newline at end of file +select count(*) from nation where add_py(n_nationkey, 1) > 2 group by n_regionkey \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf.sql b/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf.sql new file mode 100644 index 0000000..e29816b --- /dev/null +++ b/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf.sql @@ -0,0 +1 @@ +select avgpy(n_nationkey), avg(n_nationkey), countpy() from nation; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf2.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf2.sql b/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf2.sql new file mode 100644 index 0000000..b08ff1e --- /dev/null +++ b/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf2.sql @@ -0,0 +1 @@ +select countpy(), count(*) from lineitem group by l_linenumber \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf3.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf3.sql b/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf3.sql new file mode 100644 index 0000000..406442f --- /dev/null +++ b/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdaf3.sql @@ -0,0 +1 @@ +select avgpy(o_totalprice), countpy(), avg(o_totalprice), count(*) from orders group by o_custkey, o_orderdate \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdafWithHaving.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdafWithHaving.sql b/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdafWithHaving.sql new file mode 100644 index 0000000..b7769b7 --- /dev/null +++ b/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdafWithHaving.sql @@ -0,0 +1,3 @@ +select l_orderkey, avgpy(l_partkey) total, sum(l_linenumber) as num from lineitem +group by l_orderkey +having avgpy(l_partkey) = 2.5 or num = 1; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdafWithNullData.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdafWithNullData.sql b/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdafWithNullData.sql new file mode 100644 index 0000000..56fb65c --- /dev/null +++ b/tajo-core/src/test/resources/queries/TestGroupByQuery/testPythonUdafWithNullData.sql @@ -0,0 +1,4 @@ +select l_orderkey, count(distinct l_linenumber) as unique_key +from lineitem +where l_orderkey = 1000 +group by l_orderkey \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestSelectQuery/testNestedPythonFunction.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestSelectQuery/testNestedPythonFunction.sql b/tajo-core/src/test/resources/queries/TestSelectQuery/testNestedPythonFunction.sql index 75b33ae..02b2059 100644 --- a/tajo-core/src/test/resources/queries/TestSelectQuery/testNestedPythonFunction.sql +++ b/tajo-core/src/test/resources/queries/TestSelectQuery/testNestedPythonFunction.sql @@ -1 +1 @@ -select * from nation where sum_py(n_regionkey, return_one()) < 2 \ No newline at end of file +select * from nation where add_py(n_regionkey, return_one()) < 2 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestSelectQuery/testSelectPythonFuncs.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestSelectQuery/testSelectPythonFuncs.sql b/tajo-core/src/test/resources/queries/TestSelectQuery/testSelectPythonFuncs.sql index bcb9806..5ae0d5e 100644 --- a/tajo-core/src/test/resources/queries/TestSelectQuery/testSelectPythonFuncs.sql +++ b/tajo-core/src/test/resources/queries/TestSelectQuery/testSelectPythonFuncs.sql @@ -1,2 +1,2 @@ -select helloworld(), sum_py(n_nationkey, n_regionkey) as sum, concat_py(n_name) as concat +select helloworld(), add_py(n_nationkey, n_regionkey) as sum, concat_py(n_name) as concat from nation where n_nationkey < 5 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/queries/TestSelectQuery/testSelectWithPredicateOnPythonFunc.sql ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/queries/TestSelectQuery/testSelectWithPredicateOnPythonFunc.sql b/tajo-core/src/test/resources/queries/TestSelectQuery/testSelectWithPredicateOnPythonFunc.sql index d2c5082..aa9feba 100644 --- a/tajo-core/src/test/resources/queries/TestSelectQuery/testSelectWithPredicateOnPythonFunc.sql +++ b/tajo-core/src/test/resources/queries/TestSelectQuery/testSelectWithPredicateOnPythonFunc.sql @@ -1 +1 @@ -select * from nation where sum_py(n_regionkey,1) > 2 \ No newline at end of file +select * from nation where add_py(n_regionkey,1) > 2 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/results/TestGroupByQuery/testComplexTargetWithPythonUdaf.result ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/results/TestGroupByQuery/testComplexTargetWithPythonUdaf.result b/tajo-core/src/test/resources/results/TestGroupByQuery/testComplexTargetWithPythonUdaf.result new file mode 100644 index 0000000..6ee9cb5 --- /dev/null +++ b/tajo-core/src/test/resources/results/TestGroupByQuery/testComplexTargetWithPythonUdaf.result @@ -0,0 +1,3 @@ +merged +------------------------------- +8 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/results/TestGroupByQuery/testDistinctPythonUdafWithUnion1.result ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/results/TestGroupByQuery/testDistinctPythonUdafWithUnion1.result b/tajo-core/src/test/resources/results/TestGroupByQuery/testDistinctPythonUdafWithUnion1.result new file mode 100644 index 0000000..16c5524 --- /dev/null +++ b/tajo-core/src/test/resources/results/TestGroupByQuery/testDistinctPythonUdafWithUnion1.result @@ -0,0 +1,4 @@ +?sum,l_linenumber,?count_1,total +------------------------------- +6,1,3,6 +4,2,2,4 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf.result ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf.result b/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf.result new file mode 100644 index 0000000..e1ba22d --- /dev/null +++ b/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf.result @@ -0,0 +1,3 @@ +?avgpy,?avg_1,?countpy_2 +------------------------------- +12.0,12.0,25 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf2.result ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf2.result b/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf2.result new file mode 100644 index 0000000..2852167 --- /dev/null +++ b/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf2.result @@ -0,0 +1,4 @@ +?countpy,?count_1 +------------------------------- +2,2 +3,3 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf3.result ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf3.result b/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf3.result new file mode 100644 index 0000000..0607720 --- /dev/null +++ b/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdaf3.result @@ -0,0 +1,5 @@ +?avgpy,?countpy_1,?avg_2,?count_3 +------------------------------- +193846.25,1,193846.25,1 +46929.18,1,46929.18,1 +173665.47,1,173665.47,1 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdafWithHaving.result ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdafWithHaving.result b/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdafWithHaving.result new file mode 100644 index 0000000..b8369d2 --- /dev/null +++ b/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdafWithHaving.result @@ -0,0 +1,4 @@ +l_orderkey,total,num +------------------------------- +3,2.5,3 +2,2.0,1 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdafWithNullData.result ---------------------------------------------------------------------- diff --git a/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdafWithNullData.result b/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdafWithNullData.result new file mode 100644 index 0000000..1f6d988 --- /dev/null +++ b/tajo-core/src/test/resources/results/TestGroupByQuery/testPythonUdafWithNullData.result @@ -0,0 +1,2 @@ +l_orderkey,unique_key +------------------------------- \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-docs/src/main/sphinx/functions.rst ---------------------------------------------------------------------- diff --git a/tajo-docs/src/main/sphinx/functions.rst b/tajo-docs/src/main/sphinx/functions.rst index 8200bfd..7e885da 100644 --- a/tajo-docs/src/main/sphinx/functions.rst +++ b/tajo-docs/src/main/sphinx/functions.rst @@ -18,63 +18,10 @@ Built-in Functions functions/json_func ============================== -Python User-defined Functions +User-defined Functions ============================== ------------------------ -Function registration ------------------------ - -To register Python UDFs, you must install script files in all cluster nodes. -After that, you can register your functions by specifying the paths to those script files in ``tajo-site.xml``. Here is an example of the configuration. - -.. code-block:: xml - - <property> - <name>tajo.function.python.code-dir</name> - <value>/path/to/script1.py,/path/to/script2.py</value> - </property> - -Please note that you can specify multiple paths with ``','`` as a delimiter. Each file can contain multiple functions. Here is a typical example of a script file. - -.. code-block:: python - - # /path/to/script1.py - - @outputType('int4') - def return_one(): - return 1 - - @outputType("text") - def helloworld(): - return 'Hello, World' - - # No decorator - blob - def concat_py(str): - return str+str - - @outputType('int4') - def sum_py(a,b): - return a+b - -If the configuration is set properly, every function in the script files are registered when the Tajo cluster starts up. - ------------------------ -Decorators and types ------------------------ - -By default, every function has a return type of ``BLOB``. -You can use Python decorators to define output types for the script functions. Tajo can figure out return types from the annotations of the Python script. - -* ``outputType``: Defines the return data type for a script UDF in a format that Tajo can understand. The defined type must be one of the types supported by Tajo. For supported types, please refer to :doc:`/sql_language/data_model`. - ------------------------ -Query example ------------------------ - -Once the Python UDFs are successfully registered, you can use them as other built-in functions. - -.. code-block:: sql - - default> select concat_py(n_name)::text from nation where sum_py(n_regionkey,1) > 2; +.. toctree:: + :maxdepth: 1 + functions/python \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-docs/src/main/sphinx/functions/python.rst ---------------------------------------------------------------------- diff --git a/tajo-docs/src/main/sphinx/functions/python.rst b/tajo-docs/src/main/sphinx/functions/python.rst new file mode 100644 index 0000000..83eb4e3 --- /dev/null +++ b/tajo-docs/src/main/sphinx/functions/python.rst @@ -0,0 +1,159 @@ +****************************** +Python Functions +****************************** + +======================= +User-defined Functions +======================= + +----------------------- +Function registration +----------------------- + +To register Python UDFs, you must install script files in all cluster nodes. +After that, you can register your functions by specifying the paths to those script files in ``tajo-site.xml``. Here is an example of the configuration. + +.. code-block:: xml + + <property> + <name>tajo.function.python.code-dir</name> + <value>/path/to/script1.py,/path/to/script2.py</value> + </property> + +Please note that you can specify multiple paths with ``','`` as a delimiter. Each file can contain multiple functions. Here is a typical example of a script file. + +.. code-block:: python + + # /path/to/udf1.py + + @output_type('int4') + def return_one(): + return 1 + + @output_type("text") + def helloworld(): + return 'Hello, World' + + # No decorator - blob + def concat_py(str): + return str+str + + @output_type('int4') + def sum_py(a,b): + return a+b + +If the configuration is set properly, every function in the script files are registered when the Tajo cluster starts up. + +----------------------- +Decorators and types +----------------------- + +By default, every function has a return type of ``BLOB``. +You can use Python decorators to define output types for the script functions. Tajo can figure out return types from the annotations of the Python script. + +* ``output_type``: Defines the return data type for a script UDF in a format that Tajo can understand. The defined type must be one of the types supported by Tajo. For supported types, please refer to :doc:`/sql_language/data_model`. + +----------------------- +Query example +----------------------- + +Once the Python UDFs are successfully registered, you can use them as other built-in functions. + +.. code-block:: sql + + default> select concat_py(n_name)::text from nation where sum_py(n_regionkey,1) > 2; + +============================================== +User-defined Aggregation Functions +============================================== + +----------------------- +Function registration +----------------------- + +To define your Python aggregation functions, you should write Python classes for each function. +Followings are typical examples of Python UDAFs. + +.. code-block:: python + + # /path/to/udaf1.py + + class AvgPy: + sum = 0 + cnt = 0 + + def __init__(self): + self.reset() + + def reset(self): + self.sum = 0 + self.cnt = 0 + + # eval at the first stage + def eval(self, item): + self.sum += item + self.cnt += 1 + + # get intermediate result + def get_partial_result(self): + return [self.sum, self.cnt] + + # merge intermediate results + def merge(self, list): + self.sum += list[0] + self.cnt += list[1] + + # get final result + @output_type('float8') + def get_final_result(self): + return self.sum / float(self.cnt) + + + class CountPy: + cnt = 0 + + def __init__(self): + self.reset() + + def reset(self): + self.cnt = 0 + + # eval at the first stage + def eval(self): + self.cnt += 1 + + # get intermediate result + def get_partial_result(self): + return self.cnt + + # merge intermediate results + def merge(self, cnt): + self.cnt += cnt + + # get final result + @output_type('int4') + def get_final_result(self): + return self.cnt + + +These classes must provide ``reset()``, ``eval()``, ``merge()``, ``get_partial_result()``, and ``get_final_result()`` functions. + +* ``reset()`` resets the aggregation state. +* ``eval()`` evaluates input tuples in the first stage. +* ``merge()`` merges intermediate results of the first stage. +* ``get_partial_result()`` returns intermediate results of the first stage. Output type must be same with the input type of ``merge()``. +* ``get_final_result()`` returns the final aggregation result. + +----------------------- +Query example +----------------------- + +Once the Python UDAFs are successfully registered, you can use them as other built-in aggregation functions. + +.. code-block:: sql + + default> select avgpy(n_nationkey), countpy() from nation; + +.. warning:: + + Currently, Python UDAFs cannot be used as window functions. \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-plan/src/main/java/org/apache/tajo/plan/ExprAnnotator.java ---------------------------------------------------------------------- diff --git a/tajo-plan/src/main/java/org/apache/tajo/plan/ExprAnnotator.java b/tajo-plan/src/main/java/org/apache/tajo/plan/ExprAnnotator.java index 0c5a012..e24cf6e 100644 --- a/tajo-plan/src/main/java/org/apache/tajo/plan/ExprAnnotator.java +++ b/tajo-plan/src/main/java/org/apache/tajo/plan/ExprAnnotator.java @@ -615,7 +615,7 @@ public class ExprAnnotator extends BaseAlgebraVisitor<ExprAnnotator.Context, Eva if (!ctx.currentBlock.hasNode(NodeType.GROUP_BY)) { ctx.currentBlock.setAggregationRequire(); } - return new AggregationFunctionCallEval(funcDesc, (AggFunction) funcDesc.newInstance(), givenArgs); + return new AggregationFunctionCallEval(funcDesc, givenArgs); } else if (functionType == FunctionType.DISTINCT_AGGREGATION || functionType == FunctionType.DISTINCT_UDA) { throw new PlanningException("Unsupported function: " + funcDesc.toString()); @@ -638,14 +638,8 @@ public class ExprAnnotator extends BaseAlgebraVisitor<ExprAnnotator.Context, Eva throw new NoSuchFunctionException(expr.getSignature(), new DataType[]{}); } - try { - ctx.currentBlock.setAggregationRequire(); - - return new AggregationFunctionCallEval(countRows, (AggFunction) countRows.newInstance(), - new EvalNode[] {}); - } catch (InternalException e) { - throw new NoSuchFunctionException(countRows.getFunctionName(), new DataType[]{}); - } + ctx.currentBlock.setAggregationRequire(); + return new AggregationFunctionCallEval(countRows, new EvalNode[] {}); } @Override @@ -674,11 +668,7 @@ public class ExprAnnotator extends BaseAlgebraVisitor<ExprAnnotator.Context, Eva ctx.currentBlock.setAggregationRequire(); } - try { - return new AggregationFunctionCallEval(funcDesc, (AggFunction) funcDesc.newInstance(), givenArgs); - } catch (InternalException e) { - throw new PlanningException(e); - } + return new AggregationFunctionCallEval(funcDesc, givenArgs); } public static final Set<String> WINDOW_FUNCTIONS = @@ -764,11 +754,7 @@ public class ExprAnnotator extends BaseAlgebraVisitor<ExprAnnotator.Context, Eva FunctionDesc funcDesc = catalog.getFunction(funcName, functionType, paramTypes); - try { - return new WindowFunctionEval(funcDesc, (AggFunction) funcDesc.newInstance(), givenArgs, frame); - } catch (InternalException e) { - throw new PlanningException(e); - } + return new WindowFunctionEval(funcDesc, givenArgs, frame); } /////////////////////////////////////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-plan/src/main/java/org/apache/tajo/plan/expr/AggregationFunctionCallEval.java ---------------------------------------------------------------------- diff --git a/tajo-plan/src/main/java/org/apache/tajo/plan/expr/AggregationFunctionCallEval.java b/tajo-plan/src/main/java/org/apache/tajo/plan/expr/AggregationFunctionCallEval.java index 5549e2e..cfcc829 100644 --- a/tajo-plan/src/main/java/org/apache/tajo/plan/expr/AggregationFunctionCallEval.java +++ b/tajo-plan/src/main/java/org/apache/tajo/plan/expr/AggregationFunctionCallEval.java @@ -19,34 +19,62 @@ package org.apache.tajo.plan.expr; import com.google.gson.annotations.Expose; - import org.apache.tajo.catalog.FunctionDesc; +import org.apache.tajo.catalog.Schema; import org.apache.tajo.common.TajoDataTypes.DataType; import org.apache.tajo.datum.Datum; -import org.apache.tajo.plan.function.AggFunction; +import org.apache.tajo.exception.InternalException; +import org.apache.tajo.plan.function.AggFunctionInvoke; import org.apache.tajo.plan.function.FunctionContext; +import org.apache.tajo.plan.function.FunctionInvokeContext; import org.apache.tajo.storage.Tuple; import org.apache.tajo.util.TUtil; +import java.io.IOException; + public class AggregationFunctionCallEval extends FunctionEval implements Cloneable { @Expose boolean intermediatePhase = false; @Expose boolean finalPhase = true; @Expose String alias; - protected AggFunction instance; +// protected AggFunction instance; + @Expose protected FunctionInvokeContext invokeContext; + protected transient AggFunctionInvoke functionInvoke; - protected AggregationFunctionCallEval(EvalType type, FunctionDesc desc, AggFunction instance, EvalNode[] givenArgs) { + protected AggregationFunctionCallEval(EvalType type, FunctionDesc desc, EvalNode[] givenArgs) { super(type, desc, givenArgs); - this.instance = instance; + this.invokeContext = new FunctionInvokeContext(null, getParamType()); + try { + this.functionInvoke = AggFunctionInvoke.newInstance(funcDesc); + } catch (InternalException e) { + throw new RuntimeException(e); + } } - public AggregationFunctionCallEval(FunctionDesc desc, AggFunction instance, EvalNode[] givenArgs) { - super(EvalType.AGG_FUNCTION, desc, givenArgs); - this.instance = instance; + public AggregationFunctionCallEval(FunctionDesc desc, EvalNode[] givenArgs) { + this(EvalType.AGG_FUNCTION, desc, givenArgs); } public FunctionContext newContext() { - return instance.newContext(); + return functionInvoke.newContext(); + } + + @Override + public EvalNode bind(EvalContext evalContext, Schema schema) { + super.bind(evalContext, schema); + + try { + if (evalContext != null && evalContext.hasScriptEngine(this)) { + this.invokeContext.setScriptEngine(evalContext.getScriptEngine(this)); + this.invokeContext.getScriptEngine().setIntermediatePhase(intermediatePhase); + this.invokeContext.getScriptEngine().setFinalPhase(finalPhase); + } + this.functionInvoke.init(invokeContext); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return this; } public void merge(FunctionContext context, Tuple tuple) { @@ -59,9 +87,9 @@ public class AggregationFunctionCallEval extends FunctionEval implements Cloneab protected void mergeParam(FunctionContext context, Tuple params) { if (!intermediatePhase && !finalPhase) { // firstPhase - instance.eval(context, params); + functionInvoke.eval(context, params); } else { - instance.merge(context, params); + functionInvoke.merge(context, params); } } @@ -75,16 +103,16 @@ public class AggregationFunctionCallEval extends FunctionEval implements Cloneab throw new IllegalStateException("bind() must be called before terminate()"); } if (!finalPhase) { - return instance.getPartialResult(context); + return functionInvoke.getPartialResult(context); } else { - return instance.terminate(context); + return functionInvoke.terminate(context); } } @Override public DataType getValueType() { if (!finalPhase) { - return instance.getPartialResultType(); + return functionInvoke.getPartialResultType(); } else { return funcDesc.getReturnType(); } @@ -104,7 +132,10 @@ public class AggregationFunctionCallEval extends FunctionEval implements Cloneab clone.finalPhase = finalPhase; clone.intermediatePhase = intermediatePhase; clone.alias = alias; - clone.instance = (AggFunction)instance.clone(); + clone.invokeContext = (FunctionInvokeContext) invokeContext.clone(); + if (functionInvoke != null) { + clone.functionInvoke = functionInvoke; + } return clone; } @@ -146,7 +177,6 @@ public class AggregationFunctionCallEval extends FunctionEval implements Cloneab int result = super.hashCode(); result = prime * result + ((alias == null) ? 0 : alias.hashCode()); result = prime * result + (finalPhase ? 1231 : 1237); - result = prime * result + ((instance == null) ? 0 : instance.hashCode()); result = prime * result + (intermediatePhase ? 1231 : 1237); return result; } @@ -157,7 +187,6 @@ public class AggregationFunctionCallEval extends FunctionEval implements Cloneab AggregationFunctionCallEval other = (AggregationFunctionCallEval) obj; boolean eq = super.equals(other); - eq &= instance.equals(other.instance); eq &= intermediatePhase == other.intermediatePhase; eq &= finalPhase == other.finalPhase; eq &= TUtil.checkEquals(alias, other.alias); http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-plan/src/main/java/org/apache/tajo/plan/expr/EvalContext.java ---------------------------------------------------------------------- diff --git a/tajo-plan/src/main/java/org/apache/tajo/plan/expr/EvalContext.java b/tajo-plan/src/main/java/org/apache/tajo/plan/expr/EvalContext.java index 6a30e77..869dc73 100644 --- a/tajo-plan/src/main/java/org/apache/tajo/plan/expr/EvalContext.java +++ b/tajo-plan/src/main/java/org/apache/tajo/plan/expr/EvalContext.java @@ -32,7 +32,8 @@ public class EvalContext { } public boolean hasScriptEngine(EvalNode evalNode) { - return this.scriptEngineMap.containsKey(evalNode); + boolean contain = this.scriptEngineMap.containsKey(evalNode); + return contain; } public TajoScriptEngine getScriptEngine(EvalNode evalNode) { http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-plan/src/main/java/org/apache/tajo/plan/expr/WindowFunctionEval.java ---------------------------------------------------------------------- diff --git a/tajo-plan/src/main/java/org/apache/tajo/plan/expr/WindowFunctionEval.java b/tajo-plan/src/main/java/org/apache/tajo/plan/expr/WindowFunctionEval.java index a72d826..0b60d14 100644 --- a/tajo-plan/src/main/java/org/apache/tajo/plan/expr/WindowFunctionEval.java +++ b/tajo-plan/src/main/java/org/apache/tajo/plan/expr/WindowFunctionEval.java @@ -18,27 +18,25 @@ package org.apache.tajo.plan.expr; -import java.util.Arrays; - import com.google.gson.annotations.Expose; - import org.apache.tajo.catalog.FunctionDesc; import org.apache.tajo.catalog.SortSpec; import org.apache.tajo.common.TajoDataTypes.DataType; import org.apache.tajo.datum.Datum; -import org.apache.tajo.plan.function.AggFunction; import org.apache.tajo.plan.function.FunctionContext; import org.apache.tajo.plan.logical.WindowSpec; import org.apache.tajo.storage.Tuple; import org.apache.tajo.util.TUtil; +import java.util.Arrays; + public class WindowFunctionEval extends AggregationFunctionCallEval implements Cloneable { @Expose private SortSpec [] sortSpecs; @Expose WindowSpec.WindowFrame windowFrame; - public WindowFunctionEval(FunctionDesc desc, AggFunction instance, EvalNode[] givenArgs, + public WindowFunctionEval(FunctionDesc desc, EvalNode[] givenArgs, WindowSpec.WindowFrame windowFrame) { - super(EvalType.WINDOW_FUNCTION, desc, instance, givenArgs); + super(EvalType.WINDOW_FUNCTION, desc, givenArgs); this.windowFrame = windowFrame; } @@ -60,7 +58,7 @@ public class WindowFunctionEval extends AggregationFunctionCallEval implements C @Override protected void mergeParam(FunctionContext context, Tuple params) { - instance.eval(context, params); + functionInvoke.eval(context, params); } @Override @@ -68,7 +66,7 @@ public class WindowFunctionEval extends AggregationFunctionCallEval implements C if (!isBinded) { throw new IllegalStateException("bind() must be called before terminate()"); } - return instance.terminate(context); + return functionInvoke.terminate(context); } @Override http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-plan/src/main/java/org/apache/tajo/plan/function/AggFunction.java ---------------------------------------------------------------------- diff --git a/tajo-plan/src/main/java/org/apache/tajo/plan/function/AggFunction.java b/tajo-plan/src/main/java/org/apache/tajo/plan/function/AggFunction.java index 08ea6a7..9fa2369 100644 --- a/tajo-plan/src/main/java/org/apache/tajo/plan/function/AggFunction.java +++ b/tajo-plan/src/main/java/org/apache/tajo/plan/function/AggFunction.java @@ -35,8 +35,18 @@ public abstract class AggFunction<T extends Datum> extends Function<T> { public abstract FunctionContext newContext(); + /** + * Called at the first stage. + * @param ctx + * @param params + */ public abstract void eval(FunctionContext ctx, Tuple params); + /** + * Called at all stages except the first one. + * @param ctx + * @param part + */ public void merge(FunctionContext ctx, Tuple part) { eval(ctx, part); } http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-plan/src/main/java/org/apache/tajo/plan/function/AggFunctionInvoke.java ---------------------------------------------------------------------- diff --git a/tajo-plan/src/main/java/org/apache/tajo/plan/function/AggFunctionInvoke.java b/tajo-plan/src/main/java/org/apache/tajo/plan/function/AggFunctionInvoke.java new file mode 100644 index 0000000..2c2afbe --- /dev/null +++ b/tajo-plan/src/main/java/org/apache/tajo/plan/function/AggFunctionInvoke.java @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tajo.plan.function; + +import com.google.gson.annotations.Expose; +import org.apache.tajo.catalog.FunctionDesc; +import org.apache.tajo.common.TajoDataTypes; +import org.apache.tajo.datum.Datum; +import org.apache.tajo.exception.InternalException; +import org.apache.tajo.exception.UnsupportedException; +import org.apache.tajo.storage.Tuple; + +import java.io.IOException; + +public abstract class AggFunctionInvoke implements Cloneable { + @Expose protected FunctionDesc functionDesc; + + public AggFunctionInvoke(FunctionDesc functionDesc) { + this.functionDesc = functionDesc; + } + + public static AggFunctionInvoke newInstance(FunctionDesc desc) throws InternalException { + // TODO: The below line is due to the bug in the function type. The type of class-based functions is not set properly. + if (desc.getInvocation().hasLegacy()) { + return new ClassBasedAggFunctionInvoke(desc); + } else if (desc.getInvocation().hasPythonAggregation()) { + return new PythonAggFunctionInvoke(desc); + } else { + throw new UnsupportedException(desc.getInvocation() + " is not supported"); + } + } + + public void setFunctionDesc(FunctionDesc functionDesc) { + this.functionDesc = functionDesc; + } + + public abstract void init(FunctionInvokeContext context) throws IOException; + + public abstract FunctionContext newContext(); + + public abstract void eval(FunctionContext context, Tuple params); + + public abstract void merge(FunctionContext context, Tuple params); + + public abstract Datum getPartialResult(FunctionContext context); + + // TODO: use {@link IntermFunctionSignature} instead of this function. + public abstract TajoDataTypes.DataType getPartialResultType(); + + public abstract Datum terminate(FunctionContext context); + + @Override + public boolean equals(Object o) { + if (o instanceof AggFunctionInvoke) { + AggFunctionInvoke other = (AggFunctionInvoke) o; + return this.functionDesc.equals(other.functionDesc); + } + return false; + } + + @Override + public int hashCode() { + return functionDesc.hashCode(); + } + + @Override + public Object clone() throws CloneNotSupportedException { + AggFunctionInvoke clone = (AggFunctionInvoke) super.clone(); + clone.functionDesc = (FunctionDesc) this.functionDesc.clone(); + return clone; + } +} http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-plan/src/main/java/org/apache/tajo/plan/function/ClassBasedAggFunctionInvoke.java ---------------------------------------------------------------------- diff --git a/tajo-plan/src/main/java/org/apache/tajo/plan/function/ClassBasedAggFunctionInvoke.java b/tajo-plan/src/main/java/org/apache/tajo/plan/function/ClassBasedAggFunctionInvoke.java new file mode 100644 index 0000000..6657871 --- /dev/null +++ b/tajo-plan/src/main/java/org/apache/tajo/plan/function/ClassBasedAggFunctionInvoke.java @@ -0,0 +1,82 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tajo.plan.function; + +import com.google.gson.annotations.Expose; +import org.apache.tajo.catalog.FunctionDesc; +import org.apache.tajo.common.TajoDataTypes; +import org.apache.tajo.datum.Datum; +import org.apache.tajo.exception.InternalException; +import org.apache.tajo.storage.Tuple; + +import java.io.IOException; + +/** + * This class invokes class-based aggregation functions. + */ +public class ClassBasedAggFunctionInvoke extends AggFunctionInvoke implements Cloneable { + @Expose private AggFunction function; + + public ClassBasedAggFunctionInvoke(FunctionDesc functionDesc) throws InternalException { + super(functionDesc); + function = (AggFunction) functionDesc.newInstance(); + } + + @Override + public void init(FunctionInvokeContext context) throws IOException { + // nothing to do + } + + @Override + public FunctionContext newContext() { + return function.newContext(); + } + + @Override + public void eval(FunctionContext context, Tuple params) { + function.eval(context, params); + } + + @Override + public void merge(FunctionContext context, Tuple params) { + function.merge(context, params); + } + + @Override + public Datum getPartialResult(FunctionContext context) { + return function.getPartialResult(context); + } + + @Override + public TajoDataTypes.DataType getPartialResultType() { + return function.getPartialResultType(); + } + + @Override + public Datum terminate(FunctionContext context) { + return function.terminate(context); + } + + @Override + public Object clone() throws CloneNotSupportedException { + ClassBasedAggFunctionInvoke clone = (ClassBasedAggFunctionInvoke) super.clone(); + clone.function = (AggFunction) function.clone(); + return clone; + } +} http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-plan/src/main/java/org/apache/tajo/plan/function/ClassBasedScalarFunctionInvoke.java ---------------------------------------------------------------------- diff --git a/tajo-plan/src/main/java/org/apache/tajo/plan/function/ClassBasedScalarFunctionInvoke.java b/tajo-plan/src/main/java/org/apache/tajo/plan/function/ClassBasedScalarFunctionInvoke.java new file mode 100644 index 0000000..c3f4ad9 --- /dev/null +++ b/tajo-plan/src/main/java/org/apache/tajo/plan/function/ClassBasedScalarFunctionInvoke.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tajo.plan.function; + +import com.google.gson.annotations.Expose; +import org.apache.tajo.catalog.FunctionDesc; +import org.apache.tajo.datum.Datum; +import org.apache.tajo.exception.InternalException; +import org.apache.tajo.storage.Tuple; +import org.apache.tajo.util.TUtil; + +/** + * This class invokes class-based scala functions. + */ +public class ClassBasedScalarFunctionInvoke extends FunctionInvoke implements Cloneable { + @Expose private GeneralFunction function; + + public ClassBasedScalarFunctionInvoke() { + + } + + public ClassBasedScalarFunctionInvoke(FunctionDesc funcDesc) throws InternalException { + super(funcDesc); + function = (GeneralFunction) funcDesc.newInstance(); + } + + @Override + public void setFunctionDesc(FunctionDesc desc) throws InternalException { + super.setFunctionDesc(desc); + function = (GeneralFunction) functionDesc.newInstance(); + } + + @Override + public void init(FunctionInvokeContext context) { + function.init(context.getQueryContext(), context.getParamTypes()); + } + + @Override + public Datum eval(Tuple tuple) { + return function.eval(tuple); + } + + @Override + public boolean equals(Object o) { + if (o instanceof ClassBasedScalarFunctionInvoke) { + ClassBasedScalarFunctionInvoke other = (ClassBasedScalarFunctionInvoke) o; + return super.equals(other) && + TUtil.checkEquals(function, other.function); + } + return false; + } + + @Override + public int hashCode() { + return function.hashCode(); + } + + @Override + public Object clone() throws CloneNotSupportedException { + ClassBasedScalarFunctionInvoke clone = (ClassBasedScalarFunctionInvoke) super.clone(); + clone.function = (GeneralFunction) function.clone(); + return clone; + } +} http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-plan/src/main/java/org/apache/tajo/plan/function/FunctionInvoke.java ---------------------------------------------------------------------- diff --git a/tajo-plan/src/main/java/org/apache/tajo/plan/function/FunctionInvoke.java b/tajo-plan/src/main/java/org/apache/tajo/plan/function/FunctionInvoke.java index 728ae10..b8b5cfe 100644 --- a/tajo-plan/src/main/java/org/apache/tajo/plan/function/FunctionInvoke.java +++ b/tajo-plan/src/main/java/org/apache/tajo/plan/function/FunctionInvoke.java @@ -23,10 +23,8 @@ import org.apache.tajo.catalog.FunctionDesc; import org.apache.tajo.datum.Datum; import org.apache.tajo.exception.InternalException; import org.apache.tajo.exception.UnsupportedException; -import org.apache.tajo.plan.expr.EvalContext; import org.apache.tajo.storage.Tuple; -import java.io.Closeable; import java.io.IOException; /** @@ -46,7 +44,7 @@ public abstract class FunctionInvoke implements Cloneable { public static FunctionInvoke newInstance(FunctionDesc desc) throws InternalException { if (desc.getInvocation().hasLegacy()) { - return new LegacyScalarFunctionInvoke(desc); + return new ClassBasedScalarFunctionInvoke(desc); } else if (desc.getInvocation().hasPython()) { return new PythonFunctionInvoke(desc); } else { http://git-wip-us.apache.org/repos/asf/tajo/blob/9540f16e/tajo-plan/src/main/java/org/apache/tajo/plan/function/FunctionInvokeContext.java ---------------------------------------------------------------------- diff --git a/tajo-plan/src/main/java/org/apache/tajo/plan/function/FunctionInvokeContext.java b/tajo-plan/src/main/java/org/apache/tajo/plan/function/FunctionInvokeContext.java index b938072..fb8dcea 100644 --- a/tajo-plan/src/main/java/org/apache/tajo/plan/function/FunctionInvokeContext.java +++ b/tajo-plan/src/main/java/org/apache/tajo/plan/function/FunctionInvokeContext.java @@ -20,20 +20,22 @@ package org.apache.tajo.plan.function; import com.google.common.base.Objects; import org.apache.tajo.OverridableConf; +import org.apache.tajo.annotation.Nullable; import org.apache.tajo.plan.expr.FunctionEval; import org.apache.tajo.plan.function.python.TajoScriptEngine; +import org.apache.tajo.util.TUtil; import java.util.Arrays; /** * This class contains some metadata need to execute functions. */ -public class FunctionInvokeContext { - private final OverridableConf queryContext; - private final FunctionEval.ParamType[] paramTypes; +public class FunctionInvokeContext implements Cloneable { + private OverridableConf queryContext; + private FunctionEval.ParamType[] paramTypes; private TajoScriptEngine scriptEngine; - public FunctionInvokeContext(OverridableConf queryContext, FunctionEval.ParamType[] paramTypes) { + public FunctionInvokeContext(@Nullable OverridableConf queryContext, FunctionEval.ParamType[] paramTypes) { this.queryContext = queryContext; this.paramTypes = paramTypes; } @@ -67,8 +69,20 @@ public class FunctionInvokeContext { public boolean equals(Object o) { if (o instanceof FunctionInvokeContext) { FunctionInvokeContext other = (FunctionInvokeContext) o; - return queryContext.equals(other.queryContext) && Arrays.equals(paramTypes, other.paramTypes); + return TUtil.checkEquals(queryContext, other.queryContext) && + Arrays.equals(paramTypes, other.paramTypes); } return false; } + + @Override + public Object clone() throws CloneNotSupportedException { + FunctionInvokeContext clone = (FunctionInvokeContext) super.clone(); + clone.queryContext = queryContext; + clone.paramTypes = Arrays.copyOf(paramTypes, paramTypes.length); + if (scriptEngine != null) { + clone.scriptEngine = scriptEngine; + } + return clone; + } }
