Author: Carl Friedrich Bolz-Tereick <[email protected]>
Branch: py3.6
Changeset: r97749:6d2f8470165b
Date: 2019-10-09 15:58 +0200
http://bitbucket.org/pypy/pypy/changeset/6d2f8470165b/

Log:    switch to Utf8StringBuilder for the csv writer as well

diff --git a/pypy/module/_csv/interp_csv.py b/pypy/module/_csv/interp_csv.py
--- a/pypy/module/_csv/interp_csv.py
+++ b/pypy/module/_csv/interp_csv.py
@@ -1,3 +1,4 @@
+from rpython.rlib import rutf8
 from pypy.interpreter.baseobjspace import W_Root
 from pypy.interpreter.error import OperationError, oefmt
 from pypy.interpreter.typedef import TypeDef, interp_attrproperty
@@ -47,24 +48,26 @@
     if w_src is None:
         return default
     try:
-        return space.realunicode_w(w_src)
+        return space.text_w(w_src)
     except OperationError as e:
         if e.match(space, space.w_TypeError):
             raise oefmt(space.w_TypeError, '"%s" must be a string', attrname)
         raise
 
-def _get_char(space, w_src, default, name):
+def _get_codepoint(space, w_src, default, name):
     if w_src is None:
         return default
     if space.is_w(w_src, space.w_None):
-        return u'\0'
+        return 0
     if not space.isinstance_w(w_src, space.w_unicode):
         raise oefmt(space.w_TypeError, '"%s" must be string, not %T', name, 
w_src)
-    src = space.realunicode_w(w_src)
-    if len(src) == 1:
-        return src[0]
+    src, length = space.utf8_len_w(w_src)
+    if length == 1:
+        res = rutf8.codepoint_at_pos(src, 0)
+        assert res >= 0
+        return res
     if len(src) == 0:
