[ https://issues.apache.org/jira/browse/SPARK-22809?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Cricket Temple updated SPARK-22809: ----------------------------------- Description: User code can fail with dotted imports. Here's a repro script. {noformat} import numpy as np import pandas as pd import pyspark import scipy.interpolate import scipy.interpolate as scipy_interpolate import py4j scipy_interpolate2 = scipy.interpolate sc = pyspark.SparkContext() spark_session = pyspark.SQLContext(sc) ####################################################### # The details of this dataset are irrelevant # # Sorry if you'd have preferred something more boring # ####################################################### x__ = np.linspace(0,10,1000) freq__ = np.arange(1,5) x_, freq_ = np.ix_(x__, freq__) y = np.sin(x_ * freq_).ravel() x = (x_ * np.ones(freq_.shape)).ravel() freq = (np.ones(x_.shape) * freq_).ravel() df_pd = pd.DataFrame(np.stack([x,y,freq]).T, columns=['x','y','freq']) df_sk = spark_session.createDataFrame(df_pd) assert(df_sk.toPandas() == df_pd).all().all() try: import matplotlib.pyplot as plt for f, data in df_pd.groupby("freq"): plt.plot(*data[['x','y']].values.T) plt.show() except: print("I guess we can't plot anything") def mymap(x, interp_fn): df = pd.DataFrame.from_records([row.asDict() for row in list(x)]) return interp_fn(df.x.values, df.y.values)(np.pi) df_by_freq = df_sk.rdd.keyBy(lambda x: x.freq).groupByKey() result = df_by_freq.mapValues(lambda x: mymap(x, scipy_interpolate.interp1d)).collect() assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), atol=1e-6)) try: result = df_by_freq.mapValues(lambda x: mymap(x, scipy.interpolate.interp1d)).collect() raise Excpetion("Not going to reach this line") except py4j.protocol.Py4JJavaError, e: print("See?") result = df_by_freq.mapValues(lambda x: mymap(x, scipy_interpolate2.interp1d)).collect() assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), atol=1e-6)) # But now it works! result = df_by_freq.mapValues(lambda x: mymap(x, scipy.interpolate.interp1d)).collect() assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), atol=1e-6)) {noformat} was: User code can fail with dotted imports. Here's a repro script. {noformat} import numpy as np import pandas as pd import pyspark import scipy.interpolate import scipy.interpolate as scipy_interpolate import py4j sc = pyspark.SparkContext() spark_session = pyspark.SQLContext(sc) ####################################################### # The details of this dataset are irrelevant # # Sorry if you'd have preferred something more boring # ####################################################### x__ = np.linspace(0,10,1000) freq__ = np.arange(1,5) x_, freq_ = np.ix_(x__, freq__) y = np.sin(x_ * freq_).ravel() x = (x_ * np.ones(freq_.shape)).ravel() freq = (np.ones(x_.shape) * freq_).ravel() df_pd = pd.DataFrame(np.stack([x,y,freq]).T, columns=['x','y','freq']) df_sk = spark_session.createDataFrame(df_pd) assert(df_sk.toPandas() == df_pd).all().all() try: import matplotlib.pyplot as plt for f, data in df_pd.groupby("freq"): plt.plot(*data[['x','y']].values.T) plt.show() except: print("I guess we can't plot anything") def mymap(x, interp_fn): df = pd.DataFrame.from_records([row.asDict() for row in list(x)]) return interp_fn(df.x.values, df.y.values)(np.pi) df_by_freq = df_sk.rdd.keyBy(lambda x: x.freq).groupByKey() result = df_by_freq.mapValues(lambda x: mymap(x, scipy_interpolate.interp1d)).collect() assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), atol=1e-6)) try: result = df_by_freq.mapValues(lambda x: mymap(x, scipy.interpolate.interp1d)).collect() assert(False, "Not going to reach this line") except py4j.protocol.Py4JJavaError, e: print("See?") {noformat} > pyspark is sensitive to imports with dots > ----------------------------------------- > > Key: SPARK-22809 > URL: https://issues.apache.org/jira/browse/SPARK-22809 > Project: Spark > Issue Type: Bug > Components: PySpark > Affects Versions: 2.2.0 > Reporter: Cricket Temple > > User code can fail with dotted imports. Here's a repro script. > {noformat} > import numpy as np > import pandas as pd > import pyspark > import scipy.interpolate > import scipy.interpolate as scipy_interpolate > import py4j > scipy_interpolate2 = scipy.interpolate > sc = pyspark.SparkContext() > spark_session = pyspark.SQLContext(sc) > ####################################################### > # The details of this dataset are irrelevant # > # Sorry if you'd have preferred something more boring # > ####################################################### > x__ = np.linspace(0,10,1000) > freq__ = np.arange(1,5) > x_, freq_ = np.ix_(x__, freq__) > y = np.sin(x_ * freq_).ravel() > x = (x_ * np.ones(freq_.shape)).ravel() > freq = (np.ones(x_.shape) * freq_).ravel() > df_pd = pd.DataFrame(np.stack([x,y,freq]).T, columns=['x','y','freq']) > df_sk = spark_session.createDataFrame(df_pd) > assert(df_sk.toPandas() == df_pd).all().all() > try: > import matplotlib.pyplot as plt > for f, data in df_pd.groupby("freq"): > plt.plot(*data[['x','y']].values.T) > plt.show() > except: > print("I guess we can't plot anything") > def mymap(x, interp_fn): > df = pd.DataFrame.from_records([row.asDict() for row in list(x)]) > return interp_fn(df.x.values, df.y.values)(np.pi) > df_by_freq = df_sk.rdd.keyBy(lambda x: x.freq).groupByKey() > result = df_by_freq.mapValues(lambda x: mymap(x, > scipy_interpolate.interp1d)).collect() > assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), > atol=1e-6)) > try: > result = df_by_freq.mapValues(lambda x: mymap(x, > scipy.interpolate.interp1d)).collect() > raise Excpetion("Not going to reach this line") > except py4j.protocol.Py4JJavaError, e: > print("See?") > result = df_by_freq.mapValues(lambda x: mymap(x, > scipy_interpolate2.interp1d)).collect() > assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), > atol=1e-6)) > # But now it works! > result = df_by_freq.mapValues(lambda x: mymap(x, > scipy.interpolate.interp1d)).collect() > assert(np.allclose(np.array(zip(*result)[1]), np.zeros(len(freq__)), > atol=1e-6)) > {noformat} -- This message was sent by Atlassian JIRA (v6.4.14#64029) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org