https://github.com/python/cpython/commit/4cc82ffa377db5073fdc6f85c6f35f9c47397796
commit: 4cc82ffa377db5073fdc6f85c6f35f9c47397796
branch: main
author: Ɓukasz Langa <luk...@langa.pl>
committer: ambv <luk...@langa.pl>
date: 2025-03-21T18:27:35+01:00
summary:

gh-131507: Refactor screen and cursor position calculations (GH-131547)

This is based off #131509.

files:
M Lib/_pyrepl/reader.py
M Lib/_pyrepl/types.py
M Lib/_pyrepl/utils.py

diff --git a/Lib/_pyrepl/reader.py b/Lib/_pyrepl/reader.py
index b38f0bf82db331..7fc2422dac9c3f 100644
--- a/Lib/_pyrepl/reader.py
+++ b/Lib/_pyrepl/reader.py
@@ -25,12 +25,11 @@
 
 from contextlib import contextmanager
 from dataclasses import dataclass, field, fields
-import unicodedata
 from _colorize import can_colorize, ANSIColors
 
 
 from . import commands, console, input
-from .utils import wlen, unbracket, str_width
+from .utils import wlen, unbracket, disp_str
 from .trace import trace
 
 
@@ -39,36 +38,6 @@
 from .types import Callback, SimpleContextManager, KeySpec, CommandName
 
 
-def disp_str(buffer: str) -> tuple[str, list[int]]:
-    """disp_str(buffer:string) -> (string, [int])
-
-    Return the string that should be the printed representation of
-    |buffer| and a list detailing where the characters of |buffer|
-    get used up.  E.g.:
-
-    >>> disp_str(chr(3))
-    ('^C', [1, 0])
-
-    """
-    b: list[int] = []
-    s: list[str] = []
-    for c in buffer:
-        if c == '\x1a':
-            s.append(c)
-            b.append(2)
-        elif ord(c) < 128:
-            s.append(c)
-            b.append(1)
-        elif unicodedata.category(c).startswith("C"):
-            c = r"\u%04x" % ord(c)
-            s.append(c)
-            b.append(len(c))
-        else:
-            s.append(c)
-            b.append(str_width(c))
-    return "".join(s), b
-
-
 # syntax classes:
 
 SYNTAX_WHITESPACE, SYNTAX_WORD, SYNTAX_SYMBOL = range(3)
@@ -347,14 +316,12 @@ def calc_screen(self) -> list[str]:
         pos -= offset
 
         prompt_from_cache = (offset and self.buffer[offset - 1] != "\n")
-
         lines = "".join(self.buffer[offset:]).split("\n")
-
         cursor_found = False
         lines_beyond_cursor = 0
         for ln, line in enumerate(lines, num_common_lines):
-            ll = len(line)
-            if 0 <= pos <= ll:
+            line_len = len(line)
+            if 0 <= pos <= line_len:
                 self.lxy = pos, ln
                 cursor_found = True
             elif cursor_found:
@@ -368,34 +335,34 @@ def calc_screen(self) -> list[str]:
                 prompt_from_cache = False
                 prompt = ""
             else:
-                prompt = self.get_prompt(ln, ll >= pos >= 0)
+                prompt = self.get_prompt(ln, line_len >= pos >= 0)
             while "\n" in prompt:
                 pre_prompt, _, prompt = prompt.partition("\n")
                 last_refresh_line_end_offsets.append(offset)
                 screen.append(pre_prompt)
                 screeninfo.append((0, []))
-            pos -= ll + 1
-            prompt, lp = self.process_prompt(prompt)
-            l, l2 = disp_str(line)
-            wrapcount = (wlen(l) + lp) // self.console.width
-            if wrapcount == 0:
-                offset += ll + 1  # Takes all of the line plus the newline
+            pos -= line_len + 1
+            prompt, prompt_len = self.process_prompt(prompt)
+            chars, char_widths = disp_str(line)
+            wrapcount = (sum(char_widths) + prompt_len) // self.console.width
+            trace("wrapcount = {wrapcount}", wrapcount=wrapcount)
+            if wrapcount == 0 or not char_widths:
+                offset += line_len + 1  # Takes all of the line plus the 
newline
                 last_refresh_line_end_offsets.append(offset)
