Author: Armin Rigo <[email protected]>
Branch: 
Changeset: r69572:680434495e1e
Date: 2014-03-01 15:00 +0100
http://bitbucket.org/pypy/pypy/changeset/680434495e1e/

Log:    Try to improve the timetable of the jit counters: replace the two
        tables of 4096 entries with a single 5-ways-associative table of
        2048 entries.

diff --git a/rpython/jit/metainterp/compile.py 
b/rpython/jit/metainterp/compile.py
--- a/rpython/jit/metainterp/compile.py
+++ b/rpython/jit/metainterp/compile.py
@@ -500,8 +500,9 @@
     ST_BUSY_FLAG    = 0x01     # if set, busy tracing from the guard
     ST_TYPE_MASK    = 0x06     # mask for the type (TY_xxx)
     ST_SHIFT        = 3        # in "status >> ST_SHIFT" is stored:
-                               # - if TY_NONE, the jitcounter index directly
+                               # - if TY_NONE, the jitcounter hash directly
                                # - otherwise, the guard_value failarg index
+    ST_SHIFT_MASK   = -(1 << ST_SHIFT)
     TY_NONE         = 0x00
     TY_INT          = 0x02
     TY_REF          = 0x04
@@ -514,8 +515,8 @@
         #
         if metainterp_sd.warmrunnerdesc is not None:   # for tests
             jitcounter = metainterp_sd.warmrunnerdesc.jitcounter
-            index = jitcounter.in_second_half(jitcounter.fetch_next_index())
-            self.status = index << self.ST_SHIFT
+            hash = jitcounter.fetch_next_hash()
+            self.status = hash & self.ST_SHIFT_MASK
 
     def make_a_counter_per_value(self, guard_value_op):
         assert guard_value_op.getopnum() == rop.GUARD_VALUE
@@ -566,7 +567,7 @@
             # common case: this is not a guard_value, and we are not
             # already busy tracing.  The rest of self.status stores a
             # valid per-guard index in the jitcounter.
-            index = self.status >> self.ST_SHIFT
+            hash = self.status & self.ST_SHIFT_MASK
         #
         # do we have the BUSY flag?  If so, we're tracing right now, e.g. in an
         # outer invocation of the same function, so don't trace again for now.
@@ -597,12 +598,11 @@
                     intval = llmemory.cast_adr_to_int(
                         llmemory.cast_int_to_adr(intval), "forced")
 
-            hash = (current_object_addr_as_int(self) * 777767777 +
-                    intval * 1442968193)
-            index = jitcounter.in_second_half(jitcounter.get_index(hash))
+            hash = r_uint(current_object_addr_as_int(self) * 777767777 +
+                          intval * 1442968193)
         #
         increment = jitdriver_sd.warmstate.increment_trace_eagerness
-        return jitcounter.tick(index, increment)
+        return jitcounter.tick(hash, increment)
 
     def start_compiling(self):
         # start tracing and compiling from this guard.
diff --git a/rpython/jit/metainterp/counter.py 
b/rpython/jit/metainterp/counter.py
--- a/rpython/jit/metainterp/counter.py
+++ b/rpython/jit/metainterp/counter.py
@@ -7,28 +7,32 @@
 assert r_uint32.BITS == 32
 UINT32MAX = 2 ** 32 - 1
 
+# keep in sync with the C code in pypy__decay_jit_counters
+ENTRY = lltype.Struct('timetable_entry',
+                      ('times', lltype.FixedSizeArray(rffi.FLOAT, 5)),
+                      ('subhashes', lltype.FixedSizeArray(rffi.USHORT, 5)))
+
 
 class JitCounter:
-    DEFAULT_SIZE = 4096
+    DEFAULT_SIZE = 2048
 
     def __init__(self, size=DEFAULT_SIZE, translator=None):
         "NOT_RPYTHON"
         self.size = size
-        self.shift = 1
+        self.shift = 16
         while (UINT32MAX >> self.shift) != size - 1:
             self.shift += 1
-            assert self.shift < 999, "size is not a power of two <= 2**31"
+            assert self.shift < 999, "size is not a power of two <= 2**16"
         #
