[
https://issues.apache.org/jira/browse/SPARK-22809?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16336979#comment-16336979
]
holdenk commented on SPARK-22809:
---------------------------------
[~ueshin]: we can push this out to 2.3.1 given we are already in the RC process
for 2.3.0. I'm not sure if this is fixed in the 0.4.2 upgrade (I think it would
be in 0.5.2 but we're only partially upgrading). Already looking at the upgrade
PR.
> pyspark is sensitive to imports with dots
> -----------------------------------------
>
> Key: SPARK-22809
> URL: https://issues.apache.org/jira/browse/SPARK-22809
> Project: Spark
> Issue Type: Bug
> Components: PySpark
> Affects Versions: 2.2.0, 2.2.1
> Reporter: Cricket Temple
> Assignee: holdenk
> Priority: Major
>
> User code can fail with dotted imports. Here's a repro script.
> {noformat}
> import numpy as np
> import pandas as pd
> import pyspark
> import scipy.interpolate
> import scipy.interpolate as scipy_interpolate
> import py4j
> scipy_interpolate2 = scipy.interpolate
> sc = pyspark.SparkContext()
> spark_session = pyspark.SQLContext(sc)
> #######################################################
> # The details of this dataset are irrelevant #
> # Sorry if you'd have preferred something more boring #
> #######################################################
> x__ = np.linspace(0,10,1000)
> freq__ = np.arange(1,5)
> x_, freq_ = np.ix_(x__, freq__)
> y = np.sin(x_ * freq_).ravel()
> x = (x_ * np.ones(freq_.shape)).ravel()
> freq = (np.ones(x_.shape) * freq_).ravel()
> df_pd = pd.DataFrame(np.stack([x,y,freq]).T, columns=['x','y','freq'])
> df_sk = spark_session.createDataFrame(df_pd)
> assert(df_sk.toPandas() == df_pd).all().all()
> try:
> import matplotlib.pyplot as plt
> for f, data in df_pd.groupby("freq"):
> plt.plot(*data[['x','y']].values.T)
> plt.show()
> except:
> print("I guess we can't plot anything")
> def mymap(x, interp_fn):
> df = pd.DataFrame.from_records([row.asDict() for row in list(x)])
> return interp_fn(df.x.values, df.y.values)(np.pi)
> df_by_freq = df_sk.rdd.keyBy(lambda x: x.freq).groupByKey()
> result = df_by_freq.mapValues(lambda x: mymap(x,
> scipy_interpolate.interp1d)).collect()
> assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)),
> atol=1e-6))
> try:
> result = df_by_freq.mapValues(lambda x: mymap(x,
> scipy.interpolate.interp1d)).collect()
> raise Excpetion("Not going to reach this line")
> except py4j.protocol.Py4JJavaError, e:
> print("See?")
> result = df_by_freq.mapValues(lambda x: mymap(x,
> scipy_interpolate2.interp1d)).collect()
> assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)),
> atol=1e-6))
> # But now it works!
> result = df_by_freq.mapValues(lambda x: mymap(x,
> scipy.interpolate.interp1d)).collect()
> assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)),
> atol=1e-6))
> {noformat}
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]