Russell Jurney created DATAFU-150:
-------------------------------------
Summary: Add MultiLabelOneHotEncoder
Key: DATAFU-150
URL: https://issues.apache.org/jira/browse/DATAFU-150
Project: DataFu
Issue Type: Improvement
Reporter: Russell Jurney
Assignee: Russell Jurney
I have created the following code in Python to one-hot encode multilabel data
and would like to add it to DataFu:
{{
questions_tags = filtered_lists.map(lambda x: Row(_Body=x[0],
_Tags=x[1])).toDF()
# One-hot-encode the multilabel tags
enumerated_labels = [
z for z in enumerate(
sorted(
remaining_tags_df.rdd
.groupBy(lambda x: 1)
.flatMap(lambda x: [y.tag for y in x[1]])
.collect()
)
)
]
tag_index = {x: i for i, x in enumerated_labels}
index_tag = {i: x for i, x in enumerated_labels}
def one_hot_encode(tag_list, enumerated_labels):
"""PySpark can't one-hot-encode multilabel data, so we do it ourselves."""
one_hot_row = []
for i, label in enumerated_labels:
if index_tag[i] in tag_list:
one_hot_row.append(1)
else:
one_hot_row.append(0)
assert(len(one_hot_row) == len(enumerated_labels))
return one_hot_row
# Write the one-hot-encoded questions to S3 as a parquet file
one_hot_questions = questions_tags.rdd.map(
lambda x: Row(_Body=x._Body, _Tags=one_hot_encode(x._Tags,
enumerated_labels))
)
# Create a DataFrame for persisting as Parquet format
schema = T.StructType([
T.StructField("_Body", T.ArrayType(
T.StringType()
)),
T.StructField("_Tags", T.ArrayType(
T.IntegerType()
))
])
one_hot_df = spark.createDataFrame(
one_hot_questions,
schema
)
one_hot_df.show()
}}
--
This message was sent by Atlassian Jira
(v8.3.4#803005)