Author: Armin Rigo <[email protected]>
Branch:
Changeset: r69572:680434495e1e
Date: 2014-03-01 15:00 +0100
http://bitbucket.org/pypy/pypy/changeset/680434495e1e/
Log: Try to improve the timetable of the jit counters: replace the two
tables of 4096 entries with a single 5-ways-associative table of
2048 entries.
diff --git a/rpython/jit/metainterp/compile.py
b/rpython/jit/metainterp/compile.py
--- a/rpython/jit/metainterp/compile.py
+++ b/rpython/jit/metainterp/compile.py
@@ -500,8 +500,9 @@
ST_BUSY_FLAG = 0x01 # if set, busy tracing from the guard
ST_TYPE_MASK = 0x06 # mask for the type (TY_xxx)
ST_SHIFT = 3 # in "status >> ST_SHIFT" is stored:
- # - if TY_NONE, the jitcounter index directly
+ # - if TY_NONE, the jitcounter hash directly
# - otherwise, the guard_value failarg index
+ ST_SHIFT_MASK = -(1 << ST_SHIFT)
TY_NONE = 0x00
TY_INT = 0x02
TY_REF = 0x04
@@ -514,8 +515,8 @@
#
if metainterp_sd.warmrunnerdesc is not None: # for tests
jitcounter = metainterp_sd.warmrunnerdesc.jitcounter
- index = jitcounter.in_second_half(jitcounter.fetch_next_index())
- self.status = index << self.ST_SHIFT
+ hash = jitcounter.fetch_next_hash()
+ self.status = hash & self.ST_SHIFT_MASK
def make_a_counter_per_value(self, guard_value_op):
assert guard_value_op.getopnum() == rop.GUARD_VALUE
@@ -566,7 +567,7 @@
# common case: this is not a guard_value, and we are not
# already busy tracing. The rest of self.status stores a
# valid per-guard index in the jitcounter.
- index = self.status >> self.ST_SHIFT
+ hash = self.status & self.ST_SHIFT_MASK
#
# do we have the BUSY flag? If so, we're tracing right now, e.g. in an
# outer invocation of the same function, so don't trace again for now.
@@ -597,12 +598,11 @@
intval = llmemory.cast_adr_to_int(
llmemory.cast_int_to_adr(intval), "forced")
- hash = (current_object_addr_as_int(self) * 777767777 +
- intval * 1442968193)
- index = jitcounter.in_second_half(jitcounter.get_index(hash))
+ hash = r_uint(current_object_addr_as_int(self) * 777767777 +
+ intval * 1442968193)
#
increment = jitdriver_sd.warmstate.increment_trace_eagerness
- return jitcounter.tick(index, increment)
+ return jitcounter.tick(hash, increment)
def start_compiling(self):
# start tracing and compiling from this guard.
diff --git a/rpython/jit/metainterp/counter.py
b/rpython/jit/metainterp/counter.py
--- a/rpython/jit/metainterp/counter.py
+++ b/rpython/jit/metainterp/counter.py
@@ -7,28 +7,32 @@
assert r_uint32.BITS == 32
UINT32MAX = 2 ** 32 - 1
+# keep in sync with the C code in pypy__decay_jit_counters
+ENTRY = lltype.Struct('timetable_entry',
+ ('times', lltype.FixedSizeArray(rffi.FLOAT, 5)),
+ ('subhashes', lltype.FixedSizeArray(rffi.USHORT, 5)))
+
class JitCounter:
- DEFAULT_SIZE = 4096
+ DEFAULT_SIZE = 2048
def __init__(self, size=DEFAULT_SIZE, translator=None):
"NOT_RPYTHON"
self.size = size
- self.shift = 1
+ self.shift = 16
while (UINT32MAX >> self.shift) != size - 1:
self.shift += 1
- assert self.shift < 999, "size is not a power of two <= 2**31"
+ assert self.shift < 999, "size is not a power of two <= 2**16"
#
- # The table of timings. The first half is used for starting the
- # compilation of new loops. The second half is used for turning
- # failing guards into bridges. The two halves are split to avoid
- # too much interference.
- self.timetablesize = size * 2
- self.timetable = lltype.malloc(rffi.CArray(rffi.FLOAT),
- self.timetablesize,
+ # The table of timings. This is a 5-ways associative cache.
+ # We index into it using a number between 0 and (size - 1),
+ # and we're getting a 32-bytes-long entry; then this entry
+ # contains 5 possible ways, each occupying 6 bytes: 4 bytes
+ # for a float, and the 2 lowest bytes from the original hash.
+ self.timetable = lltype.malloc(rffi.CArray(ENTRY), self.size,
flavor='raw', zero=True,
track_allocation=False)
- self._nextindex = r_uint(0)
+ self._nexthash = r_uint(0)
#
# The table of JitCell entries, recording already-compiled loops
self.celltable = [None] * size
@@ -56,46 +60,92 @@
return 0.0 # no increment, never reach 1.0
return 1.0 / (threshold - 0.001)
- def get_index(self, hash):
- """Return the index (< self.size) from a hash value. This truncates
+ def _get_index(self, hash):
+ """Return the index (< self.size) from a hash. This truncates
the hash to 32 bits, and then keep the *highest* remaining bits.
- Be sure that hash is computed correctly."""
+ Be sure that hash is computed correctly, by multiplying with
+ a large odd number or by fetch_next_hash()."""
hash32 = r_uint(r_uint32(hash)) # mask off the bits higher than 32
index = hash32 >> self.shift # shift, resulting in a value < size
return index # return the result as a r_uint
- get_index._always_inline_ = True
+ _get_index._always_inline_ = True
- def fetch_next_index(self):
- result = self._nextindex
- self._nextindex = (result + 1) & self.get_index(-1)
+ @staticmethod
+ def _get_subhash(hash):
+ return hash & 65535
+
+ def fetch_next_hash(self):
+ result = self._nexthash
+ # note: all three "1" bits in the following constant are needed
+ # to make test_counter.test_fetch_next_index pass. The first
+ # is to increment the "subhash" (lower 16 bits of the hash).
+ # The second is to increment the "index" portion of the hash.
+ # The third is so that after 65536 passes, the "index" is
+ # incremented by one more (by overflow), so that the next
+ # 65536 passes don't end up with the same subhashes.
+ self._nexthash = result + r_uint(1 | (1 << self.shift) |
+ (1 << (self.shift - 16)))
return result
- def in_second_half(self, index):
- assert index < r_uint(self.size)
- return self.size + index
+ def _swap(self, p_entry, n):
+ if float(p_entry.times[n]) > float(p_entry.times[n + 1]):
+ return n + 1
+ else:
+ x = p_entry.times[n]
+ p_entry.times[n] = p_entry.times[n + 1]
+ p_entry.times[n + 1] = x
+ x = p_entry.subhashes[n]
+ p_entry.subhashes[n] = p_entry.subhashes[n + 1]
+ p_entry.subhashes[n + 1] = x
+ return n
+ _swap._always_inline_ = True
- def tick(self, index, increment):
- counter = float(self.timetable[index]) + increment
+ def tick(self, hash, increment):
+ p_entry = self.timetable[self._get_index(hash)]
+ subhash = self._get_subhash(hash)
+ #
+ if p_entry.subhashes[0] == subhash:
+ n = 0
+ elif p_entry.subhashes[1] == subhash:
+ n = self._swap(p_entry, 0)
+ elif p_entry.subhashes[2] == subhash:
+ n = self._swap(p_entry, 1)
+ elif p_entry.subhashes[3] == subhash:
+ n = self._swap(p_entry, 2)
+ elif p_entry.subhashes[4] == subhash:
+ n = self._swap(p_entry, 3)
+ else:
+ n = 4
+ while n > 0 and float(p_entry.times[n - 1]) == 0.0:
+ n -= 1
+ p_entry.subhashes[n] = rffi.cast(rffi.USHORT, subhash)
+ p_entry.times[n] = r_singlefloat(0.0)
+ #
+ counter = float(p_entry.times[n]) + increment
if counter < 1.0:
- self.timetable[index] = r_singlefloat(counter)
+ p_entry.times[n] = r_singlefloat(counter)
return False
else:
# when the bound is reached, we immediately reset the value to 0.0
- self.reset(index)
+ self.reset(hash)
return True
- tick._always_inline_ = True
- def reset(self, index):
- self.timetable[index] = r_singlefloat(0.0)
+ def reset(self, hash):
+ p_entry = self.timetable[self._get_index(hash)]
+ subhash = self._get_subhash(hash)
+ for i in range(5):
+ if p_entry.subhashes[i] == subhash:
+ p_entry.times[i] = r_singlefloat(0.0)
- def lookup_chain(self, index):
- return self.celltable[index]
+ def lookup_chain(self, hash):
+ return self.celltable[self._get_index(hash)]
- def cleanup_chain(self, index):
- self.reset(index)
- self.install_new_cell(index, None)
+ def cleanup_chain(self, hash):
+ self.reset(hash)
+ self.install_new_cell(hash, None)
- def install_new_cell(self, index, newcell):
+ def install_new_cell(self, hash, newcell):
+ index = self._get_index(hash)
cell = self.celltable[index]
keep = newcell
while cell is not None:
@@ -125,22 +175,29 @@
# important in corner cases where we would suddenly compile more
# than one loop because all counters reach the bound at the same
# time, but where compiling all but the first one is pointless.
- size = self.timetablesize
- pypy__decay_jit_counters(self.timetable, self.decay_by_mult, size)
+ p = rffi.cast(rffi.CCHARP, self.timetable)
+ pypy__decay_jit_counters(p, self.decay_by_mult, self.size)
# this function is written directly in C; gcc will optimize it using SSE
eci = ExternalCompilationInfo(post_include_bits=["""
-static void pypy__decay_jit_counters(float table[], double f1, long size1) {
+static void pypy__decay_jit_counters(char *data, double f1, long size) {
+ struct { float times[5]; unsigned short subhashes[5]; } *p = data;
float f = (float)f1;
- int i, size = (int)size1;
- for (i=0; i<size; i++)
- table[i] *= f;
+ long i;
+ for (i=0; i<size; i++) {
+ p->times[0] *= f;
+ p->times[1] *= f;
+ p->times[2] *= f;
+ p->times[3] *= f;
+ p->times[4] *= f;
+ ++p;
+ }
}
"""])
pypy__decay_jit_counters = rffi.llexternal(
- "pypy__decay_jit_counters", [rffi.FLOATP, lltype.Float, lltype.Signed],
+ "pypy__decay_jit_counters", [rffi.CCHARP, lltype.Float, lltype.Signed],
lltype.Void, compilation_info=eci, _nowrapper=True, sandboxsafe=True)
@@ -153,11 +210,12 @@
def __init__(self):
from collections import defaultdict
JitCounter.__init__(self, size=8)
- zero = r_singlefloat(0.0)
- self.timetable = defaultdict(lambda: zero)
+ def make_null_entry():
+ return lltype.malloc(ENTRY, immortal=True, zero=True)
+ self.timetable = defaultdict(make_null_entry)
self.celltable = defaultdict(lambda: None)
- def get_index(self, hash):
+ def _get_index(self, hash):
"NOT_RPYTHON"
return hash
@@ -165,10 +223,6 @@
"NOT_RPYTHON"
pass
- def in_second_half(self, index):
- "NOT_RPYTHON"
- return index + 12345
-
def _clear_all(self):
self.timetable.clear()
self.celltable.clear()
diff --git a/rpython/jit/metainterp/test/test_counter.py
b/rpython/jit/metainterp/test/test_counter.py
--- a/rpython/jit/metainterp/test/test_counter.py
+++ b/rpython/jit/metainterp/test/test_counter.py
@@ -5,30 +5,77 @@
jc = JitCounter(size=128) # 7 bits
for i in range(10):
hash = 400000001 * i
- index = jc.get_index(hash)
+ index = jc._get_index(hash)
assert index == (hash >> (32 - 7))
-def test_fetch_next_index():
- jc = JitCounter(size=4)
- lst = [jc.fetch_next_index() for i in range(10)]
- assert lst == [0, 1, 2, 3, 0, 1, 2, 3, 0, 1]
+def test_get_subhash():
+ assert JitCounter._get_subhash(0x518ebd) == 0x8ebd
+
+def test_fetch_next_hash():
+ jc = JitCounter(size=2048)
+ # check the distribution of "fetch_next_hash() & ~7".
+ blocks = [[jc.fetch_next_hash() & ~7 for i in range(65536)]
+ for j in range(2)]
+ for block in blocks:
+ assert 0 <= jc._get_index(block[0]) < 2048
+ assert 0 <= jc._get_index(block[-1]) < 2048
+ assert 0 <= jc._get_index(block[2531]) < 2048
+ assert 0 <= jc._get_index(block[45981]) < 2048
+ # should be correctly distributed: ideally 2047 or 2048 different
+ # values
+ assert len(set([jc._get_index(x) for x in block])) >= 2040
+ # check that the subkeys are distinct for same-block entries
+ subkeys = {}
+ for block in blocks:
+ for x in block:
+ idx = jc._get_index(x)
+ subkeys.setdefault(idx, []).append(jc._get_subhash(x))
+ collisions = 0
+ for idx, sks in subkeys.items():
+ collisions += len(sks) - len(set(sks))
+ assert collisions < 5
+
+def index2hash(jc, index, subhash=0):
+ assert 0 <= subhash < 65536
+ return (index << jc.shift) | subhash
def test_tick():
jc = JitCounter()
incr = jc.compute_threshold(4)
for i in range(5):
- r = jc.tick(104, incr)
+ r = jc.tick(index2hash(jc, 104), incr)
assert r is (i == 3)
for i in range(5):
- r = jc.tick(108, incr)
- s = jc.tick(109, incr)
+ r = jc.tick(index2hash(jc, 108), incr)
+ s = jc.tick(index2hash(jc, 109), incr)
assert r is (i == 3)
assert s is (i == 3)
- jc.reset(108)
+ jc.reset(index2hash(jc, 108))
for i in range(5):
- r = jc.tick(108, incr)
+ r = jc.tick(index2hash(jc, 108), incr)
assert r is (i == 3)
+def test_collisions():
+ jc = JitCounter(size=4) # 2 bits
+ incr = jc.compute_threshold(4)
+ for i in range(5):
+ for sk in range(100, 105):
+ r = jc.tick(index2hash(jc, 3, subhash=sk), incr)
+ assert r is (i == 3)
+
+ jc = JitCounter()
+ incr = jc.compute_threshold(4)
+ misses = 0
+ for i in range(5):
+ for sk in range(100, 106):
+ r = jc.tick(index2hash(jc, 3, subhash=sk), incr)
+ if r:
+ assert i == 3
+ elif i == 3:
+ misses += 1
+ assert misses < 5
+
+
def test_install_new_chain():
class Dead:
next = None
diff --git a/rpython/jit/metainterp/warmstate.py
b/rpython/jit/metainterp/warmstate.py
--- a/rpython/jit/metainterp/warmstate.py
+++ b/rpython/jit/metainterp/warmstate.py
@@ -7,7 +7,7 @@
from rpython.rlib.jit import PARAMETERS
from rpython.rlib.nonconst import NonConstant
from rpython.rlib.objectmodel import specialize, we_are_translated, r_dict
-from rpython.rlib.rarithmetic import intmask
+from rpython.rlib.rarithmetic import intmask, r_uint
from rpython.rlib.unroll import unrolling_iterable
from rpython.rtyper.annlowlevel import (hlstr, cast_base_ptr_to_instance,
cast_object_to_ptr)
@@ -312,7 +312,7 @@
#
assert 0, "should have raised"
- def bound_reached(index, cell, *args):
+ def bound_reached(hash, cell, *args):
if not confirm_enter_jit(*args):
return
jitcounter.decay_all_counters()
@@ -322,7 +322,7 @@
greenargs = args[:num_green_args]
if cell is None:
cell = JitCell(*greenargs)
- jitcounter.install_new_cell(index, cell)
+ jitcounter.install_new_cell(hash, cell)
cell.flags |= JC_TRACING
try:
metainterp.compile_and_run_once(jitdriver_sd, *args)
@@ -339,16 +339,16 @@
# These few lines inline some logic that is also on the
# JitCell class, to avoid computing the hash several times.
greenargs = args[:num_green_args]
- index = JitCell.get_index(*greenargs)
- cell = jitcounter.lookup_chain(index)
+ hash = JitCell.get_uhash(*greenargs)
+ cell = jitcounter.lookup_chain(hash)
while cell is not None:
if isinstance(cell, JitCell) and cell.comparekey(*greenargs):
break # found
cell = cell.next
else:
# not found. increment the counter
- if jitcounter.tick(index, increment_threshold):
- bound_reached(index, None, *args)
+ if jitcounter.tick(hash, increment_threshold):
+ bound_reached(hash, None, *args)
return
# Here, we have found 'cell'.
@@ -359,15 +359,15 @@
# this function. don't trace a second time.
return
# attached by compile_tmp_callback(). count normally
- if jitcounter.tick(index, increment_threshold):
- bound_reached(index, cell, *args)
+ if jitcounter.tick(hash, increment_threshold):
+ bound_reached(hash, cell, *args)
return
# machine code was already compiled for these greenargs
procedure_token = cell.get_procedure_token()
if procedure_token is None:
# it was an aborted compilation, or maybe a weakref that
# has been freed
- jitcounter.cleanup_chain(index)
+ jitcounter.cleanup_chain(hash)
return
if not confirm_enter_jit(*args):
return
@@ -422,7 +422,6 @@
green_args_name_spec = unrolling_iterable([('g%d' % i, TYPE)
for i, TYPE in enumerate(jitdriver_sd._green_args_spec)])
unwrap_greenkey = self.make_unwrap_greenkey()
- random_initial_value = hash(self)
#
class JitCell(BaseJitCell):
def __init__(self, *greenargs):
@@ -441,20 +440,20 @@
return True
@staticmethod
- def get_index(*greenargs):
- x = random_initial_value
+ def get_uhash(*greenargs):
+ x = r_uint(-1888132534)
i = 0
for _, TYPE in green_args_name_spec:
item = greenargs[i]
- y = hash_whatever(TYPE, item)
- x = intmask((x ^ y) * 1405695061) # prime number, 2**30~31
+ y = r_uint(hash_whatever(TYPE, item))
+ x = (x ^ y) * r_uint(1405695061) # prime number, 2**30~31
i = i + 1
- return jitcounter.get_index(x)
+ return x
@staticmethod
def get_jitcell(*greenargs):
- index = JitCell.get_index(*greenargs)
- cell = jitcounter.lookup_chain(index)
+ hash = JitCell.get_uhash(*greenargs)
+ cell = jitcounter.lookup_chain(hash)
while cell is not None:
if (isinstance(cell, JitCell) and
cell.comparekey(*greenargs)):
@@ -470,15 +469,15 @@
@staticmethod
def ensure_jit_cell_at_key(greenkey):
greenargs = unwrap_greenkey(greenkey)
- index = JitCell.get_index(*greenargs)
- cell = jitcounter.lookup_chain(index)
+ hash = JitCell.get_uhash(*greenargs)
+ cell = jitcounter.lookup_chain(hash)
while cell is not None:
if (isinstance(cell, JitCell) and
cell.comparekey(*greenargs)):
return cell
cell = cell.next
newcell = JitCell(*greenargs)
- jitcounter.install_new_cell(index, newcell)
+ jitcounter.install_new_cell(hash, newcell)
return newcell
#
self.JitCell = JitCell
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit