Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/19630#discussion_r149873461
--- Diff: python/pyspark/sql/udf.py ---
@@ -0,0 +1,136 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+User-defined function related classes and functions
+"""
+import functools
+
+from pyspark import SparkContext
+from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
+from pyspark.sql.column import Column, _to_java_column, _to_seq
+from pyspark.sql.types import StringType, DataType, _parse_datatype_string
+
+
+def _wrap_function(sc, func, returnType):
+ command = (func, returnType)
+ pickled_command, broadcast_vars, env, includes =
_prepare_for_python_RDD(sc, command)
+ return sc._jvm.PythonFunction(bytearray(pickled_command), env,
includes, sc.pythonExec,
+ sc.pythonVer, broadcast_vars,
sc._javaAccumulator)
+
+
+def _create_udf(f, *, returnType, udfType):
+ if udfType in (PythonEvalType.PANDAS_SCALAR_UDF,
PythonEvalType.PANDAS_GROUP_FLATMAP_UDF):
+ import inspect
+ argspec = inspect.getargspec(f)
+ if len(argspec.args) == 0 and argspec.varargs is None:
+ raise ValueError(
+ "0-arg pandas_udfs are not supported. "
+ "Instead, create a 1-arg pandas_udf and ignore the arg in
your function."
+ )
+ udf_obj = UserDefinedFunction(f, returnType=returnType, name=None,
udfType=udfType)
+ return udf_obj._wrapped()
+
+
+class UserDefinedFunction(object):
+ """
+ User defined function in Python
+
+ .. versionadded:: 1.3
+ """
+ def __init__(self, func,
+ returnType=StringType(), name=None,
+ udfType=PythonEvalType.SQL_BATCHED_UDF):
+ if not callable(func):
+ raise TypeError(
+ "Not a function or callable (__call__ is not defined): "
+ "{0}".format(type(func)))
+
+ self.func = func
+ self._returnType = returnType
+ # Stores UserDefinedPythonFunctions jobj, once initialized
+ self._returnType_placeholder = None
+ self._judf_placeholder = None
+ self._name = name or (
+ func.__name__ if hasattr(func, '__name__')
+ else func.__class__.__name__)
+ self.udfType = udfType
+
+
+ @property
+ def returnType(self):
+ # This makes sure this is called after SparkContext is initialized.
+ # ``_parse_datatype_string`` accesses to JVM for parsing a DDL
formatted string.
+ if self._returnType_placeholder is None:
+ if isinstance(self._returnType, DataType):
+ self._returnType_placeholder = self._returnType
+ else:
+ self._returnType_placeholder =
_parse_datatype_string(self._returnType)
+ return self._returnType_placeholder
+
+ @property
+ def _judf(self):
+ # It is possible that concurrent access, to newly created UDF,
+ # will initialize multiple UserDefinedPythonFunctions.
+ # This is unlikely, doesn't affect correctness,
+ # and should have a minimal performance impact.
+ if self._judf_placeholder is None:
+ self._judf_placeholder = self._create_judf()
+ return self._judf_placeholder
+
+ def _create_judf(self):
+ from pyspark.sql import SparkSession
+
+ spark = SparkSession.builder.getOrCreate()
+ sc = spark.sparkContext
+
+ wrapped_func = _wrap_function(sc, self.func, self.returnType)
+ jdt = spark._jsparkSession.parseDataType(self.returnType.json())
+ judf =
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
+ self._name, wrapped_func, jdt, self.udfType)
+ return judf
+
+ def __call__(self, *cols):
+ judf = self._judf
+ sc = SparkContext._active_spark_context
+ return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
+
+ def _wrapped(self):
+ """
+ Wrap this udf with a function and attach docstring from func
+ """
+
+ # It is possible for a callable instance without __name__
attribute or/and
+ # __module__ attribute to be wrapped here. For example,
functools.partial. In this case,
+ # we should avoid wrapping the attributes from the wrapped
function to the wrapper
+ # function. So, we take out these attribute names from the default
names to set and
+ # then manually assign it after being wrapped.
+ assignments = tuple(
+ a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__'
and a != '__module__')
+
+ @functools.wraps(self.func, assigned=assignments)
+ def wrapper(*args):
+ return self(*args)
+
+ wrapper.__name__ = self._name
+ wrapper.__module__ = (self.func.__module__ if hasattr(self.func,
'__module__')
+ else self.func.__class__.__module__)
+
+ wrapper.func = self.func
+ wrapper.returnType = self.returnType
+ wrapper.udfType = self.udfType
+
+ return wrapper
--- End diff --
Shall we add a blank line at the end of file?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]