Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/22365#discussion_r216233575
--- Diff: python/pyspark/sql/dataframe.py ---
@@ -880,18 +880,23 @@ def sampleBy(self, col, fractions, seed=None):
| 0| 5|
| 1| 9|
+---+-----+
+ >>> dataset.sampleBy(col("key"), fractions={2: 1.0},
seed=0).count()
+ 33
"""
- if not isinstance(col, basestring):
- raise ValueError("col must be a string, but got %r" %
type(col))
+ if isinstance(col, basestring):
+ col = Column(col)
+ elif not isinstance(col, Column):
+ raise ValueError("col must be a string or a column, but got
%r" % type(col))
if not isinstance(fractions, dict):
raise ValueError("fractions must be a dict but got %r" %
type(fractions))
for k, v in fractions.items():
if not isinstance(k, (float, int, long, basestring)):
raise ValueError("key must be float, int, long, or string,
but got %r" % type(k))
fractions[k] = float(v)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
- return DataFrame(self._jdf.stat().sampleBy(col,
self._jmap(fractions), seed), self.sql_ctx)
+ return DataFrame(self._jdf.stat()
+ .sampleBy(col._jc, self._jmap(fractions), seed),
self.sql_ctx)
--- End diff --
I would just do `col = col._jc`
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]