Ryan Gambord has uploaded this change for review. ( https://gem5-review.googlesource.com/c/public/gem5/+/17610

Change subject: sim,python: Multi-inheritance support for SimObjects
......................................................................

sim,python: Multi-inheritance support for SimObjects

Change-Id: I95a252f8a5bf6cadbc0f8b07be84b4a75bf4f4ba
Signed-off-by: Ryan Gambord <[email protected]>
---
M src/python/m5/SimObject.py
M src/python/m5/util/multidict.py
M src/sim/init.cc
M src/sim/init.hh
4 files changed, 211 insertions(+), 188 deletions(-)



diff --git a/src/python/m5/SimObject.py b/src/python/m5/SimObject.py
index 0e29980..1c7f9a3 100644
--- a/src/python/m5/SimObject.py
+++ b/src/python/m5/SimObject.py
@@ -471,46 +471,38 @@

         # initialize required attributes

-        # class-only attributes
-        cls._params = multidict() # param descriptions
-        cls._ports = multidict()  # port descriptions
+        msos = [base for base in bases if isinstance(base, MetaSimObject)]
+        cls._bases = msos

-        # class or instance attributes
-        cls._values = multidict()   # param values
-        cls._hr_values = multidict() # human readable param values
-        cls._children = multidict() # SimObject children
-        cls._port_refs = multidict() # port ref objects
- cls._instantiated = False # really instantiated, cloned, or subclassed
+        # Initialize the class-only attributes

- # We don't support multiple inheritance of sim objects. If you want - # to, you must fix multidict to deal with it properly. Non sim-objects
-        # are ok, though
-        bTotal = 0
-        for c in bases:
-            if isinstance(c, MetaSimObject):
-                bTotal += 1
-            if bTotal > 1:
-                raise TypeError(
-                      "SimObjects do not support multiple inheritance")
+        # param descriptions
+        cls._params = multidict(*[mso._params for mso in msos])

-        base = bases[0]
+        # port descriptions
+        cls._ports = multidict(*[mso._ports for mso in msos])

-        # Set up general inheritance via multidicts.  A subclass will
-        # inherit all its settings from the base class.  The only time
-        # the following is not true is when we define the SimObject
-        # class itself (in which case the multidicts have no parent).
-        if isinstance(base, MetaSimObject):
-            cls._base = base
-            cls._params.parent = base._params
-            cls._ports.parent = base._ports
-            cls._values.parent = base._values
-            cls._hr_values.parent = base._hr_values
-            cls._children.parent = base._children
-            cls._port_refs.parent = base._port_refs
-            # mark base as having been subclassed
-            base._instantiated = True
-        else:
-            cls._base = None
+
+        # Initialize class or instance attributes
+
+        # param values
+        cls._values = multidict(*[mso._values for mso in msos])
+
+        # human readable param values
+        cls._hr_values = multidict(*[mso._hr_values for mso in msos])
+
+        # SimObject children
+        cls._children = multidict(*[mso._children for mso in msos])
+
+        # port ref objects
+        cls._port_refs = multidict(*[mso._port_refs for mso in msos])
+
+        cls._instantiated = False
+
+        for mso in msos:
+            # Mark all bases as instantiated/cloned/subclassed
+            mso._instantiated = True
+

         # default keyword values
         if 'type' in cls._value_dict:
@@ -714,14 +706,12 @@
     py::module m = m_internal.def_submodule("param_${cls}");
 ''')
         code.indent()
-        if cls._base:
-            code('py::class_<${cls}Params, ${{cls._base.type}}Params, ' \
-                 'std::unique_ptr<${{cls}}Params, py::nodelete>>(' \
-                 'm, "${cls}Params")')
-        else:
-            code('py::class_<${cls}Params, ' \
-                 'std::unique_ptr<${cls}Params, py::nodelete>>(' \
-                 'm, "${cls}Params")')
+        base_params_str = 'Params, '.join(
+ [base.type for base in cls._bases] + [''])
+        code('py::class_<${cls}Params, ${base_params_str}' \
+             'std::unique_ptr<${{cls}}Params, py::nodelete>>(' \
+             'm, "${cls}Params", py::multiple_inheritance())')
+

         code.indent()
         if not hasattr(cls, 'abstract') or not cls.abstract:
@@ -748,22 +738,19 @@
             # overridden, use that value.
             if cls.cxx_base:
                 bases.append(cls.cxx_base)
-        elif cls._base:
-            # If not and if there was a SimObject base, use its c++ class
+        elif cls._bases:
+ # If not and if there were SimObject bases, use thier c++ classes
             # as this class' base.
-            bases.append(cls._base.cxx_class)
+            bases = [base.cxx_class for base in cls._bases]
         # Add in any extra bases that were requested.
         bases.extend(cls.cxx_extra_bases)

-        if bases:
-            base_str = ", ".join(bases)
-            code('py::class_<${{cls.cxx_class}}, ${base_str}, ' \
-                 'std::unique_ptr<${{cls.cxx_class}}, py::nodelete>>(' \
-                 'm, "${py_class_name}")')
-        else:
-            code('py::class_<${{cls.cxx_class}}, ' \
-                 'std::unique_ptr<${{cls.cxx_class}}, py::nodelete>>(' \
-                 'm, "${py_class_name}")')
+
+        base_str = ", ".join(bases + [''])
+        code('py::class_<${{cls.cxx_class}}, ${base_str}' \
+             'std::unique_ptr<${{cls.cxx_class}}, py::nodelete>>(' \
+             'm, "${py_class_name}")')
+
         code.indent()
         for exp in cls.cxx_exports:
             exp.export(code, cls.cxx_class)
@@ -773,8 +760,11 @@
         code.dedent()
         code('}')
         code()
- code('static EmbeddedPyBind embed_obj("${0}", module_init, "${1}");',
-             cls, cls._base.type if cls._base else "")
+ code('static EmbeddedPyBind embed_obj("${0}", module_init, {${1}});',
+             cls,
+             ', '.join(['"' + base.type + '"' for base in cls._bases])
+            )
+

     _warned_about_nested_templates = False

@@ -894,8 +884,8 @@
             port.cxx_predecls(code)
         code()

-        if cls._base:
-            code('#include "params/${{cls._base.type}}.hh"')
+        for base in cls._bases:
+            code('#include "params/${{base.type}}.hh"')
             code()

         for ptype in ptypes:
@@ -905,8 +895,10 @@

         # now generate the actual param struct
         code("struct ${cls}Params")
-        if cls._base:
-            code("    : public ${{cls._base.type}}Params")
+        if cls._bases:
+            code("    : " + ', '.join(
+                ['public virtual ' + base.type + 'Params' \
+                for base in cls._bases]))
         code("{")
         if not hasattr(cls, 'abstract') or not cls.abstract:
             if 'type' in cls.__dict__:
diff --git a/src/python/m5/util/multidict.py b/src/python/m5/util/multidict.py
old mode 100644
new mode 100755
index 2330156..a3d50dd
--- a/src/python/m5/util/multidict.py
+++ b/src/python/m5/util/multidict.py
@@ -1,3 +1,5 @@
+#!/usr/bin/env python2.7
+#
 # Copyright (c) 2005 The Regents of The University of Michigan
 # All rights reserved.
 #
@@ -27,149 +29,171 @@
 # Authors: Nathan Binkert

 from __future__ import print_function
+from abc import ABCMeta

 __all__ = [ 'multidict' ]

-class multidict(object):
-    def __init__(self, parent = {}, **kwargs):
-        self.local = dict(**kwargs)
-        self.parent = parent
-        self.deleted = {}
+def multidict(*args, **kw):
+    class metamultidict(type):
+        def __call__(cls, *args, **kw):
+            raise TypeError('\'{}\' object is not callable'
+                            .format(cls.__name__))

-    def __str__(self):
-        return str(dict(self.items()))
+        def __contains__(cls, key):
+            for md in cls.__mro__:
+                if not isinstance(md, type(multidict)):
+                    break # Stop when we reach the 'object' class in mro
+                if key in md.ignored:
+                    return False
+                if key in md.local:
+                    return True
+            return False

-    def __repr__(self):
-        return repr(dict(list(self.items())))
+        def __delitem__(cls, key):
+            cls.ignored.add(key)

-    def __contains__(self, key):
-        return key in self.local or key in self.parent
+        def __getitem__(cls, key):
+            for md in cls.__mro__:
+                if not isinstance(md, type(multidict)):
+                    break
+                if key in md.ignored:
+                    raise KeyError(key)
+                if key in md.local:
+                    return md.local[key]
+            raise KeyError(key)

-    def __delitem__(self, key):
-        try:
-            del self.local[key]
-        except KeyError as e:
-            if key in self.parent:
-                self.deleted[key] = True
-            else:
-                raise KeyError(e)
+        def __iter__(cls):
+            ignored = set([])
+            for md in cls.__mro__:
+                if not isinstance(md, type(multidict)):
+                    break
+                ignored |= md.ignored
+                for key in md.local:
+                    if key not in ignored:
+                        yield key
+                        ignored.add(key)

-    def __setitem__(self, key, value):
-        self.deleted.pop(key, False)
-        self.local[key] = value
+        def __repr__(cls):
+            args = []
+            for b in cls.bases:
+                args.append(repr(b))
+            for k, v in cls.local.items():
+                args.append('{}={}'.format(k, v))

-    def __getitem__(self, key):
-        try:
-            return self.local[key]
-        except KeyError as e:
-            if not self.deleted.get(key, False) and key in self.parent:
-                return self.parent[key]
-            else:
-                raise KeyError(e)
+            str = 'multidict(' + ', '.join(args) + ')'
+            if cls.ignored:
+                str += '.ignore(' + ', '.join(cls.ignored) + ')'
+            return str

-    def __len__(self):
-        return len(self.local) + len(self.parent)
+        def __setitem__(cls, key, value):
+            cls.ignored.discard(key)
+            cls.local[key] = value

-    def next(self):
-        for key,value in self.local.items():
-            yield key,value
+        def __str__(cls):
+ return '<multidict instance \'' + str(dict(cls.items())) + '\'>'

-        if self.parent:
-            for key,value in self.parent.next():
-                if key not in self.local and key not in self.deleted:
-                    yield key,value

-    def has_key(self, key):
-        return key in self
+        def clear(cls):
+            for key in cls.keys():
+                del cls[key]

-    def items(self):
-        for item in self.next():
-            yield item
+        def copy(cls):
+            return type('multidict',
+                        cls.bases if cls.bases else (multidict,),
+                        {'local':cls.local.copy(),
+                        'ignored':cls.ignored.copy(),
+                        'bases':cls.bases})
+        def get(cls, key, *args, **kw):
+            if kw:
+                raise TypeError('get() takes no keyword arguments')
+            if len(args) > 1:
+ raise TypeError('get() expected at most 2 arguments, got {}'
+                                .format(len(args)))
+            try:
+                return cls[key]
+            except KeyError:
+                return args[0] if len(args) > 0 else None

-    def keys(self):
-        for key,value in self.next():
-            yield key
+        has_key = __contains__

-    def values(self):
-        for key,value in self.next():
-            yield value
+        def ignore(cls, *args):
+            cls.ignored.update(*args)
+            return cls

-    def get(self, key, default=None):
-        try:
-            return self[key]
-        except KeyError as e:
-            return default
+        def items(cls):
+            return [(k, cls[k]) for k in cls]

-    def setdefault(self, key, default):
-        try:
-            return self[key]
-        except KeyError:
-            self.deleted.pop(key, False)
-            self.local[key] = default
-            return default
+        def iteritems(cls):
+            for k in cls:
+                yield (k, cls[k])

-    def _dump(self):
-        print('multidict dump')
-        node = self
-        while isinstance(node, multidict):
-            print('    ', node.local)
-            node = node.parent
+        def iterkeys(cls):
+            for k in cls:
+                yield k

-    def _dumpkey(self, key):
-        values = []
-        node = self
-        while isinstance(node, multidict):
-            if key in node.local:
-                values.append(node.local[key])
-            node = node.parent
-        print(key, values)
+        def itervalues(cls):
+            for k in cls:
+                yield cls[k]

-if __name__ == '__main__':
-    test1 = multidict()
-    test2 = multidict(test1)
-    test3 = multidict(test2)
-    test4 = multidict(test3)
+        def keys(cls):
+            return [k for k in cls]

-    test1['a'] = 'test1_a'
-    test1['b'] = 'test1_b'
-    test1['c'] = 'test1_c'
-    test1['d'] = 'test1_d'
-    test1['e'] = 'test1_e'
+        def pop(cls, key, *args, **kw):
+            if kw:
+                raise TypeError('pop() takes no keyword arguments')
+            if len(args) > 1:
+ raise TypeError('pop() expected at most 2 arguments, got {}'
+                                .format(len(args)))
+            try:
+                value = cls[key]
+                del cls[key]
+                return value
+            except KeyError:
+                if len(args) > 0: return args[0]
+                else: raise

-    test2['a'] = 'test2_a'
-    del test2['b']
-    test2['c'] = 'test2_c'
-    del test1['a']
+        def popitem(self):
+            try:
+                return self.pop(self.iterkeys().next())
+            except StopIteration:
+                raise KeyError('popitem(): dictionary is empty')

-    test2.setdefault('f', multidict)
+        def setdefault(cls, key, *args, **kw):
+            if kw:
+                raise TypeError('setdefault() takes no keyword arguments')
+            if len(args) > 1:
+ raise TypeError('setdefault() expected at most 2 arguments,'\
+                                'got {}'.format(len(args)))
+            if key not in cls:
+                cls[key] = args[0] if len(args) > 0 else None
+            return cls[key]

-    print('test1>', list(test1.items()))
-    print('test2>', list(test2.items()))
-    #print(test1['a'])
-    print(test1['b'])
-    print(test1['c'])
-    print(test1['d'])
-    print(test1['e'])
+        def update(cls, *args, **kw):
+            cls.local.update(*args, **kw)

-    print(test2['a'])
-    #print(test2['b'])
-    print(test2['c'])
-    print(test2['d'])
-    print(test2['e'])
+        def values(cls):
+            return [cls[k] for k in cls]

-    for key in test2.keys():
-        print(key)
+    class multidict(object):
+        __metaclass__ = metamultidict
+        # Defaults
+        ignored = set()
+        local = {}
+        bases = ()

-    test2.get('g', 'foo')
-    #test2.get('b')
-    test2.get('b', 'bar')
-    test2.setdefault('b', 'blah')
-    print(test1)
-    print(test2)
-    print(repr(test2))
+    bases =[]
+    for arg in args:
+        if type(arg) is dict:
+            # If arg is a normal dict object, wrap it with a multidict
+            bases.append(multidict(**arg))
+        elif arg.__name__ == 'multidict':
+            bases.append(arg)
+        else:
+            raise TypeError('Invalid base type \'{}\''
+                            .format(type(arg).__name__))

-    print(len(test2))
-
-    test3['a'] = [ 0, 1, 2, 3 ]
-
-    print(test4)
+    bases = tuple(bases)
+    _dict = {'local':kw,
+             'ignored':set(),
+             'bases':bases}
+    return type('multidict', bases if bases else (multidict,), _dict)
diff --git a/src/sim/init.cc b/src/sim/init.cc
index 1fb7e6e..579d7f9 100644
--- a/src/sim/init.cc
+++ b/src/sim/init.cc
@@ -154,15 +154,15 @@

 EmbeddedPyBind::EmbeddedPyBind(const char *_name,
                                void (*init_func)(py::module &),
-                               const char *_base)
-    : initFunc(init_func), registered(false), name(_name), base(_base)
+                               const std::vector<std::string> _bases)
+    : initFunc(init_func), registered(false), name(_name), bases(_bases)
 {
     getMap()[_name] = this;
 }

 EmbeddedPyBind::EmbeddedPyBind(const char *_name,
                                void (*init_func)(py::module &))
-    : initFunc(init_func), registered(false), name(_name), base("")
+    : initFunc(init_func), registered(false), name(_name), bases({})
 {
     getMap()[_name] = this;
 }
@@ -181,7 +181,14 @@
 bool
 EmbeddedPyBind::depsReady() const
 {
-    return base.empty() || getMap()[base]->registered;
+    if (bases.empty())
+      return true;
+
+    for (auto & base: bases) {
+      if (!getMap()[base]->registered)
+        return false;
+    }
+    return true;
 }

 std::map<std::string, EmbeddedPyBind *> &
@@ -211,7 +218,7 @@

     for (auto &kv : getMap()) {
         auto &obj = kv.second;
-        if (obj->base.empty()) {
+        if (obj->bases.empty()) {
             obj->init(m_m5);
         } else {
             pending.push_back(obj);
diff --git a/src/sim/init.hh b/src/sim/init.hh
index 40ff9ae..f0bfe12 100644
--- a/src/sim/init.hh
+++ b/src/sim/init.hh
@@ -45,12 +45,12 @@

 #include "pybind11/pybind11.h"

+#include <inttypes.h>
+
 #include <list>
 #include <map>
 #include <string>

-#include <inttypes.h>
-
 #ifndef PyObject_HEAD
 struct _object;
 typedef _object PyObject;
@@ -85,7 +85,7 @@
   public:
     EmbeddedPyBind(const char *_name,
                    void (*init_func)(pybind11::module &),
-                   const char *_base);
+                   const std::vector<std::string> _bases);

     EmbeddedPyBind(const char *_name,
                    void (*init_func)(pybind11::module &));
@@ -104,7 +104,7 @@

     bool registered;
     const std::string name;
-    const std::string base;
+    const std::vector<std::string> bases;

     static std::map<std::string, EmbeddedPyBind *> &getMap();
 };

--
To view, visit https://gem5-review.googlesource.com/c/public/gem5/+/17610
To unsubscribe, or for help writing mail filters, visit https://gem5-review.googlesource.com/settings

Gerrit-Project: public/gem5
Gerrit-Branch: master
Gerrit-Change-Id: I95a252f8a5bf6cadbc0f8b07be84b4a75bf4f4ba
Gerrit-Change-Number: 17610
Gerrit-PatchSet: 1
Gerrit-Owner: Ryan Gambord <[email protected]>
Gerrit-MessageType: newchange
_______________________________________________
gem5-dev mailing list
[email protected]
http://m5sim.org/mailman/listinfo/gem5-dev

Reply via email to