westonpace commented on a change in pull request #11223:
URL: https://github.com/apache/arrow/pull/11223#discussion_r717532248



##########
File path: python/pyarrow/tests/test_compute_kernels.py
##########
@@ -0,0 +1,1237 @@
+import re
+from abc import abstractmethod, ABC
+from typing import List, Dict, Set, Tuple
+
+import pytest
+
+import pyarrow as pa
+import pyarrow.compute as pc
+
+
+def sample_integral_types():
+    return [
+        pa.int8(),
+        pa.int16(),
+        pa.int32(),
+        pa.int64(),
+        pa.uint8(),
+        pa.uint16(),
+        pa.uint32(),
+        pa.uint64(),
+    ]
+
+
+def sample_signed_integral_types():
+    return [
+        pa.int8(),
+        pa.int16(),
+        pa.int32(),
+        pa.int64()
+    ]
+
+
+def sample_simple_float_types():
+    return [
+        pa.float32(),
+        pa.float64()
+    ]
+
+
+def sample_decimal_types():
+    return [
+        pa.decimal128(7, 3),
+        pa.decimal128(10, 4)
+    ]
+
+
+def sample_float_types():
+    return [
+        pa.float32(),
+        pa.float64(),
+        pa.decimal128(7, 3),
+        pa.decimal128(10, 4)
+    ]
+
+
+def sample_simple_numeric_types():
+    return sample_integral_types() + sample_simple_float_types()
+
+
+def sample_numeric_types():
+    return sample_integral_types() + sample_float_types()
+
+
+def sample_signed_numeric_types():
+    return sample_signed_integral_types() + sample_float_types()
+
+
+def sample_timestamp_no_tz_types():
+    return [
+        pa.timestamp('s'),
+        pa.timestamp('ms'),
+        pa.timestamp('us'),
+        pa.timestamp('ns')
+    ]
+
+
+def sample_timestamptz_types():
+    return [
+        pa.timestamp('s', 'America/New_York'),
+        pa.timestamp('ms', 'America/New_York'),
+        pa.timestamp('us', 'America/New_York'),
+        pa.timestamp('ns', 'America/New_York'),
+        pa.timestamp('s', 'UTC'),
+        pa.timestamp('ms', 'UTC'),
+        pa.timestamp('us', 'UTC'),
+        pa.timestamp('ns', 'UTC')
+    ]
+
+
+def sample_timestamp_types():
+    return sample_timestamptz_types() + sample_timestamp_no_tz_types()
+
+
+def sample_date_only_types():
+    return [
+        pa.date32(),
+        pa.date64()
+    ]
+
+
+def sample_date_types():
+    return sample_date_only_types() + sample_timestamp_types()
+
+
+def sample_time_only_types():
+    return [
+        pa.time32('s'),
+        pa.time32('ms'),
+        pa.time64('us'),
+        pa.time64('ns')
+    ]
+
+
+def sample_time_types():
+    return sample_time_only_types() + sample_timestamp_types()
+
+
+def sample_temporal_types():
+    return sample_date_only_types() + \
+        sample_time_only_types() + \
+        sample_timestamp_types()
+
+
+def sample_logical_types():
+    return [pa.bool_()]
+
+
+def sample_bytes_types():
+    return [
+        pa.binary(),
+        pa.binary(32),
+        pa.large_binary(),
+        pa.string(),
+        pa.large_string()
+    ]
+
+
+def sample_fixed_bytes_types():
+    return [
+        pa.binary(32),
+    ]
+
+
+def sample_string_types():
+    return [
+        pa.string(),
+        pa.large_string()
+    ]
+
+
+def sample_primitive_types():
+    return sample_numeric_types() + \
+        sample_temporal_types() + \
+        sample_timestamp_types() + \
+        sample_bytes_types()
+
+
+def __listify_types(types):
+    return [pa.list_(t) for t in types] + [pa.list_(t, 32) for t in types] + [
+        pa.large_list(t) for t in types]
+
+
+def __structify_types(types):
+    return [pa.struct([pa.field('data', t)]) for t in types]
+
+
+def sample_sortable_types():
+    return sample_primitive_types()
+
+
+def sample_list_types():
+    return __listify_types(sample_primitive_types() + [pa.null()])
+
+
+def sample_struct_types():
+    return __structify_types(sample_primitive_types() + [pa.null()])
+
+
+def sample_all_types():
+    return sample_primitive_types() + \
+        sample_list_types() + \
+        sample_struct_types()
+
+
+type_categories = {
+    'boolean': sample_logical_types(),
+    'bytes': sample_bytes_types(),
+    'date': sample_date_only_types(),
+    'datelike': sample_date_types(),
+    'decimal': sample_decimal_types(),
+    'equatable': sample_sortable_types(),
+    'fixed_bytes': sample_fixed_bytes_types(),
+    'floating': sample_float_types(),
+    'integral': sample_integral_types(),
+    'list': sample_list_types(),
+    'logical': sample_logical_types(),
+    'null': [],
+    'numeric': sample_numeric_types(),
+    'signed_numeric': sample_signed_numeric_types(),
+    'simple_numeric': sample_simple_numeric_types(),
+    'sortable': sample_sortable_types(),
+    'string': sample_string_types(),
+    'struct': sample_struct_types(),
+    'temporal': sample_temporal_types(),
+    'time': sample_time_only_types(),
+    'timelike': sample_time_types(),
+    'timestamp': sample_timestamp_no_tz_types(),
+    'timestamptz': sample_timestamptz_types(),
+    'timestamp_all': sample_timestamp_types(),
+}
+
+
+def get_sample_types(category):
+    types = type_categories.get(category, None)
+    if types is None:
+        raise Exception(f'Unrecognized type category {category}')
+    return types + [pa.null()]
+
+
+class DynamicParameter(ABC):
+
+    def __init__(self, key: str):
+        self.key = key
+
+    @abstractmethod
+    def compute_type(self, parameters_map: Dict[str, pa.DataType]):
+        pass
+
+
+class DecimalDynamicParameter(DynamicParameter):
+
+    def __init__(self, key, left_name, right_name):
+        super(DecimalDynamicParameter, self).__init__(key)
+        self.left_name = left_name
+        self.right_name = right_name
+
+    def _ensure_decimal(self, type_):
+        if not pa.types.is_decimal(type_):
+            raise Exception(
+                'DECIMAL_* type function was used for a type '
+                f'{type_} which is not decimal')
+
+    def compute_type(self, parameters_map):
+        left_type = parameters_map[self.left_name]
+        right_type = parameters_map[self.right_name]
+        if pa.types.is_null(left_type):
+            return right_type
+        elif pa.types.is_null(right_type):
+            return left_type
+        self._ensure_decimal(left_type)
+        self._ensure_decimal(right_type)
+        scale, precision = self._do_compute(left_type.scale,
+                                            left_type.precision,
+                                            right_type.scale,
+                                            right_type.precision)
+        if precision <= 38 and pa.types.is_decimal128(
+                left_type) and pa.types.is_decimal128(right_type):
+            return pa.decimal128(precision, scale)
+        else:
+            return pa.decimal256(precision, scale)
+
+    @abstractmethod
+    def _do_compute(self, s1, p1, s2, p2):
+        pass
+
+
+class DecimalAddDynamicParameter(DecimalDynamicParameter):
+
+    def __init__(self, key, left_name, right_name):
+        super(DecimalAddDynamicParameter, self).__init__(
+            key, left_name, right_name)
+
+    def _do_compute(self, s1, p1, s2, p2):
+        scale = max(s1, s2)
+        precision = max(p1 - s1, p2 - s2) + scale + 1
+        return scale, precision
+
+
+class DecimalMultiplyDynamicParameter(DecimalDynamicParameter):
+
+    def __init__(self, key, left_name, right_name):
+        super(DecimalMultiplyDynamicParameter, self).__init__(
+            key, left_name, right_name)
+
+    def _do_compute(self, s1, p1, s2, p2):
+        scale = s1 + s2
+        precision = p1 + p2 + 1
+        return scale, precision
+
+
+class DecimalDivideDynamicParameter(DecimalDynamicParameter):
+
+    def __init__(self, key, left_name, right_name):
+        super(DecimalDivideDynamicParameter, self).__init__(
+            key, left_name, right_name)
+
+    def _do_compute(self, s1, p1, s2, p2):
+        scale = max(4, s1 + p2 - s2 + 1)
+        precision = p1 - s1 + s2 + scale
+        return scale, precision
+
+
+class StructifyDynamicParameter(DynamicParameter):
+
+    def __init__(self, key):
+        super(StructifyDynamicParameter, self).__init__(key)
+
+    def compute_type(self, parameters_map):
+        fields = [pa.field(key, value)
+                  for key, value in parameters_map.items()]
+        return pa.struct(fields)
+
+
+class WithTzParameter(DynamicParameter):
+
+    def __init__(self, key, source_name):
+        super(WithTzParameter, self).__init__(key)
+        self.name = source_name
+
+    def compute_type(self, parameters_map: Dict[str, pa.DataType]):
+        src_type = parameters_map[self.name]
+        if pa.types.is_null(src_type):
+            return pa.null()
+        return pa.timestamp(src_type.unit, 'UTC')
+
+
+dynamic_parameter_types = {
+    'DECIMAL_ADD': DecimalAddDynamicParameter,
+    'DECIMAL_MULTIPLY': DecimalMultiplyDynamicParameter,
+    'DECIMAL_DIVIDE': DecimalDivideDynamicParameter,
+    'STRUCTIFY': StructifyDynamicParameter,
+    'WITH_TZ': WithTzParameter
+}
+
+
+class ConstrainedParameter(ABC):
+
+    def __init__(self, key: str):
+        self.key = key
+
+    @abstractmethod
+    def sample(self, parameters_map: Dict[str, pa.DataType]) -> List[
+            pa.DataType]:
+        pass
+
+    @abstractmethod
+    def satisfied_with(self, data_type: pa.DataType,
+                       parameters_map: Dict[str, pa.DataType]) -> bool:
+        pass
+
+
+class IsListOfGivenType(ConstrainedParameter):
+
+    def __init__(self, key, parameter_name):
+        super(IsListOfGivenType, self).__init__(key)
+        self.name = parameter_name
+
+    def sample(self, parameters_map):
+        type_ = parameters_map[self.name]
+        return [
+            pa.list_(type_),
+            pa.list_(type_, 32)
+        ]
+
+    def satisfied_with(self, data_type: pa.DataType,
+                       parameters_map: Dict[str, pa.DataType]) -> bool:
+        if not pa.types.is_list(data_type):
+            return False
+        if self.name not in parameters_map:
+            return False
+        target_type = parameters_map[self.name]
+        return target_type == data_type.value_type
+
+
+class IsFixedSizeListOfGivenType(ConstrainedParameter):
+
+    def __init__(self, key, parameter_name):
+        super(IsFixedSizeListOfGivenType, self).__init__(key)
+        self.name = parameter_name
+
+    def sample(self, parameters_map):
+        type_ = parameters_map[self.name]
+        return [
+            pa.list_(type_, 32)
+        ]
+
+    def satisfied_with(self, data_type: pa.DataType,
+                       parameters_map: Dict[str, pa.DataType]) -> bool:
+        if not pa.types.is_fixed_size_list(data_type):
+            return False
+        if self.name not in parameters_map:
+            return False
+        target_type = parameters_map[self.name]
+        return target_type == data_type.value_type
+
+
+class IsCaseWhen(ConstrainedParameter):
+
+    def __init__(self, key, *args):
+        super(IsCaseWhen, self).__init__(key)
+
+    def sample(self, parameters_map):
+        fields = []
+        for idx in range(len(parameters_map)):
+            fields.append(pa.field(f'f{idx}', pa.bool_()))
+        return [
+            pa.struct(fields)
+        ]
+
+    def satisfied_with(self, data_type: pa.DataType,
+                       parameters_map: Dict[str, pa.DataType]) -> bool:
+        if not pa.types.is_struct(data_type):
+            return False
+        for field in data_type:
+            if not pa.types.is_boolean(field):
+                return False
+        return True
+
+
+condition_types = {
+    'LIST': IsListOfGivenType,
+    'FIXED_SIZE_LIST': IsFixedSizeListOfGivenType,
+    'CASE_WHEN': IsCaseWhen
+}
+
+
+class InSetOfTypes(ConstrainedParameter):
+
+    def __init__(self, key, example_types):
+        super(InSetOfTypes, self).__init__(key)
+        self.example_types = example_types
+
+    def sample(self, _):
+        return self.example_types
+
+    def satisfied_with(self, data_type: pa.DataType,
+                       parameters_map: Dict[str, pa.DataType]) -> bool:
+        return data_type in self.example_types
+
+
+class IsAnyType(ConstrainedParameter):
+
+    def __init__(self, key):
+        super(IsAnyType, self).__init__(key)
+
+    def sample(self, _):
+        return sample_all_types()
+
+    def satisfied_with(self, data_type: pa.DataType,
+                       parameters_map: Dict[str, pa.DataType]) -> bool:
+        return True
+
+
+def parse_parameter_condition_func(key, value):
+    func_name = value[1:value.index('(')].upper()
+    func_args = value[value.index('(') + 1:value.index(')')].split(',')
+    if func_name in condition_types:
+        condition_type = condition_types[func_name]
+        return condition_type(key, *func_args)
+    else:
+        raise Exception(
+            f'Unrecognized parameter condition function ({func_name}) on '
+            f'key {key}')
+
+
+def parse_parameter_condition_typed(key, value):
+    if value:
+        sample_types = get_sample_types(value)
+        return InSetOfTypes(key, sample_types)
+    else:
+        return IsAnyType(key)
+
+
+def parse_parameter_condition(pstr):
+    key, _, value = pstr.partition(':')
+    if value.startswith('~'):
+        return parse_parameter_condition_func(key, value)
+    else:
+        return parse_parameter_condition_typed(key, value)
+
+
+def parse_dynamic_parameter(pstr):
+    key, _, value = pstr.partition('=')
+    func_name = value[0:value.index('(')].upper()
+    func_args = value[value.index('(') + 1:value.index(')')].split('|')
+    func_args = [arg for arg in func_args if '...' not in arg]
+    if func_name.upper() in dynamic_parameter_types:
+        dynamic_parameter_type = dynamic_parameter_types[func_name]
+        return dynamic_parameter_type(key, *func_args)
+    else:
+        raise Exception(
+            f"Unrecognized dynamic parameter function {func_name} for '"
+            f"parameter {key}")
+
+
+def parse_parameters_string(parameters_str):
+    parameters_str = parameters_str[1:-1]
+    parameter_strings = parameters_str.split(',')
+    conditions = [parse_parameter_condition(
+        pstr) for pstr in parameter_strings if '=' not in pstr]
+    dynamic_parameters = [parse_dynamic_parameter(
+        pstr) for pstr in parameter_strings if '=' in pstr]
+    return conditions, dynamic_parameters
+
+
+class FunctionSignatureArg:
+
+    def __init__(self, key, variadic):
+        self.key = key
+        self.variadic = variadic
+
+
+class FunctionSignature:

Review comment:
       Switched, although this was the only class that was a pure fit.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to