[ 
https://issues.apache.org/jira/browse/SPARK-22809?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Cricket Temple updated SPARK-22809:
-----------------------------------
    Description: 
User code can fail with dotted imports.  Here's a repro script.

{noformat}
import numpy as np
import pandas as pd
import pyspark
import scipy.interpolate
import scipy.interpolate as scipy_interpolate
import py4j

scipy_interpolate2 = scipy.interpolate

sc = pyspark.SparkContext()
spark_session = pyspark.SQLContext(sc)

#######################################################
# The details of this dataset are irrelevant          #
# Sorry if you'd have preferred something more boring #
#######################################################
x__ = np.linspace(0,10,1000)
freq__ = np.arange(1,5)
x_, freq_ = np.ix_(x__, freq__)
y = np.sin(x_ * freq_).ravel()
x = (x_ * np.ones(freq_.shape)).ravel()
freq = (np.ones(x_.shape) * freq_).ravel()
df_pd = pd.DataFrame(np.stack([x,y,freq]).T, columns=['x','y','freq'])
df_sk = spark_session.createDataFrame(df_pd)
assert(df_sk.toPandas() == df_pd).all().all()

try:
    import matplotlib.pyplot as plt
    for f, data in df_pd.groupby("freq"):
        plt.plot(*data[['x','y']].values.T)
    plt.show()
except:
    print("I guess we can't plot anything")

def mymap(x, interp_fn):
    df = pd.DataFrame.from_records([row.asDict() for row in list(x)])
    return interp_fn(df.x.values, df.y.values)(np.pi)

df_by_freq = df_sk.rdd.keyBy(lambda x: x.freq).groupByKey()

result = df_by_freq.mapValues(lambda x: mymap(x, 
scipy_interpolate.interp1d)).collect()
assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), atol=1e-6))

try:
    result = df_by_freq.mapValues(lambda x: mymap(x, 
scipy.interpolate.interp1d)).collect()
    raise Excpetion("Not going to reach this line")
except py4j.protocol.Py4JJavaError, e:
    print("See?")

result = df_by_freq.mapValues(lambda x: mymap(x, 
scipy_interpolate2.interp1d)).collect()
assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), atol=1e-6))

# But now it works!
result = df_by_freq.mapValues(lambda x: mymap(x, 
scipy.interpolate.interp1d)).collect()
assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), atol=1e-6))
{noformat}

  was:
User code can fail with dotted imports.  Here's a repro script.

{noformat}
import numpy as np
import pandas as pd
import pyspark
import scipy.interpolate
import scipy.interpolate as scipy_interpolate
import py4j

sc = pyspark.SparkContext()
spark_session = pyspark.SQLContext(sc)

#######################################################
# The details of this dataset are irrelevant          #
# Sorry if you'd have preferred something more boring #
#######################################################
x__ = np.linspace(0,10,1000)
freq__ = np.arange(1,5)
x_, freq_ = np.ix_(x__, freq__)
y = np.sin(x_ * freq_).ravel()
x = (x_ * np.ones(freq_.shape)).ravel()
freq = (np.ones(x_.shape) * freq_).ravel()
df_pd = pd.DataFrame(np.stack([x,y,freq]).T, columns=['x','y','freq'])
df_sk = spark_session.createDataFrame(df_pd)
assert(df_sk.toPandas() == df_pd).all().all()

try:
    import matplotlib.pyplot as plt
    for f, data in df_pd.groupby("freq"):
        plt.plot(*data[['x','y']].values.T)
    plt.show()
except:
    print("I guess we can't plot anything")

def mymap(x, interp_fn):
    df = pd.DataFrame.from_records([row.asDict() for row in list(x)])
    return interp_fn(df.x.values, df.y.values)(np.pi)

df_by_freq = df_sk.rdd.keyBy(lambda x: x.freq).groupByKey()

result = df_by_freq.mapValues(lambda x: mymap(x, 
scipy_interpolate.interp1d)).collect()
assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), atol=1e-6))
try:
    result = df_by_freq.mapValues(lambda x: mymap(x, 
scipy.interpolate.interp1d)).collect()
    assert(False, "Not going to reach this line")
except py4j.protocol.Py4JJavaError, e:
    print("See?")
{noformat}


> pyspark is sensitive to imports with dots
> -----------------------------------------
>
>                 Key: SPARK-22809
>                 URL: https://issues.apache.org/jira/browse/SPARK-22809
>             Project: Spark
>          Issue Type: Bug
>          Components: PySpark
>    Affects Versions: 2.2.0
>            Reporter: Cricket Temple
>
> User code can fail with dotted imports.  Here's a repro script.
> {noformat}
> import numpy as np
> import pandas as pd
> import pyspark
> import scipy.interpolate
> import scipy.interpolate as scipy_interpolate
> import py4j
> scipy_interpolate2 = scipy.interpolate
> sc = pyspark.SparkContext()
> spark_session = pyspark.SQLContext(sc)
> #######################################################
> # The details of this dataset are irrelevant          #
> # Sorry if you'd have preferred something more boring #
> #######################################################
> x__ = np.linspace(0,10,1000)
> freq__ = np.arange(1,5)
> x_, freq_ = np.ix_(x__, freq__)
> y = np.sin(x_ * freq_).ravel()
> x = (x_ * np.ones(freq_.shape)).ravel()
> freq = (np.ones(x_.shape) * freq_).ravel()
> df_pd = pd.DataFrame(np.stack([x,y,freq]).T, columns=['x','y','freq'])
> df_sk = spark_session.createDataFrame(df_pd)
> assert(df_sk.toPandas() == df_pd).all().all()
> try:
>     import matplotlib.pyplot as plt
>     for f, data in df_pd.groupby("freq"):
>         plt.plot(*data[['x','y']].values.T)
>     plt.show()
> except:
>     print("I guess we can't plot anything")
> def mymap(x, interp_fn):
>     df = pd.DataFrame.from_records([row.asDict() for row in list(x)])
>     return interp_fn(df.x.values, df.y.values)(np.pi)
> df_by_freq = df_sk.rdd.keyBy(lambda x: x.freq).groupByKey()
> result = df_by_freq.mapValues(lambda x: mymap(x, 
> scipy_interpolate.interp1d)).collect()
> assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), 
> atol=1e-6))
> try:
>     result = df_by_freq.mapValues(lambda x: mymap(x, 
> scipy.interpolate.interp1d)).collect()
>     raise Excpetion("Not going to reach this line")
> except py4j.protocol.Py4JJavaError, e:
>     print("See?")
> result = df_by_freq.mapValues(lambda x: mymap(x, 
> scipy_interpolate2.interp1d)).collect()
> assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), 
> atol=1e-6))
> # But now it works!
> result = df_by_freq.mapValues(lambda x: mymap(x, 
> scipy.interpolate.interp1d)).collect()
> assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), 
> atol=1e-6))
> {noformat}



--
This message was sent by Atlassian JIRA
(v6.4.14#64029)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to