Ryan Gambord has uploaded this change for review. ( https://gem5-review.googlesource.com/c/public/gem5/+/17068

Change subject: sim: Updated SimObject.py and multidict.py to support multiple inheritance
......................................................................

sim: Updated SimObject.py and multidict.py to support multiple inheritance

Change-Id: Ib88f666b79a7038cade6c74234f5badf8eefd278
---
M src/python/m5/SimObject.py
M src/python/m5/util/multidict.py
2 files changed, 211 insertions(+), 172 deletions(-)



diff --git a/src/python/m5/SimObject.py b/src/python/m5/SimObject.py
index 7f19c07..4962ad4 100644
--- a/src/python/m5/SimObject.py
+++ b/src/python/m5/SimObject.py
@@ -468,46 +468,38 @@

         # initialize required attributes

-        # class-only attributes
-        cls._params = multidict() # param descriptions
-        cls._ports = multidict()  # port descriptions
+        msos = [base for base in bases if isinstance(base, MetaSimObject)]
+        cls._bases = msos

-        # class or instance attributes
-        cls._values = multidict()   # param values
-        cls._hr_values = multidict() # human readable param values
-        cls._children = multidict() # SimObject children
-        cls._port_refs = multidict() # port ref objects
- cls._instantiated = False # really instantiated, cloned, or subclassed
+        # Initialize the class-only attributes

- # We don't support multiple inheritance of sim objects. If you want - # to, you must fix multidict to deal with it properly. Non sim-objects
-        # are ok, though
-        bTotal = 0
-        for c in bases:
-            if isinstance(c, MetaSimObject):
-                bTotal += 1
-            if bTotal > 1:
-                raise TypeError(
-                      "SimObjects do not support multiple inheritance")
+        # param descriptions
+        cls._params = multidict(*[mso._params for mso in msos])

-        base = bases[0]
+        # port descriptions
+        cls._ports = multidict(*[mso._ports for mso in msos])

-        # Set up general inheritance via multidicts.  A subclass will
-        # inherit all its settings from the base class.  The only time
-        # the following is not true is when we define the SimObject
-        # class itself (in which case the multidicts have no parent).
-        if isinstance(base, MetaSimObject):
-            cls._base = base
-            cls._params.parent = base._params
-            cls._ports.parent = base._ports
-            cls._values.parent = base._values
-            cls._hr_values.parent = base._hr_values
-            cls._children.parent = base._children
-            cls._port_refs.parent = base._port_refs
-            # mark base as having been subclassed
-            base._instantiated = True
-        else:
-            cls._base = None
+
+        # Initialize class or instance attributes
+
+        # param values
+        cls._values = multidict(*[mso._values for mso in msos])
+
+        # human readable param values
+        cls._hr_values = multidict(*[mso._hr_values for mso in msos])
+
+        # SimObject children
+        cls._children = multidict(*[mso._children for mso in msos])
+
+        # port ref objects
+        cls._port_refs = multidict(*[mso._port_refs for mso in msos])
+
+        cls._instantiated = False
+
+        for mso in msos:
+            # Mark all bases as instantiated/cloned/subclassed
+            mso._instantiated = True
+

         # default keyword values
         if 'type' in cls._value_dict:
@@ -711,14 +703,12 @@
     py::module m = m_internal.def_submodule("param_${cls}");
 ''')
         code.indent()
-        if cls._base:
-            code('py::class_<${cls}Params, ${{cls._base.type}}Params, ' \
-                 'std::unique_ptr<${{cls}}Params, py::nodelete>>(' \
-                 'm, "${cls}Params")')
-        else:
-            code('py::class_<${cls}Params, ' \
-                 'std::unique_ptr<${cls}Params, py::nodelete>>(' \
-                 'm, "${cls}Params")')
+        base_params_str = 'Params, '.join(
+ [base.type for base in cls._bases] + [''])
+        code('py::class_<${cls}Params, ${base_params_str}' \
+             'std::unique_ptr<${{cls}}Params, py::nodelete>>(' \
+             'm, "${cls}Params")')
+

         code.indent()
         if not hasattr(cls, 'abstract') or not cls.abstract:
@@ -745,22 +735,19 @@
             # overridden, use that value.
             if cls.cxx_base:
                 bases.append(cls.cxx_base)
-        elif cls._base:
-            # If not and if there was a SimObject base, use its c++ class
+        elif cls._bases:
+ # If not and if there were SimObject bases, use thier c++ classes
             # as this class' base.
-            bases.append(cls._base.cxx_class)
+            bases = [base.cxx_class for base in cls._bases]
         # Add in any extra bases that were requested.
         bases.extend(cls.cxx_extra_bases)

-        if bases:
-            base_str = ", ".join(bases)
-            code('py::class_<${{cls.cxx_class}}, ${base_str}, ' \
-                 'std::unique_ptr<${{cls.cxx_class}}, py::nodelete>>(' \
-                 'm, "${py_class_name}")')
-        else:
-            code('py::class_<${{cls.cxx_class}}, ' \
-                 'std::unique_ptr<${{cls.cxx_class}}, py::nodelete>>(' \
-                 'm, "${py_class_name}")')
+
+        base_str = ", ".join(bases + [''])
+        code('py::class_<${{cls.cxx_class}}, ${base_str}' \
+             'std::unique_ptr<${{cls.cxx_class}}, py::nodelete>>(' \
+             'm, "${py_class_name}")')
+
         code.indent()
         for exp in cls.cxx_exports:
             exp.export(code, cls.cxx_class)
@@ -770,8 +757,11 @@
         code.dedent()
         code('}')
         code()
- code('static EmbeddedPyBind embed_obj("${0}", module_init, "${1}");',
-             cls, cls._base.type if cls._base else "")
+ code('static EmbeddedPyBind embed_obj("${0}", module_init, {${1}});',
+             cls,
+             ', '.join(['"' + base.type + '"' for base in cls._bases])
+            )
+


     # Generate the C++ declaration (.hh file) for this SimObject's
@@ -821,8 +811,8 @@
             port.cxx_predecls(code)
         code()

-        if cls._base:
-            code('#include "params/${{cls._base.type}}.hh"')
+        for base in cls._bases:
+            code('#include "params/${{base.type}}.hh"')
             code()

         for ptype in ptypes:
@@ -832,8 +822,9 @@

         # now generate the actual param struct
         code("struct ${cls}Params")
-        if cls._base:
-            code("    : public ${{cls._base.type}}Params")
+        if cls._bases:
+            code("    : " + ', '.join(
+                ['public ' + base.type + 'Params' for base in cls._bases]))
         code("{")
         if not hasattr(cls, 'abstract') or not cls.abstract:
             if 'type' in cls.__dict__:
diff --git a/src/python/m5/util/multidict.py b/src/python/m5/util/multidict.py
index 2330156..45edd7e 100644
--- a/src/python/m5/util/multidict.py
+++ b/src/python/m5/util/multidict.py
@@ -27,149 +27,197 @@
 # Authors: Nathan Binkert

 from __future__ import print_function
+from abc import ABCMeta
+

 __all__ = [ 'multidict' ]

-class multidict(object):
-    def __init__(self, parent = {}, **kwargs):
-        self.local = dict(**kwargs)
-        self.parent = parent
-        self.deleted = {}

-    def __str__(self):
-        return str(dict(self.items()))
+
+class multidict(dict):
+
+    def _mro(obj):
+        mros = []
+        for p in obj.parents:
+            try:
+                mros.append([i for i in p.mro])
+            except AttributeError as e:
+                pass
+
+        seqs = [[obj]] + mros + [list(obj.parents)]
+        res = []
+        while 1:
+            nonemptyseqs=[seq for seq in seqs if seq]
+            if not nonemptyseqs:
+                return tuple(res)
+ for seq in nonemptyseqs: # find merge candidates among seq heads
+                cand = seq[0];
+                for s in nonemptyseqs:
+                    for item in s[1:]:
+                        if cand is item:
+                            cand = None
+                if cand is not None:
+                    break
+            if cand is None:
+                raise TypeError("Inconsistent hierarchy")
+
+
+            res.append(cand)
+            for seq in nonemptyseqs:
+                if seq[0] is cand:
+                    del seq[0]

     def __repr__(self):
-        return repr(dict(list(self.items())))
+        args = []
+        for p in self.parents:
+            args.append(''.join(
+                ['    ' + line for line in repr(p).splitlines(True)]))
+        for k, v in dict.iteritems(self):
+            args.append('  ' + k + '=' + repr(v))
+        str = self.__class__.__name__ + '(\n' + ',\n'.join(args) + '\n)'
+        if self.ignored: str += '.ignore(' + repr(self.ignored) + ')'
+        return str

-    def __contains__(self, key):
-        return key in self.local or key in self.parent
+    def __str__(self):
+        return '<multidict instance \'' + str(dict(self.items())) + '\'>'

-    def __delitem__(self, key):
-        try:
-            del self.local[key]
-        except KeyError as e:
-            if key in self.parent:
-                self.deleted[key] = True
-            else:
-                raise KeyError(e)
+    def __new__(cls, *parents, **kw):
+        for p in parents:
+            if not isinstance(p, dict):
+                raise TypeError("'{}' cannot have parent of type '{}'"\
+                        .format(cls.__name__, type(p).__name__))
+        return dict.__new__(cls, *parents, **kw)

-    def __setitem__(self, key, value):
-        self.deleted.pop(key, False)
-        self.local[key] = value
+    def __init__(self, *parents, **kw):
+
+        self.parents = parents
+        self.ignored = set([])
+        self.mro = self._mro()
+        super(self.__class__, self).__init__(**kw)
+
+    def __iter__(self):
+        ignored = set([])
+        for md in self.mro:
+            ignored |= getattr(md, 'ignored', set([]))
+            for key in dict.__iter__(md):
+                if key not in ignored:
+                    yield key
+ ignored.add(key) # Don't yield ancestral copies of this key

     def __getitem__(self, key):
-        try:
-            return self.local[key]
-        except KeyError as e:
-            if not self.deleted.get(key, False) and key in self.parent:
-                return self.parent[key]
-            else:
-                raise KeyError(e)
+        for md in self.mro:
+            try:
+                return dict.__getitem__(md, key)
+            except KeyError:
+                pass # Keep searching up the hierarchy

-    def __len__(self):
-        return len(self.local) + len(self.parent)
+            if key in getattr(md, 'ignored', set([])):
+                raise KeyError(key) # Stop searching up the hierarchy

-    def next(self):
-        for key,value in self.local.items():
-            yield key,value
+        raise KeyError(key) # Key not found in hierarchy

-        if self.parent:
-            for key,value in self.parent.next():
-                if key not in self.local and key not in self.deleted:
-                    yield key,value
+    def __contains__(self, key):
+        for md in self.mro:
+            if dict.__contains__(md, key):
+                return True

-    def has_key(self, key):
-        return key in self
+            if key in getattr(md, 'ignored', set([])):
+                return False # Stop searching up the hierarchy

-    def items(self):
-        for item in self.next():
-            yield item
+        return False # Key not found in hierarchy

-    def keys(self):
-        for key,value in self.next():
-            yield key
+    def __delitem__(self, key):
+        # If the item is locally defined, delete it using super
+        if super(self.__class__, self).__contains__(key):
+            super(self.__class__, self).__delitem__(key)
+        # Ignore this key from all ancestors
+        self.ignored.add(key)

-    def values(self):
-        for key,value in self.next():
-            yield value
+    def clear(self):
+        for key in self:
+            del self[key]

-    def get(self, key, default=None):
-        try:
-            return self[key]
-        except KeyError as e:
-            return default
+    def copy(self):
+        ret = multidict(*self.parents, **self)
+        ret.ignored = self.ignored.copy()
+        return ret

-    def setdefault(self, key, default):
+    def get(self, key, *args, **kw):
+        if kw:
+            raise TypeError('get() takes no keyword arguments')
+        if len(args) > 1:
+            raise TypeError('get() expected at most 2 arguments, got {}'
+                            .format(len(args)))
         try:
             return self[key]
         except KeyError:
-            self.deleted.pop(key, False)
-            self.local[key] = default
-            return default
+            return args[0] if len(args) > 0 else None

-    def _dump(self):
-        print('multidict dump')
-        node = self
-        while isinstance(node, multidict):
-            print('    ', node.local)
-            node = node.parent
+    has_key = __contains__

-    def _dumpkey(self, key):
-        values = []
-        node = self
-        while isinstance(node, multidict):
-            if key in node.local:
-                values.append(node.local[key])
-            node = node.parent
-        print(key, values)
+    def keys(self):
+        return [k for k in self]

-if __name__ == '__main__':
-    test1 = multidict()
-    test2 = multidict(test1)
-    test3 = multidict(test2)
-    test4 = multidict(test3)
+    def values(self):
+        return [self[k] for k in self]

-    test1['a'] = 'test1_a'
-    test1['b'] = 'test1_b'
-    test1['c'] = 'test1_c'
-    test1['d'] = 'test1_d'
-    test1['e'] = 'test1_e'
+    def items(self):
+        return [(k, self[k]) for k in self]

-    test2['a'] = 'test2_a'
-    del test2['b']
-    test2['c'] = 'test2_c'
-    del test1['a']
+    def iterkeys(self):
+        for k in self:
+            yield k

-    test2.setdefault('f', multidict)
+    def itervalues(self):
+        for k in self:
+            yield self[k]

-    print('test1>', list(test1.items()))
-    print('test2>', list(test2.items()))
-    #print(test1['a'])
-    print(test1['b'])
-    print(test1['c'])
-    print(test1['d'])
-    print(test1['e'])
+    def iteritems(self):
+        for k in self:
+            yield (k,self[k])

-    print(test2['a'])
-    #print(test2['b'])
-    print(test2['c'])
-    print(test2['d'])
-    print(test2['e'])
+    def pop(self, key, *args, **kw):
+        if kw:
+            raise TypeError('pop() takes no keyword arguments')
+        if len(args) > 1:
+            raise TypeError('pop() expected at most 2 arguments, got {}'
+                            .format(len(args)))
+        try:
+            value = self[key]
+            del self[key]
+            return value
+        except KeyError:
+            if len(args) > 0: return args[0]
+            else: raise

-    for key in test2.keys():
-        print(key)
+    def popitem(self):
+        print(list(self.iterkeys()))
+        try:
+            return self.pop(self.iterkeys().next())
+        except StopIteration:
+            raise KeyError('popitem(): dictionary is empty')

-    test2.get('g', 'foo')
-    #test2.get('b')
-    test2.get('b', 'bar')
-    test2.setdefault('b', 'blah')
-    print(test1)
-    print(test2)
-    print(repr(test2))

-    print(len(test2))
+    def setdefault(self, key, *args, **kw):
+        if kw:
+            raise TypeError('setdefault() takes no keyword arguments')
+        if len(args) > 1:
+ raise TypeError('setdefault() expected at most 2 arguments, got {}'
+                            .format(len(args)))
+        if key not in self:
+            self[key] = args[0] if len(args) > 0 else None
+        return self[key]

-    test3['a'] = [ 0, 1, 2, 3 ]
+    def update(self, *args, **kw):
+        if len(args) > 1:
+ raise TypeError('setdefault() expected at most 1 arguments, got {}'
+                            .format(len(args)))
+        self.__init__(args[0] if args else {}, self.copy(), **kw)

-    print(test4)
+    def ignore(self, *args):
+        self.ignored.update(*args)
+        return self
+
+    def __getattr__(self, attr):
+        if attr == 'local': # Return local variables as a regular dict
+            return dict(self)

--
To view, visit https://gem5-review.googlesource.com/c/public/gem5/+/17068
To unsubscribe, or for help writing mail filters, visit https://gem5-review.googlesource.com/settings

Gerrit-Project: public/gem5
Gerrit-Branch: master
Gerrit-Change-Id: Ib88f666b79a7038cade6c74234f5badf8eefd278
Gerrit-Change-Number: 17068
Gerrit-PatchSet: 1
Gerrit-Owner: Ryan Gambord <[email protected]>
Gerrit-MessageType: newchange
_______________________________________________
gem5-dev mailing list
[email protected]
http://m5sim.org/mailman/listinfo/gem5-dev

Reply via email to