This patch handles QAPI union types and generates the equivalent data structures and methods in Go to handle it.
The QAPI union type has two types of fields: The @base and the @Variants members. The @base fields can be considered common members for the union while only one field maximum is set for the @Variants. In the QAPI specification, it defines a @discriminator field, which is an Enum type. The purpose of the @discriminator is to identify which @variant type is being used. Not that @discriminator's enum might have more values than the union's data struct. This is fine. The union does not need to handle all cases of the enum, but it should accept them without error. For this specific case, we keep the @discriminator field in every union type. The union types implement the Marshaler and Unmarshaler interfaces to seamless decode from JSON objects to Golang structs and vice versa. qapi: | { 'union': 'SetPasswordOptions', | 'base': { 'protocol': 'DisplayProtocol', | 'password': 'str', | '*connected': 'SetPasswordAction' }, | 'discriminator': 'protocol', | 'data': { 'vnc': 'SetPasswordOptionsVnc' } } go: | type SetPasswordOptions struct { | Protocol DisplayProtocol `json:"protocol"` | Password string `json:"password"` | Connected *SetPasswordAction `json:"connected,omitempty"` | | // Variants fields | Vnc *SetPasswordOptionsVnc `json:"-"` | } Signed-off-by: Victor Toso <victort...@redhat.com> --- scripts/qapi/golang.py | 170 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 167 insertions(+), 3 deletions(-) diff --git a/scripts/qapi/golang.py b/scripts/qapi/golang.py index 8320af99b6..343c9c9b95 100644 --- a/scripts/qapi/golang.py +++ b/scripts/qapi/golang.py @@ -52,6 +52,17 @@ } return nil } + +// This helper is used to move struct's fields into a map. +// This function is useful to merge JSON objects. +func unwrapToMap(m map[string]any, data any) error { + if bytes, err := json.Marshal(&data); err != nil { + return fmt.Errorf("unwrapToMap: %s", err) + } else if err := json.Unmarshal(bytes, &m); err != nil { + return fmt.Errorf("unwrapToMap: %s, data=%s", err, string(bytes)) + } + return nil +} ''' TEMPLATE_ALTERNATE = ''' @@ -131,6 +142,62 @@ }} ''' +TEMPLATE_UNION_CHECK_FIELD = ''' +if s.{field} != nil && err == nil {{ + if len(bytes) != 0 {{ + err = errors.New(`multiple variant fields set`) + }} else if err = unwrapToMap(m, s.{field}); err == nil {{ + s.{discriminator} = {go_enum_value} + m["{member_name}"] = {go_enum_value} + bytes, err = json.Marshal(m) + }} +}} +''' + +TEMPLATE_UNION_DRIVER_CASE = ''' +case {go_enum_value}: + s.{field} = new({member_type}) + if err := json.Unmarshal(data, s.{field}); err != nil {{ + s.{field} = nil + return err + }}''' + +TEMPLATE_UNION_METHODS = ''' +func (s {type_name}) MarshalJSON() ([]byte, error) {{ + var bytes []byte + var err error + type Alias {type_name} + v := Alias(s) + m := make(map[string]any) + unwrapToMap(m, &v) + {check_fields} + {check_non_fields_marshal} + if err != nil {{ + return nil, fmt.Errorf("error: marshal: {type_name}: reason='%s', struct='%+v'", err, s) + }} else if len(bytes) == 0 {{ + return nil, fmt.Errorf("error: marshal: {type_name} unsupported, struct='%+v'", s) + }} + return bytes, nil +}} + +func (s *{type_name}) UnmarshalJSON(data []byte) error {{{base_type_def} + tmp := struct {{ + {base_type_name} + }}{{}} + + if err := json.Unmarshal(data, &tmp); err != nil {{ + return err + }} + {base_type_assign_unmarshal} + switch tmp.{discriminator} {{ + {driver_cases} + {check_non_fields_unmarshal} + }} + return nil +}} +''' + + def gen_golang(schema: QAPISchema, output_dir: str, prefix: str) -> None: @@ -428,6 +495,98 @@ def qapi_to_golang_struct(self: QAPISchemaGenGolangVisitor, variants) return content +def qapi_to_golang_methods_union(self: QAPISchemaGenGolangVisitor, + name: str, + base: Optional[QAPISchemaObjectType], + variants: Optional[QAPISchemaVariants] + ) -> str: + + type_name = qapi_to_go_type_name(name) + + assert base + base_type_assign_unmarshal = "" + base_type_name = qapi_to_go_type_name(base.name) + base_type_def = qapi_to_golang_struct(self, + base.name, + base.info, + base.ifcond, + base.features, + base.base, + base.members, + base.variants) + for member in base.local_members: + field = qapi_to_field_name(member.name) + base_type_assign_unmarshal += f'''s.{field} = tmp.{field} +''' + + driver_cases = "" + check_fields = "" + exists = {} + enum_name = variants.tag_member.type.name + discriminator = qapi_to_field_name(variants.tag_member.name) + if variants: + for var in variants.variants: + if var.type.is_implicit(): + continue + + field = qapi_to_field_name(var.name) + enum_value = qapi_to_field_name_enum(var.name) + member_type = qapi_schema_type_to_go_type(var.type.name) + go_enum_value = f'''{enum_name}{enum_value}''' + exists[go_enum_value] = True + + check_fields += TEMPLATE_UNION_CHECK_FIELD.format(field=field, + discriminator=discriminator, + member_name=variants.tag_member.name, + go_enum_value=go_enum_value) + driver_cases += TEMPLATE_UNION_DRIVER_CASE.format(go_enum_value=go_enum_value, + field=field, + member_type=member_type) + + check_non_fields_marshal = "" + check_non_fields_unmarshal = "" + enum_obj = self.schema.lookup_entity(enum_name) + if len(exists) != len(enum_obj.members): + driver_cases += '''\ndefault:''' + check_non_fields_marshal = ''' + // Check for valid values without field members + if len(bytes) == 0 && err == nil && + (''' + check_non_fields_unmarshal = ''' + // Check for valid values without field members + if ''' + + for member in enum_obj.members: + value = qapi_to_field_name_enum(member.name) + go_enum_value = f'''{enum_name}{value}''' + + if go_enum_value in exists: + continue + + check_non_fields_marshal += f'''s.{discriminator} == {go_enum_value} ||\n''' + check_non_fields_unmarshal += f'''tmp.{discriminator} != {go_enum_value} &&\n''' + + check_non_fields_marshal = f'''{check_non_fields_marshal[:-3]}) {{ + type Alias {type_name} + bytes, err = json.Marshal(Alias(s)) + }} +''' + check_non_fields_unmarshal = f'''{check_non_fields_unmarshal[1:-3]} {{ + return fmt.Errorf("error: unmarshal: {type_name}: received unrecognized value: '%s'", + tmp.{discriminator}) + }} +''' + + return TEMPLATE_UNION_METHODS.format(type_name=type_name, + check_fields=check_fields, + check_non_fields_marshal=check_non_fields_marshal, + base_type_def=base_type_def, + base_type_name=base_type_name, + base_type_assign_unmarshal=base_type_assign_unmarshal, + discriminator=discriminator, + driver_cases=driver_cases[1:], + check_non_fields_unmarshal=check_non_fields_unmarshal) + def generate_template_alternate(self: QAPISchemaGenGolangVisitor, name: str, variants: Optional[QAPISchemaVariants]) -> str: @@ -490,7 +649,7 @@ class QAPISchemaGenGolangVisitor(QAPISchemaVisitor): def __init__(self, _: str): super().__init__() - types = ["alternate", "enum", "helper", "struct"] + types = ["alternate", "enum", "helper", "struct", "union"] self.target = {name: "" for name in types} self.objects_seen = {} self.schema = None @@ -530,10 +689,10 @@ def visit_object_type(self: QAPISchemaGenGolangVisitor, members: List[QAPISchemaObjectTypeMember], variants: Optional[QAPISchemaVariants] ) -> None: - # Do not handle anything besides struct. + # Do not handle anything besides struct and unions. if (name == self.schema.the_empty_object_type.name or not isinstance(name, str) or - info.defn_meta not in ["struct"]): + info.defn_meta not in ["struct", "union"]): return # Base structs are embed @@ -566,6 +725,11 @@ def visit_object_type(self: QAPISchemaGenGolangVisitor, base, members, variants) + if info.defn_meta == "union": + self.target[info.defn_meta] += qapi_to_golang_methods_union(self, + name, + base, + variants) def visit_alternate_type(self: QAPISchemaGenGolangVisitor, name: str, -- 2.41.0