-                screen.append(prompt + l)
-                screeninfo.append((lp, l2))
+                screen.append(prompt + "".join(chars))
+                screeninfo.append((prompt_len, char_widths))
             else:
-                i = 0
-                while l:
-                    prelen = lp if i == 0 else 0
+                pre = prompt
+                prelen = prompt_len
+                for wrap in range(wrapcount + 1):
                     index_to_wrap_before = 0
                     column = 0
-                    for character_width in l2:
-                        if column + character_width >= self.console.width - 
prelen:
+                    for char_width in char_widths:
+                        if column + char_width + prelen >= self.console.width:
                             break
                         index_to_wrap_before += 1
-                        column += character_width
-                    pre = prompt if i == 0 else ""
-                    if len(l) > index_to_wrap_before:
+                        column += char_width
+                    if len(chars) > index_to_wrap_before:
                         offset += index_to_wrap_before
                         post = "\\"
                         after = [1]
@@ -404,11 +371,14 @@ def calc_screen(self) -> list[str]:
                         post = ""
                         after = []
                     last_refresh_line_end_offsets.append(offset)
-                    screen.append(pre + l[:index_to_wrap_before] + post)
-                    screeninfo.append((prelen, l2[:index_to_wrap_before] + 
after))
-                    l = l[index_to_wrap_before:]
-                    l2 = l2[index_to_wrap_before:]
-                    i += 1
+                    render = pre + "".join(chars[:index_to_wrap_before]) + post
+                    render_widths = char_widths[:index_to_wrap_before] + after
+                    screen.append(render)
+                    screeninfo.append((prelen, render_widths))
+                    chars = chars[index_to_wrap_before:]
+                    char_widths = char_widths[index_to_wrap_before:]
+                    pre = ""
+                    prelen = 0
         self.screeninfo = screeninfo
         self.cxy = self.pos2xy()
         if self.msg:
@@ -537,9 +507,9 @@ def setpos_from_xy(self, x: int, y: int) -> None:
         pos = 0
         i = 0
         while i < y:
-            prompt_len, character_widths = self.screeninfo[i]
-            offset = len(character_widths) - character_widths.count(0)
-            in_wrapped_line = prompt_len + sum(character_widths) >= 
self.console.width
+            prompt_len, char_widths = self.screeninfo[i]
+            offset = len(char_widths)
+            in_wrapped_line = prompt_len + sum(char_widths) >= 
self.console.width
             if in_wrapped_line:
                 pos += offset - 1  # -1 cause backslash is not in buffer
             else:
@@ -560,29 +530,33 @@ def setpos_from_xy(self, x: int, y: int) -> None:
 
     def pos2xy(self) -> tuple[int, int]:
         """Return the x, y coordinates of position 'pos'."""
-        # this *is* incomprehensible, yes.
-        p, y = 0, 0
-        l2: list[int] = []
+
+        prompt_len, y = 0, 0
+        char_widths: list[int] = []
         pos = self.pos
         assert 0 <= pos <= len(self.buffer)
+
+        # optimize for the common case: typing at the end of the buffer
         if pos == len(self.buffer) and len(self.screeninfo) > 0:
             y = len(self.screeninfo) - 1
-            p, l2 = self.screeninfo[y]
-            return p + sum(l2) + l2.count(0), y
+            prompt_len, char_widths = self.screeninfo[y]
+            return prompt_len + sum(char_widths), y
+
+        for prompt_len, char_widths in self.screeninfo:
+            offset = len(char_widths)
+            in_wrapped_line = prompt_len + sum(char_widths) >= 
self.console.width
+            if in_wrapped_line:
+                offset -= 1  # need to remove line-wrapping backslash
 