-        # The table of timings.  The first half is used for starting the
-        # compilation of new loops.  The second half is used for turning
-        # failing guards into bridges.  The two halves are split to avoid
-        # too much interference.
-        self.timetablesize = size * 2
-        self.timetable = lltype.malloc(rffi.CArray(rffi.FLOAT),
-                                       self.timetablesize,
+        # The table of timings.  This is a 5-ways associative cache.
+        # We index into it using a number between 0 and (size - 1),
+        # and we're getting a 32-bytes-long entry; then this entry
+        # contains 5 possible ways, each occupying 6 bytes: 4 bytes
+        # for a float, and the 2 lowest bytes from the original hash.
+        self.timetable = lltype.malloc(rffi.CArray(ENTRY), self.size,
                                        flavor='raw', zero=True,
                                        track_allocation=False)
-        self._nextindex = r_uint(0)
+        self._nexthash = r_uint(0)
         #
         # The table of JitCell entries, recording already-compiled loops
         self.celltable = [None] * size
@@ -56,46 +60,92 @@
             return 0.0   # no increment, never reach 1.0
         return 1.0 / (threshold - 0.001)
 
-    def get_index(self, hash):
-        """Return the index (< self.size) from a hash value.  This truncates
+    def _get_index(self, hash):
+        """Return the index (< self.size) from a hash.  This truncates
         the hash to 32 bits, and then keep the *highest* remaining bits.
-        Be sure that hash is computed correctly."""
+        Be sure that hash is computed correctly, by multiplying with
+        a large odd number or by fetch_next_hash()."""
         hash32 = r_uint(r_uint32(hash))  # mask off the bits higher than 32
         index = hash32 >> self.shift     # shift, resulting in a value < size
         return index                     # return the result as a r_uint
-    get_index._always_inline_ = True
+    _get_index._always_inline_ = True
 
-    def fetch_next_index(self):
-        result = self._nextindex
-        self._nextindex = (result + 1) & self.get_index(-1)
+    @staticmethod
+    def _get_subhash(hash):
+        return hash & 65535
+
+    def fetch_next_hash(self):
+        result = self._nexthash
+        # note: all three "1" bits in the following constant are needed
+        # to make test_counter.test_fetch_next_index pass.  The first
+        # is to increment the "subhash" (lower 16 bits of the hash).
+        # The second is to increment the "index" portion of the hash.
+        # The third is so that after 65536 passes, the "index" is
+        # incremented by one more (by overflow), so that the next
+        # 65536 passes don't end up with the same subhashes.
+        self._nexthash = result + r_uint(1 | (1 << self.shift) |
+                                         (1 << (self.shift - 16)))
         return result
 
-    def in_second_half(self, index):
-        assert index < r_uint(self.size)
-        return self.size + index
+    def _swap(self, p_entry, n):
+        if float(p_entry.times[n]) > float(p_entry.times[n + 1]):
+            return n + 1
+        else:
+            x = p_entry.times[n]
+            p_entry.times[n] = p_entry.times[n + 1]
+            p_entry.times[n + 1] = x
+            x = p_entry.subhashes[n]
+            p_entry.subhashes[n] = p_entry.subhashes[n + 1]
+            p_entry.subhashes[n + 1] = x
+            return n
+    _swap._always_inline_ = True
 
-    def tick(self, index, increment):
-        counter = float(self.timetable[index]) + increment
+    def tick(self, hash, increment):
+        p_entry = self.timetable[self._get_index(hash)]
+        subhash = self._get_subhash(hash)
+        #
+        if p_entry.subhashes[0] == subhash:
+            n = 0
+        elif p_entry.subhashes[1] == subhash:
+            n = self._swap(p_entry, 0)
+        elif p_entry.subhashes[2] == subhash:
+            n = self._swap(p_entry, 1)
+        elif p_entry.subhashes[3] == subhash:
+            n = self._swap(p_entry, 2)
+        elif p_entry.subhashes[4] == subhash:
+            n = self._swap(p_entry, 3)
+        else:
+            n = 4
+            while n > 0 and float(p_entry.times[n - 1]) == 0.0:
+                n -= 1
+            p_entry.subhashes[n] = rffi.cast(rffi.USHORT, subhash)
+            p_entry.times[n] = r_singlefloat(0.0)
+        #
+        counter = float(p_entry.times[n]) + increment
         if counter < 1.0:
-            self.timetable[index] = r_singlefloat(counter)
+            p_entry.times[n] = r_singlefloat(counter)
             return False
         else:
             # when the bound is reached, we immediately reset the value to 0.0
-            self.reset(index)
+            self.reset(hash)
             return True
-    tick._always_inline_ = True
 
-    def reset(self, index):
-        self.timetable[index] = r_singlefloat(0.0)
+    def reset(self, hash):
+        p_entry = self.timetable[self._get_index(hash)]
+        subhash = self._get_subhash(hash)
+        for i in range(5):
+            if p_entry.subhashes[i] == subhash:
+                p_entry.times[i] = r_singlefloat(0.0)
 
-    def lookup_chain(self, index):
-        return self.celltable[index]
+    def lookup_chain(self, hash):
+        return self.celltable[self._get_index(hash)]
 
-    def cleanup_chain(self, index):
-        self.reset(index)
-        self.install_new_cell(index, None)
+    def cleanup_chain(self, hash):
+        self.reset(hash)
+        self.install_new_cell(hash, None)
 
-    def install_new_cell(self, index, newcell):
+    def install_new_cell(self, hash, newcell):
+        index = self._get_index(hash)
         cell = self.celltable[index]
         keep = newcell
         while cell is not None:
@@ -125,22 +175,29 @@
         # important in corner cases where we would suddenly compile more
         # than one loop because all counters reach the bound at the same
         # time, but where compiling all but the first one is pointless.
-        size = self.timetablesize
-        pypy__decay_jit_counters(self.timetable, self.decay_by_mult, size)
+        p = rffi.cast(rffi.CCHARP, self.timetable)
+        pypy__decay_jit_counters(p, self.decay_by_mult, self.size)
 
 
 # this function is written directly in C; gcc will optimize it using SSE
 eci = ExternalCompilationInfo(post_include_bits=["""
-static void pypy__decay_jit_counters(float table[], double f1, long size1) {
+static void pypy__decay_jit_counters(char *data, double f1, long size) {
+    struct { float times[5]; unsigned short subhashes[5]; } *p = data;
     float f = (float)f1;
-    int i, size = (int)size1;
-    for (i=0; i<size; i++)
-        table[i] *= f;
+    long i;
+    for (i=0; i<size; i++) {
+        p->times[0] *= f;
+        p->times[1] *= f;
+        p->times[2] *= f;
+        p->times[3] *= f;
+        p->times[4] *= f;
+        ++p;
+    }
 }
 """])
 
 pypy__decay_jit_counters = rffi.llexternal(
-    "pypy__decay_jit_counters", [rffi.FLOATP, lltype.Float, lltype.Signed],
+    "pypy__decay_jit_counters", [rffi.CCHARP, lltype.Float, lltype.Signed],
     lltype.Void, compilation_info=eci, _nowrapper=True, sandboxsafe=True)
 
 
@@ -153,11 +210,12 @@
     def __init__(self):
         from collections import defaultdict
         JitCounter.__init__(self, size=8)
-        zero = r_singlefloat(0.0)
-        self.timetable = defaultdict(lambda: zero)
+        def make_null_entry():
+            return lltype.malloc(ENTRY, immortal=True, zero=True)
+        self.timetable = defaultdict(make_null_entry)
         self.celltable = defaultdict(lambda: None)
 
-    def get_index(self, hash):
+    def _get_index(self, hash):
         "NOT_RPYTHON"
         return hash
 
@@ -165,10 +223,6 @@
         "NOT_RPYTHON"
         pass
 
-    def in_second_half(self, index):
-        "NOT_RPYTHON"
-        return index + 12345
-
     def _clear_all(self):
         self.timetable.clear()
         self.celltable.clear()
diff --git a/rpython/jit/metainterp/test/test_counter.py 
b/rpython/jit/metainterp/test/test_counter.py
--- a/rpython/jit/metainterp/test/test_counter.py
+++ b/rpython/jit/metainterp/test/test_counter.py
@@ -5,30 +5,77 @@
     jc = JitCounter(size=128)    # 7 bits
     for i in range(10):
         hash = 400000001 * i
-        index = jc.get_index(hash)
+        index = jc._get_index(hash)
         assert index == (hash >> (32 - 7))
 
-def test_fetch_next_index():
-    jc = JitCounter(size=4)
-    lst = [jc.fetch_next_index() for i in range(10)]
-    assert lst == [0, 1, 2, 3, 0, 1, 2, 3, 0, 1]
+def test_get_subhash():
+    assert JitCounter._get_subhash(0x518ebd) == 0x8ebd
+
+def test_fetch_next_hash():
+    jc = JitCounter(size=2048)
+    # check the distribution of "fetch_next_hash() & ~7".
+    blocks = [[jc.fetch_next_hash() & ~7 for i in range(65536)]
+              for j in range(2)]
+    for block in blocks:
+        assert 0 <= jc._get_index(block[0]) < 2048
+        assert 0 <= jc._get_index(block[-1]) < 2048
+        assert 0 <= jc._get_index(block[2531]) < 2048
+        assert 0 <= jc._get_index(block[45981]) < 2048
+        # should be correctly distributed: ideally 2047 or 2048 different
+        # values
+        assert len(set([jc._get_index(x) for x in block])) >= 2040
+    # check that the subkeys are distinct for same-block entries
+    subkeys = {}
+    for block in blocks:
+        for x in block:
+            idx = jc._get_index(x)
+            subkeys.setdefault(idx, []).append(jc._get_subhash(x))
+    collisions = 0
+    for idx, sks in subkeys.items():
+        collisions += len(sks) - len(set(sks))
+    assert collisions < 5
+
+def index2hash(jc, index, subhash=0):
+    assert 0 <= subhash < 65536
+    return (index << jc.shift) | subhash
 
 def test_tick():
     jc = JitCounter()
     incr = jc.compute_threshold(4)
     for i in range(5):
-        r = jc.tick(104, incr)
+        r = jc.tick(index2hash(jc, 104), incr)
         assert r is (i == 3)
     for i in range(5):
-        r = jc.tick(108, incr)
-        s = jc.tick(109, incr)
+        r = jc.tick(index2hash(jc, 108), incr)
+        s = jc.tick(index2hash(jc, 109), incr)
         assert r is (i == 3)
         assert s is (i == 3)
-    jc.reset(108)
+    jc.reset(index2hash(jc, 108))
     for i in range(5):
-        r = jc.tick(108, incr)
+        r = jc.tick(index2hash(jc, 108), incr)
         assert r is (i == 3)
 
+def test_collisions():
+    jc = JitCounter(size=4)     # 2 bits
+    incr = jc.compute_threshold(4)
+    for i in range(5):
+        for sk in range(100, 105):
+            r = jc.tick(index2hash(jc, 3, subhash=sk), incr)
+            assert r is (i == 3)
+
+    jc = JitCounter()
+    incr = jc.compute_threshold(4)
+    misses = 0
+    for i in range(5):
+        for sk in range(100, 106):
+            r = jc.tick(index2hash(jc, 3, subhash=sk), incr)
+            if r:
+                assert i == 3
+            elif i == 3:
+                misses += 1
+    assert misses < 5
+
+
 def test_install_new_chain():
     class Dead:
         next = None
diff --git a/rpython/jit/metainterp/warmstate.py 
b/rpython/jit/metainterp/warmstate.py
--- a/rpython/jit/metainterp/warmstate.py
+++ b/rpython/jit/metainterp/warmstate.py
@@ -7,7 +7,7 @@
 from rpython.rlib.jit import PARAMETERS
 from rpython.rlib.nonconst import NonConstant
 from rpython.rlib.objectmodel import specialize, we_are_translated, r_dict
-from rpython.rlib.rarithmetic import intmask
+from rpython.rlib.rarithmetic import intmask, r_uint
 from rpython.rlib.unroll import unrolling_iterable
 from rpython.rtyper.annlowlevel import (hlstr, cast_base_ptr_to_instance,
     cast_object_to_ptr)
@@ -312,7 +312,7 @@
             #
             assert 0, "should have raised"
 
-        def bound_reached(index, cell, *args):
+        def bound_reached(hash, cell, *args):
             if not confirm_enter_jit(*args):
                 return
             jitcounter.decay_all_counters()
@@ -322,7 +322,7 @@
             greenargs = args[:num_green_args]
             if cell is None:
                 cell = JitCell(*greenargs)
-                jitcounter.install_new_cell(index, cell)
+                jitcounter.install_new_cell(hash, cell)
             cell.flags |= JC_TRACING
             try:
                 metainterp.compile_and_run_once(jitdriver_sd, *args)
@@ -339,16 +339,16 @@
             # These few lines inline some logic that is also on the
             # JitCell class, to avoid computing the hash several times.
             greenargs = args[:num_green_args]
-            index = JitCell.get_index(*greenargs)
-            cell = jitcounter.lookup_chain(index)
+            hash = JitCell.get_uhash(*greenargs)
+            cell = jitcounter.lookup_chain(hash)
             while cell is not None:
                 if isinstance(cell, JitCell) and cell.comparekey(*greenargs):
                     break    # found
                 cell = cell.next
             else:
                 # not found. increment the counter
-                if jitcounter.tick(index, increment_threshold):
-                    bound_reached(index, None, *args)
+                if jitcounter.tick(hash, increment_threshold):
+                    bound_reached(hash, None, *args)
                 return
 
             # Here, we have found 'cell'.
@@ -359,15 +359,15 @@
                     # this function. don't trace a second time.
                     return
                 # attached by compile_tmp_callback().  count normally
-                if jitcounter.tick(index, increment_threshold):
-                    bound_reached(index, cell, *args)
+                if jitcounter.tick(hash, increment_threshold):
+                    bound_reached(hash, cell, *args)
                 return
             # machine code was already compiled for these greenargs
             procedure_token = cell.get_procedure_token()
             if procedure_token is None:
                 # it was an aborted compilation, or maybe a weakref that
                 # has been freed
-                jitcounter.cleanup_chain(index)
+                jitcounter.cleanup_chain(hash)
                 return
             if not confirm_enter_jit(*args):
                 return
@@ -422,7 +422,6 @@
         green_args_name_spec = unrolling_iterable([('g%d' % i, TYPE)
                      for i, TYPE in enumerate(jitdriver_sd._green_args_spec)])
         unwrap_greenkey = self.make_unwrap_greenkey()
-        random_initial_value = hash(self)
         #
         class JitCell(BaseJitCell):
             def __init__(self, *greenargs):
@@ -441,20 +440,20 @@
                 return True
 
             @staticmethod
-            def get_index(*greenargs):
-                x = random_initial_value
+            def get_uhash(*greenargs):
+                x = r_uint(-1888132534)
                 i = 0
                 for _, TYPE in green_args_name_spec:
                     item = greenargs[i]
-                    y = hash_whatever(TYPE, item)
-                    x = intmask((x ^ y) * 1405695061)  # prime number, 2**30~31
+                    y = r_uint(hash_whatever(TYPE, item))
+                    x = (x ^ y) * r_uint(1405695061)  # prime number, 2**30~31
                     i = i + 1
-                return jitcounter.get_index(x)
+                return x
 
             @staticmethod
             def get_jitcell(*greenargs):
-                index = JitCell.get_index(*greenargs)
-                cell = jitcounter.lookup_chain(index)
+                hash = JitCell.get_uhash(*greenargs)
+                cell = jitcounter.lookup_chain(hash)
                 while cell is not None:
                     if (isinstance(cell, JitCell) and
                             cell.comparekey(*greenargs)):
@@ -470,15 +469,15 @@
             @staticmethod
             def ensure_jit_cell_at_key(greenkey):
                 greenargs = unwrap_greenkey(greenkey)
-                index = JitCell.get_index(*greenargs)
-                cell = jitcounter.lookup_chain(index)
+                hash = JitCell.get_uhash(*greenargs)
+                cell = jitcounter.lookup_chain(hash)
                 while cell is not None:
                     if (isinstance(cell, JitCell) and
                             cell.comparekey(*greenargs)):
                         return cell
                     cell = cell.next
                 newcell = JitCell(*greenargs)
-                jitcounter.install_new_cell(index, newcell)
+                jitcounter.install_new_cell(hash, newcell)
                 return newcell
         #
         self.JitCell = JitCell
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to