roshanjrajan-zip commented on code in PR #2825:
URL: https://github.com/apache/thrift/pull/2825#discussion_r1248042615


##########
compiler/cpp/src/thrift/generate/t_py_generator.cc:
##########
@@ -905,6 +905,32 @@ void 
t_py_generator::generate_py_struct_definition(ostream& out,
     }
 
     out << "))" << endl;
+  } else if (gen_enum_) {
+    bool has_enum = false;
+    for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+      t_type* type = (*m_iter)->get_type();
+      if (type->is_enum()) {
+        has_enum = true;
+        break;
+      }
+    }
+
+    if (has_enum) {
+      out << endl;
+      indent(out) << "def __setattr__(self, name, value):" << endl;
+      indent_up();
+      for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {

Review Comment:
   Updated the test! Here is the entire generated code file
   ```python
   #
   # Autogenerated by Thrift Compiler (0.19.0)
   #
   # DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
   #
   #  options string: py:enum
   #
   
   from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, 
TApplicationException
   from thrift.protocol.TProtocol import TProtocolException
   from thrift.TRecursive import fix_spec
   from enum import IntEnum
   
   import sys
   import shared_types.ttypes
   
   from thrift.transport import TTransport
   all_structs = []
   
   
   class TestEnum(IntEnum):
       TestEnum0 = 0
       TestEnum1 = 1
   
   
   
   class TestStruct(object):
       """
       Attributes:
        - param1
        - param2
        - param3
   
       """
   
   
       def __init__(self, param1=None, param2=None, param3=None,):
           self.param1 = param1
           self.param2 = param2
           self.param3 = param3
   
       def __setattr__(self, name, value):
           if name == "param2":
               super().__setattr__(name, value if hasattr(value, 'value') else 
TestEnum.__members__.get(value))
               return
           if name == "param3":
               super().__setattr__(name, value if hasattr(value, 'value') else 
shared_types.ttypes.SharedEnum.__members__.get(value))
               return
           super().__setattr__(name, value)
   
   
       def read(self, iprot):
           if iprot._fast_decode is not None and isinstance(iprot.trans, 
TTransport.CReadableTransport) and self.thrift_spec is not None:
               iprot._fast_decode(self, iprot, [self.__class__, 
self.thrift_spec])
               return
           iprot.readStructBegin()
           while True:
               (fname, ftype, fid) = iprot.readFieldBegin()
               if ftype == TType.STOP:
                   break
               if fid == 1:
                   if ftype == TType.STRING:
                       self.param1 = iprot.readString().decode('utf-8', 
errors='replace') if sys.version_info[0] == 2 else iprot.readString()
                   else:
                       iprot.skip(ftype)
               elif fid == 2:
                   if ftype == TType.I32:
                       self.param2 = TestEnum(iprot.readI32())
                   else:
                       iprot.skip(ftype)
               elif fid == 3:
                   if ftype == TType.I32:
                       self.param3 = 
shared_types.ttypes.SharedEnum(iprot.readI32())
                   else:
                       iprot.skip(ftype)
               else:
                   iprot.skip(ftype)
               iprot.readFieldEnd()
           iprot.readStructEnd()
   
       def write(self, oprot):
           if oprot._fast_encode is not None and self.thrift_spec is not None:
               oprot.trans.write(oprot._fast_encode(self, [self.__class__, 
self.thrift_spec]))
               return
           oprot.writeStructBegin('TestStruct')
           if self.param1 is not None:
               oprot.writeFieldBegin('param1', TType.STRING, 1)
               oprot.writeString(self.param1.encode('utf-8') if 
sys.version_info[0] == 2 else self.param1)
               oprot.writeFieldEnd()
           if self.param2 is not None:
               oprot.writeFieldBegin('param2', TType.I32, 2)
               oprot.writeI32(self.param2.value)
               oprot.writeFieldEnd()
           if self.param3 is not None:
               oprot.writeFieldBegin('param3', TType.I32, 3)
               oprot.writeI32(self.param3.value)
               oprot.writeFieldEnd()
           oprot.writeFieldStop()
           oprot.writeStructEnd()
   
       def validate(self):
           return
   
       def __repr__(self):
           L = ['%s=%r' % (key, value)
                for key, value in self.__dict__.items()]
           return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
   
       def __eq__(self, other):
           return isinstance(other, self.__class__) and self.__dict__ == 
other.__dict__
   
       def __ne__(self, other):
           return not (self == other)
   all_structs.append(TestStruct)
   TestStruct.thrift_spec = (
       None,  # 0
       (1, TType.STRING, 'param1', 'UTF8', None, ),  # 1
       (2, TType.I32, 'param2', None, None, ),  # 2
       (3, TType.I32, 'param3', None, None, ),  # 3
   )
   fix_spec(all_structs)
   del all_structs
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to