FYI

Hopeful other will find this example helpful

Andy

Example of a Trivial Custom PySpark Transformer
ref:
* 
* NLTKWordPunctTokenizer example
<http://stackoverflow.com/questions/32331848/create-a-custom-transformer-in-
pyspark-ml> 
* 
* pyspark.sql.functions.udf
<http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=u
df#pyspark.sql.functions.udf>

In [12]:
1
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param
2
from pyspark.ml.util import keyword_only
3
 
4
from pyspark.sql.functions import udf
5
from pyspark.sql.types import FloatType
6
from pyspark.ml.pipeline import Transformer
7
 
8
class TrivialTokenizer(Transformer, HasInputCol, HasOutputCol):
9
 
10
    @keyword_only
11
    def __init__(self, inputCol=None, outputCol=None, constant=None):
12
        super(TrivialTokenizer, self).__init__()
13
        self.constant = Param(self, "constant", 0)
14
        self._setDefault(constant=0)
15
        kwargs = self.__init__._input_kwargs
16
        self.setParams(**kwargs)
17
 
18
    @keyword_only
19
    def setParams(self, inputCol=None, outputCol=None, constant=None):
20
        kwargs = self.setParams._input_kwargs
21
        return self._set(**kwargs)
22
 
23
    def setConstant(self, value):
24
        self._paramMap[self.constant] = value
25
        return self
26
 
27
    def getConstant(self):
28
        return self.getOrDefault(self.constant)
29
 
30
    def _transform(self, dataset):
31
        const = self.getConstant()
32
 
33
        def f(v):
34
            return v + const
35
 
36
        t = FloatType()
37
        out_col = self.getOutputCol()
38
        in_col = dataset[self.getInputCol()]
39
        return dataset.withColumn(out_col, udf(f, t)(in_col))
40
    
41
sentenceDataFrame = sqlContext.createDataFrame([
42
  (0, 1.1, "Hi I heard who the about Spark"),
43
  (0, 1.2, "I wish Java could use case classes"),
44
  (1, 1.3, "Logistic regression models are neat")
45
], ["label", "x1", "sentence"])
46
 
47
testTokenizer = TrivialTokenizer(inputCol="x1", outputCol="x2",
constant=1.0) 
48
 
49
testTokenizer.transform(sentenceDataFrame).show()
+-----+---+--------------------+---+
|label| x1|            sentence| x2|
+-----+---+--------------------+---+
|    0|1.1|Hi I heard who th...|2.1|
|    0|1.2|I wish Java could...|2.2|
|    1|1.3|Logistic regressi...|2.3|
+-----+---+--------------------+---+


Reply via email to