Fix CheckDomain and CheckPortType classes to properly deal with aliases.
Resolves: rhbz#1600009
---
python/sepolicy/sepolicy.py | 8 +++-----
python/sepolicy/sepolicy/__init__.py | 10 +++++++++-
2 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/python/sepolicy/sepolicy.py b/python/sepolicy/sepolicy.py
index a000c1ad..01380fbe 100755
--- a/python/sepolicy/sepolicy.py
+++ b/python/sepolicy/sepolicy.py
@@ -60,8 +60,6 @@ class CheckPath(argparse.Action):
class CheckType(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
- domains = sepolicy.get_all_domains()
-
if isinstance(values, str):
setattr(namespace, self.dest, values)
else:
@@ -103,7 +101,7 @@ class CheckDomain(argparse.Action):
domains = sepolicy.get_all_domains()
if isinstance(values, str):
- if values not in domains:
+ if sepolicy.get_real_type_name(values) not in domains:
raise ValueError("%s must be an SELinux process domain:\nValid
domains: %s" % (values, ", ".join(domains)))
setattr(namespace, self.dest, values)
else:
@@ -112,7 +110,7 @@ class CheckDomain(argparse.Action):
newval = []
for v in values:
- if v not in domains:
+ if sepolicy.get_real_type_name(v) not in domains:
raise ValueError("%s must be an SELinux process
domain:\nValid domains: %s" % (v, ", ".join(domains)))
newval.append(v)
setattr(namespace, self.dest, newval)
@@ -167,7 +165,7 @@ class CheckPortType(argparse.Action):
if not newval:
newval = []
for v in values:
- if v not in port_types:
+ if sepolicy.get_real_type_name(v) not in port_types:
raise ValueError("%s must be an SELinux port type:\nValid port
types: %s" % (v, ", ".join(port_types)))
newval.append(v)
setattr(namespace, self.dest, values)
diff --git a/python/sepolicy/sepolicy/__init__.py
b/python/sepolicy/sepolicy/__init__.py
index 8484b28c..0da3917b 100644
--- a/python/sepolicy/sepolicy/__init__.py
+++ b/python/sepolicy/sepolicy/__init__.py
@@ -447,6 +447,14 @@ def get_file_types(setype):
return mpaths
+# determine if entered type is an alias
+# and return corresponding type name
+def get_real_type_name(name):
+ try:
+ return next(info(TYPE, name))["name"]
+ except (RuntimeError, StopIteration):
+ return None
+
def get_writable_files(setype):
file_types = get_all_file_types()
all_writes = []
@@ -1061,7 +1069,7 @@ def gen_short_name(setype):
domainname = setype[:-2]
else:
domainname = setype
- if domainname + "_t" not in all_domains:
+ if get_real_type_name(domainname + "_t") not in all_domains:
raise ValueError("domain %s_t does not exist" % domainname)
if domainname[-1] == 'd':
short_name = domainname[:-1] + "_"
--
2.17.1
_______________________________________________
Selinux mailing list
[email protected]
To unsubscribe, send email to [email protected].
To get help, send an email containing "help" to [email protected].