-        for p, l2 in self.screeninfo:
-            l = len(l2) - l2.count(0)
-            in_wrapped_line = p + sum(l2) >= self.console.width
-            offset = l - 1 if in_wrapped_line else l  # need to remove 
backslash
             if offset >= pos:
                 break
 
-            if p + sum(l2) >= self.console.width:
-                pos -= l - 1  # -1 cause backslash is not in buffer
-            else:
-                pos -= l + 1  # +1 cause newline is in buffer
+            if not in_wrapped_line:
+                offset += 1  # there's a newline in buffer
+
+            pos -= offset
             y += 1
-        return p + sum(l2[:pos]), y
+        return prompt_len + sum(char_widths[:pos]), y
 
     def insert(self, text: str | list[str]) -> None:
         """Insert 'text' at the insertion point."""
diff --git a/Lib/_pyrepl/types.py b/Lib/_pyrepl/types.py
index f9d48b828c720b..c5b7ebc1a406bd 100644
--- a/Lib/_pyrepl/types.py
+++ b/Lib/_pyrepl/types.py
@@ -1,8 +1,10 @@
 from collections.abc import Callable, Iterator
 
-Callback = Callable[[], object]
-SimpleContextManager = Iterator[None]
-KeySpec = str  # like r"\C-c"
-CommandName = str  # like "interrupt"
-EventTuple = tuple[CommandName, str]
-Completer = Callable[[str, int], str | None]
+type Callback = Callable[[], object]
+type SimpleContextManager = Iterator[None]
+type KeySpec = str  # like r"\C-c"
+type CommandName = str  # like "interrupt"
+type EventTuple = tuple[CommandName, str]
+type Completer = Callable[[str, int], str | None]
+type CharBuffer = list[str]
+type CharWidths = list[int]
diff --git a/Lib/_pyrepl/utils.py b/Lib/_pyrepl/utils.py
index 0eb5f8c0097f41..7437fbe1ab9371 100644
--- a/Lib/_pyrepl/utils.py
+++ b/Lib/_pyrepl/utils.py
@@ -2,6 +2,9 @@
 import unicodedata
 import functools
 
+from .types import CharBuffer, CharWidths
+from .trace import trace
+
 ANSI_ESCAPE_SEQUENCE = re.compile(r"\x1b\[[ -@]*[A-~]")
 ZERO_WIDTH_BRACKET = re.compile(r"\x01.*?\x02")
 ZERO_WIDTH_TRANS = str.maketrans({"\x01": "", "\x02": ""})
@@ -36,3 +39,39 @@ def unbracket(s: str, including_content: bool = False) -> 
str:
     if including_content:
         return ZERO_WIDTH_BRACKET.sub("", s)
     return s.translate(ZERO_WIDTH_TRANS)
+
+
+def disp_str(buffer: str) -> tuple[CharBuffer, CharWidths]:
+    r"""Decompose the input buffer into a printable variant.
+
+    Returns a tuple of two lists:
+    - the first list is the input buffer, character by character;
+    - the second list is the visible width of each character in the input
+      buffer.
+
+    Examples:
+    >>> utils.disp_str("a = 9")
+    (['a', ' ', '=', ' ', '9'], [1, 1, 1, 1, 1])
+    """
+    chars: CharBuffer = []
+    char_widths: CharWidths = []
+
+    if not buffer:
+        return chars, char_widths
+
+    for c in buffer:
+        if c == "\x1a":  # CTRL-Z on Windows
+            chars.append(c)
+            char_widths.append(2)
+        elif ord(c) < 128:
+            chars.append(c)
+            char_widths.append(1)
+        elif unicodedata.category(c).startswith("C"):
+            c = r"\u%04x" % ord(c)
+            chars.append(c)
+            char_widths.append(len(c))
+        else:
+            chars.append(c)
+            char_widths.append(str_width(c))
+    trace("disp_str({buffer}) = {s}, {b}", buffer=repr(buffer), s=chars, 
b=char_widths)
+    return chars, char_widths

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: arch...@mail-archive.com

Reply via email to