Github user icexelloss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19630#discussion_r150311184
  
    --- Diff: python/pyspark/sql/udf.py ---
    @@ -0,0 +1,136 @@
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +"""
    +User-defined function related classes and functions
    +"""
    +import functools
    +
    +from pyspark import SparkContext
    +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
    +from pyspark.sql.column import Column, _to_java_column, _to_seq
    +from pyspark.sql.types import StringType, DataType, _parse_datatype_string
    +
    +
    +def _wrap_function(sc, func, returnType):
    +    command = (func, returnType)
    +    pickled_command, broadcast_vars, env, includes = 
_prepare_for_python_RDD(sc, command)
    +    return sc._jvm.PythonFunction(bytearray(pickled_command), env, 
includes, sc.pythonExec,
    +                                  sc.pythonVer, broadcast_vars, 
sc._javaAccumulator)
    +
    +
    +def _create_udf(f, *, returnType, udfType):
    +    if udfType in (PythonEvalType.PANDAS_SCALAR_UDF, 
PythonEvalType.PANDAS_GROUP_FLATMAP_UDF):
    +        import inspect
    +        argspec = inspect.getargspec(f)
    +        if len(argspec.args) == 0 and argspec.varargs is None:
    +            raise ValueError(
    +                "0-arg pandas_udfs are not supported. "
    +                "Instead, create a 1-arg pandas_udf and ignore the arg in 
your function."
    +            )
    +    udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, 
udfType=udfType)
    +    return udf_obj._wrapped()
    +
    +
    +class UserDefinedFunction(object):
    +    """
    +    User defined function in Python
    +
    +    .. versionadded:: 1.3
    +    """
    +    def __init__(self, func,
    +                 returnType=StringType(), name=None,
    +                 udfType=PythonEvalType.SQL_BATCHED_UDF):
    +        if not callable(func):
    +            raise TypeError(
    +                "Not a function or callable (__call__ is not defined): "
    +                "{0}".format(type(func)))
    +
    +        self.func = func
    +        self._returnType = returnType
    +        # Stores UserDefinedPythonFunctions jobj, once initialized
    +        self._returnType_placeholder = None
    +        self._judf_placeholder = None
    +        self._name = name or (
    +            func.__name__ if hasattr(func, '__name__')
    +            else func.__class__.__name__)
    +        self.udfType = udfType
    +
    +
    +    @property
    +    def returnType(self):
    +        # This makes sure this is called after SparkContext is initialized.
    +        # ``_parse_datatype_string`` accesses to JVM for parsing a DDL 
formatted string.
    +        if self._returnType_placeholder is None:
    +            if isinstance(self._returnType, DataType):
    +                self._returnType_placeholder = self._returnType
    +            else:
    +                self._returnType_placeholder = 
_parse_datatype_string(self._returnType)
    +        return self._returnType_placeholder
    +
    +    @property
    +    def _judf(self):
    +        # It is possible that concurrent access, to newly created UDF,
    +        # will initialize multiple UserDefinedPythonFunctions.
    +        # This is unlikely, doesn't affect correctness,
    +        # and should have a minimal performance impact.
    +        if self._judf_placeholder is None:
    +            self._judf_placeholder = self._create_judf()
    +        return self._judf_placeholder
    +
    +    def _create_judf(self):
    +        from pyspark.sql import SparkSession
    +
    +        spark = SparkSession.builder.getOrCreate()
    +        sc = spark.sparkContext
    +
    +        wrapped_func = _wrap_function(sc, self.func, self.returnType)
    +        jdt = spark._jsparkSession.parseDataType(self.returnType.json())
    +        judf = 
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
    +            self._name, wrapped_func, jdt, self.udfType)
    +        return judf
    +
    +    def __call__(self, *cols):
    +        judf = self._judf
    +        sc = SparkContext._active_spark_context
    +        return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
    +
    +    def _wrapped(self):
    +        """
    +        Wrap this udf with a function and attach docstring from func
    +        """
    +
    +        # It is possible for a callable instance without __name__ 
attribute or/and
    +        # __module__ attribute to be wrapped here. For example, 
functools.partial. In this case,
    +        # we should avoid wrapping the attributes from the wrapped 
function to the wrapper
    +        # function. So, we take out these attribute names from the default 
names to set and
    +        # then manually assign it after being wrapped.
    +        assignments = tuple(
    +            a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' 
and a != '__module__')
    +
    +        @functools.wraps(self.func, assigned=assignments)
    +        def wrapper(*args):
    +            return self(*args)
    +
    +        wrapper.__name__ = self._name
    +        wrapper.__module__ = (self.func.__module__ if hasattr(self.func, 
'__module__')
    +                              else self.func.__class__.__module__)
    +
    +        wrapper.func = self.func
    +        wrapper.returnType = self.returnType
    +        wrapper.udfType = self.udfType
    +
    +        return wrapper
    --- End diff --
    
    Added


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to