This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6ad80fd0368 [SPARK-46084][PS] Refactor data type casting operation for
Categorical type
6ad80fd0368 is described below
commit 6ad80fd036834e7291c572335e318096781a7ae4
Author: Haejoon Lee <[email protected]>
AuthorDate: Thu Nov 23 22:05:52 2023 -0800
[SPARK-46084][PS] Refactor data type casting operation for Categorical type
### What changes were proposed in this pull request?
The PR proposes to refactor data type casting operation - especially
`DataTypeOps.astype` - for Categorical type.
### Why are the changes needed?
To optimize performance/debuggability/readability by using official API. We
can leverage the PySpark API `coalesce` and `create_map `, instead of
implementing Python code from scratch.
### Does this PR introduce _any_ user-facing change?
No, it's internal optimization.
### How was this patch tested?
The existing CI should pass.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43993 from itholic/refactor_cat.
Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/pandas/data_type_ops/base.py | 26 ++++++--------------------
1 file changed, 6 insertions(+), 20 deletions(-)
diff --git a/python/pyspark/pandas/data_type_ops/base.py
b/python/pyspark/pandas/data_type_ops/base.py
index 4f57aa65be7..5a4cd7a1eb0 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -18,6 +18,7 @@
import numbers
from abc import ABCMeta
from typing import Any, Optional, Union
+from itertools import chain
import numpy as np
import pandas as pd
@@ -129,26 +130,11 @@ def _as_categorical_type(
if len(categories) == 0:
scol = F.lit(-1)
else:
- scol = F.lit(-1)
- if isinstance(
-
index_ops._internal.spark_type_for(index_ops._internal.column_labels[0]),
BinaryType
- ):
- from pyspark.sql.functions import base64
-
- stringified_column = base64(index_ops.spark.column)
- for code, category in enumerate(categories):
- # Convert each category to base64 before comparison
- base64_category = F.base64(F.lit(category))
- scol = F.when(stringified_column == base64_category,
F.lit(code)).otherwise(
- scol
- )
- else:
- stringified_column = F.format_string("%s",
index_ops.spark.column)
-
- for code, category in enumerate(categories):
- scol = F.when(stringified_column == F.lit(category),
F.lit(code)).otherwise(
- scol
- )
+ kvs = chain(
+ *[(F.lit(category), F.lit(code)) for code, category in
enumerate(categories)]
+ )
+ map_scol = F.create_map(*kvs)
+ scol = F.coalesce(map_scol[index_ops.spark.column], F.lit(-1))
return index_ops._with_new_scol(
scol.cast(spark_type),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]