-        return u'\0'
+        return 0
     raise oefmt(space.w_TypeError, '"%s" must be a 1-character string', name)
 
 def _build_dialect(space, w_dialect, w_delimiter, w_doublequote,
@@ -104,11 +107,11 @@
             w_strict = _fetch(space, w_dialect, 'strict')
 
     dialect = W_Dialect()
-    dialect.delimiter = _get_char(space, w_delimiter, u',', 'delimiter')
+    dialect.delimiter = _get_codepoint(space, w_delimiter, ord(u','), 
'delimiter')
     dialect.doublequote = _get_bool(space, w_doublequote, True)
-    dialect.escapechar = _get_char(space, w_escapechar, u'\0', 'escapechar')
-    dialect.lineterminator = _get_str(space, w_lineterminator, u'\r\n', 
'lineterminator')
-    dialect.quotechar = _get_char(space, w_quotechar, u'"', 'quotechar')
+    dialect.escapechar = _get_codepoint(space, w_escapechar, ord(u'\0'), 
'escapechar')
+    dialect.lineterminator = _get_str(space, w_lineterminator, '\r\n', 
'lineterminator')
+    dialect.quotechar = _get_codepoint(space, w_quotechar, ord(u'"'), 
'quotechar')
     tmp_quoting = _get_int(space, w_quoting, QUOTE_MINIMAL, 'quoting')
     dialect.skipinitialspace = _get_bool(space, w_skipinitialspace, False)
     dialect.strict = _get_bool(space, w_strict, False)
@@ -117,13 +120,13 @@
     if not (0 <= tmp_quoting < 4):
         raise oefmt(space.w_TypeError, 'bad "quoting" value')
 
-    if dialect.delimiter == u'\0':
+    if dialect.delimiter == 0:
         raise oefmt(space.w_TypeError,
                     '"delimiter" must be a 1-character string')
 
     if space.is_w(w_quotechar, space.w_None) and w_quoting is None:
         tmp_quoting = QUOTE_NONE
-    if tmp_quoting != QUOTE_NONE and dialect.quotechar == u'\0':
+    if tmp_quoting != QUOTE_NONE and dialect.quotechar == 0:
         raise oefmt(space.w_TypeError,
                     "quotechar must be set if quoting enabled")
     dialect.quoting = tmp_quoting
@@ -158,14 +161,20 @@
 
 
 def _get_escapechar(space, dialect):
-    if dialect.escapechar == u'\0':
+    if dialect.escapechar == 0:
         return space.w_None
-    return space.newtext(dialect.escapechar)
+    s = rutf8.unichr_as_utf8(dialect.escapechar)
+    return space.newutf8(s, 1)
 
 def _get_quotechar(space, dialect):
-    if dialect.quotechar == u'\0':
+    if dialect.quotechar == 0:
         return space.w_None
-    return space.newtext(dialect.quotechar)
+    s = rutf8.unichr_as_utf8(dialect.quotechar)
+    return space.newutf8(s, 1)
+
+def _get_delimiter(space, dialect):
+    s = rutf8.unichr_as_utf8(dialect.delimiter)
+    return space.newutf8(s, 1)
 
 
 W_Dialect.typedef = TypeDef(
@@ -173,8 +182,7 @@
         __new__ = interp2app(W_Dialect___new__),
         __reduce_ex__ = interp2app(W_Dialect.reduce_ex_w),
 
-        delimiter        = interp_attrproperty('delimiter', W_Dialect,
-            wrapfn='newtext'),
+        delimiter        = GetSetProperty(_get_delimiter, cls=W_Dialect),
         doublequote      = interp_attrproperty('doublequote', W_Dialect,
             wrapfn='newbool'),
         escapechar       = GetSetProperty(_get_escapechar, cls=W_Dialect),
diff --git a/pypy/module/_csv/interp_reader.py 
b/pypy/module/_csv/interp_reader.py
--- a/pypy/module/_csv/interp_reader.py
+++ b/pypy/module/_csv/interp_reader.py
@@ -96,17 +96,17 @@
                         # save empty field
                         self.save_field(field_builder)
                         state = EAT_CRNL
-                    elif (c == ord(dialect.quotechar) and
+                    elif (c == dialect.quotechar and
                               dialect.quoting != QUOTE_NONE):
                         # start quoted field
                         state = IN_QUOTED_FIELD
-                    elif c == ord(dialect.escapechar):
+                    elif c == dialect.escapechar:
                         # possible escaped character
                         state = ESCAPED_CHAR
                     elif c == ord(u' ') and dialect.skipinitialspace:
                         # ignore space at start of field
                         pass
-                    elif c == ord(dialect.delimiter):
+                    elif c == dialect.delimiter:
                         # save empty field
                         self.save_field(field_builder)
                     else:
@@ -130,10 +130,10 @@
                         # end of line
                         self.save_field(field_builder)
                         state = EAT_CRNL
-                    elif c == ord(dialect.escapechar):
+                    elif c == dialect.escapechar:
                         # possible escaped character
                         state = ESCAPED_CHAR
-                    elif c == ord(dialect.delimiter):
+                    elif c == dialect.delimiter:
                         # save field - wait for new field
                         self.save_field(field_builder)
                         state = START_FIELD
@@ -143,10 +143,10 @@
 
                 elif state == IN_QUOTED_FIELD:
                     # in quoted field
-                    if c == ord(dialect.escapechar):
+                    if c == dialect.escapechar:
                         # Possible escape character
                         state = ESCAPE_IN_QUOTED_FIELD
-                    elif (c == ord(dialect.quotechar) and
+                    elif (c == dialect.quotechar and
                               dialect.quoting != QUOTE_NONE):
                         if dialect.doublequote:
                             # doublequote; " represented by ""
@@ -165,11 +165,11 @@
                 elif state == QUOTE_IN_QUOTED_FIELD:
                     # doublequote - seen a quote in an quoted field
                     if (dialect.quoting != QUOTE_NONE and
-                            c == ord(dialect.quotechar)):
+                            c == dialect.quotechar):
                         # save "" as "
                         self.add_char(field_builder, c)
                         state = IN_QUOTED_FIELD
-                    elif c == ord(dialect.delimiter):
+                    elif c == dialect.delimiter:
                         # save field - wait for new field
                         self.save_field(field_builder)
                         state = START_FIELD
@@ -183,7 +183,7 @@
                     else:
                         # illegal
                         raise self.error(u"'%s' expected after '%s'" % (
-                            dialect.delimiter, dialect.quotechar))
+                            unichr(dialect.delimiter), 
unichr(dialect.quotechar)))
 
                 elif state == EAT_CRNL:
                     if not (c == ord(u'\n') or c == ord(u'\r')):
diff --git a/pypy/module/_csv/interp_writer.py 
b/pypy/module/_csv/interp_writer.py
--- a/pypy/module/_csv/interp_writer.py
+++ b/pypy/module/_csv/interp_writer.py
@@ -1,4 +1,4 @@
-from rpython.rlib.rstring import UnicodeBuilder
+from rpython.rlib.rutf8 import Utf8StringIterator, Utf8StringBuilder
 from rpython.rlib import objectmodel
 from pypy.interpreter.baseobjspace import W_Root
 from pypy.interpreter.error import OperationError
@@ -15,11 +15,13 @@
         self.dialect = dialect
         self.w_filewrite = space.getattr(w_fileobj, space.newtext('write'))
         # precompute this
-        special = dialect.delimiter + dialect.lineterminator
-        if dialect.escapechar != '\0':
-            special += dialect.escapechar
-        if dialect.quotechar != '\0':
-            special += dialect.quotechar
+        special = [dialect.delimiter]
+        for c in Utf8StringIterator(dialect.lineterminator):
+            special.append(c)
+        if dialect.escapechar != 0:
+            special.append(dialect.escapechar)
+        if dialect.quotechar != 0:
+            special.append(dialect.quotechar)
         self.special_characters = special
 
     @objectmodel.dont_inline
@@ -35,16 +37,17 @@
         space = self.space
         fields_w = space.listview(w_fields)
         dialect = self.dialect
-        rec = UnicodeBuilder(80)
+        rec = Utf8StringBuilder(80)
         #
         for field_index in range(len(fields_w)):
             w_field = fields_w[field_index]
             if space.is_w(w_field, space.w_None):
-                field = u""
+                field = ""
+                length = 0
             elif space.isinstance_w(w_field, space.w_float):
-                field = space.realunicode_w(space.repr(w_field))
+                field, length = space.utf8_len_w(space.repr(w_field))
             else:
-                field = space.realunicode_w(space.str(w_field))
+                field, length = space.utf8_len_w(space.str(w_field))
             #
             if dialect.quoting == QUOTE_NONNUMERIC:
                 try:
@@ -57,9 +60,9 @@
             elif dialect.quoting == QUOTE_ALL:
                 quoted = True
             elif dialect.quoting == QUOTE_MINIMAL:
-                # Find out if we really quoting
+                # Find out if we really need quoting.
                 special_characters = self.special_characters
-                for c in field:
+                for c in Utf8StringIterator(field):
                     if c in special_characters:
                         if c != dialect.quotechar or dialect.doublequote:
                             quoted = True
@@ -78,15 +81,15 @@
 
             # If this is not the first field we need a field separator
             if field_index > 0:
-                rec.append(dialect.delimiter)
+                rec.append_code(dialect.delimiter)
 
             # Handle preceding quote
             if quoted:
-                rec.append(dialect.quotechar)
+                rec.append_code(dialect.quotechar)
 
             # Copy field data
             special_characters = self.special_characters
-            for c in field:
+            for c in Utf8StringIterator(field):
                 if c in special_characters:
                     if dialect.quoting == QUOTE_NONE:
                         want_escape = True
@@ -94,28 +97,28 @@
                         want_escape = False
                         if c == dialect.quotechar:
                             if dialect.doublequote:
-                                rec.append(dialect.quotechar)
+                                rec.append_code(dialect.quotechar)
                             else:
                                 want_escape = True
                     if want_escape:
-                        if dialect.escapechar == u'\0':
+                        if dialect.escapechar == 0:
                             raise self.error("need to escape, "
                                              "but no escapechar set")
-                        rec.append(dialect.escapechar)
+                        rec.append_code(dialect.escapechar)
                     else:
                         assert quoted
                 # Copy field character into record buffer
-                rec.append(c)
+                rec.append_code(c)
 
             # Handle final quote
             if quoted:
-                rec.append(dialect.quotechar)
+                rec.append_code(dialect.quotechar)
 
         # Add line terminator
         rec.append(dialect.lineterminator)
 
         line = rec.build()
-        return space.call_function(self.w_filewrite, space.newtext(line))
+        return space.call_function(self.w_filewrite, space.newutf8(line, 
rec.getlength()))
 
     def writerows(self, w_seqseq):
         """Construct and write a series of sequences to a csv file.
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to