Signed-off-by: Yuichi Ito <[email protected]>
---
 ryu/lib/packet/sctp.py             |   89 ++++++++++++++++++++++----------
 ryu/tests/unit/packet/test_sctp.py |  100 ++++++++++++++++++------------------
 2 files changed, 110 insertions(+), 79 deletions(-)

diff --git a/ryu/lib/packet/sctp.py b/ryu/lib/packet/sctp.py
index 6a6e89b..666e1e9 100644
--- a/ryu/lib/packet/sctp.py
+++ b/ryu/lib/packet/sctp.py
@@ -1125,7 +1125,7 @@ class cause(stringify.StringifyMixin):
     def cause_code(cls):
         pass

-    def __init__(self, length):
+    def __init__(self, length=0):
         self.length = length

     @classmethod
@@ -1149,7 +1149,7 @@ class cause_with_value(cause):

     __metaclass__ = abc.ABCMeta

-    def __init__(self, length, value):
+    def __init__(self, value, length=0):
         super(cause_with_value, self).__init__(length)
         self.value = value

@@ -1160,13 +1160,16 @@ class cause_with_value(cause):
         if (cls._MIN_LEN < length):
             fmt = '%ds' % (length - cls._MIN_LEN)
             (value, ) = struct.unpack_from(fmt, buf, cls._MIN_LEN)
-        return cls(length, value)
+        return cls(value, length)

     def serialize(self):
         buf = bytearray(struct.pack(
             self._PACK_STR, self.cause_code(), self.length))
         if self.value:
             buf.extend(self.value)
+        if 0 == self.length:
+            self.length = len(buf)
+            struct.pack_into('!H', buf, 2, self.length)
         mod = len(buf) % 4
         if mod:
             buf.extend(bytearray(4 - mod))
@@ -1206,9 +1209,11 @@ class cause_invalid_stream_id(cause_with_value):
     @classmethod
     def parser(cls, buf):
         (_, length, value) = struct.unpack_from(cls._PACK_STR, buf)
-        return cls(length, value)
+        return cls(value, length)

     def serialize(self):
+        if 0 == self.length:
+            self.length = self._MIN_LEN
         buf = struct.pack(
             self._PACK_STR, self.cause_code(), self.length, self.value)
         return buf
@@ -1232,9 +1237,9 @@ class cause_missing_param(cause):
     ============== =====================================================
     Attribute      Description
     ============== =====================================================
-    length         length of this cause containing this header.
-    num            Number of missing params.
     types          a list of missing params.
+    num            Number of missing params.
+    length         length of this cause containing this header.
     ============== =====================================================
     """

@@ -1245,14 +1250,16 @@ class cause_missing_param(cause):
     def cause_code(cls):
         return CCODE_MISSING_PARAM

-    def __init__(self, length, num, types=None):
+    def __init__(self, types=None, num=0, length=0):
         super(cause_missing_param, self).__init__(length)
-        self.num = num
         types = types or []
         assert isinstance(types, list)
         for one in types:
             assert isinstance(one, int)
         self.types = types
+        if 0 == num:
+            num = len(self.types)
+        self.num = num

     @classmethod
     def parser(cls, buf):
@@ -1263,13 +1270,16 @@ class cause_missing_param(cause):
             offset = cls._MIN_LEN + (struct.calcsize('!H') * count)
             (one, ) = struct.unpack_from('!H', buf, offset)
             types.append(one)
-        return cls(length, num, types)
+        return cls(types, num, length)

     def serialize(self):
         buf = bytearray(struct.pack(
             self._PACK_STR, self.cause_code(), self.length, self.num))
         for one in self.types:
             buf.extend(struct.pack('!H', one))
+        if 0 == self.length:
+            self.length = len(buf)
+            struct.pack_into('!H', buf, 2, self.length)
         mod = len(buf) % 4
         if mod:
             buf.extend(bytearray(4 - mod))
@@ -1336,6 +1346,8 @@ class cause_out_of_resource(cause):
         return cls(length)

     def serialize(self):
+        if 0 == self.length:
+            self.length = self._MIN_LEN
         buf = struct.pack(
             self._PACK_STR, self.cause_code(), self.length)
         return buf
@@ -1389,13 +1401,15 @@ class cause_unresolvable_addr(cause_with_value):
         (ptype, ) = struct.unpack_from('!H', buf, cls._MIN_LEN)
         cls_ = cls._RECOGNIZED_PARAMS.get(ptype)
         value = cls_.parser(buf[cls._MIN_LEN:])
-        return cls(length, value)
+        return cls(value, length)

     def serialize(self):
         buf = bytearray(struct.pack(
             self._PACK_STR, self.cause_code(), self.length))
-        if self.value:
-            buf.extend(self.value.serialize())
+        buf.extend(self.value.serialize())
+        if 0 == self.length:
+            self.length = len(buf)
+            struct.pack_into('!H', buf, 2, self.length)
         mod = len(buf) % 4
         if mod:
             buf.extend(bytearray(4 - mod))
@@ -1462,6 +1476,8 @@ class cause_invalid_param(cause):
         return cls(length)

     def serialize(self):
+        if 0 == self.length:
+            self.length = self._MIN_LEN
         buf = struct.pack(
             self._PACK_STR, self.cause_code(), self.length)
         return buf
@@ -1557,6 +1573,8 @@ class cause_cookie_while_shutdown(cause):
         return cls(length)

     def serialize(self):
+        if 0 == self.length:
+            self.length = self._MIN_LEN
         buf = struct.pack(
             self._PACK_STR, self.cause_code(), self.length)
         return buf
@@ -1600,10 +1618,10 @@ class cause_restart_with_new_addr(cause_with_value):
     def cause_code(cls):
         return CCODE_RESTART_WITH_NEW_ADDR

-    def __init__(self, length, value):
+    def __init__(self, value, length=0):
         if not isinstance(value, list):
             value = [value]
-        super(cause_restart_with_new_addr, self).__init__(length, value)
+        super(cause_restart_with_new_addr, self).__init__(value, length)

     @classmethod
     def parser(cls, buf):
@@ -1618,13 +1636,16 @@ class cause_restart_with_new_addr(cause_with_value):
             ins = cls_.parser(buf[offset:])
             value.append(ins)
             offset += len(ins)
-        return cls(length, value)
+        return cls(value, length)

     def serialize(self):
         buf = bytearray(struct.pack(
             self._PACK_STR, self.cause_code(), self.length))
         for one in self.value:
             buf.extend(one.serialize())
+        if 0 == self.length:
+            self.length = len(buf)
+            struct.pack_into('!H', buf, 2, self.length)
         mod = len(buf) % 4
         if mod:
             buf.extend(bytearray(4 - mod))
@@ -1703,7 +1724,7 @@ class param(stringify.StringifyMixin):
     def param_type(cls):
         pass

-    def __init__(self, length, value):
+    def __init__(self, value, length=0):
         self.length = length
         self.value = value

@@ -1714,13 +1735,16 @@ class param(stringify.StringifyMixin):
         if (cls._MIN_LEN < length):
             fmt = '%ds' % (length - cls._MIN_LEN)
             (value, ) = struct.unpack_from(fmt, buf, cls._MIN_LEN)
-        return cls(length, value)
+        return cls(value, length)

     def serialize(self):
         buf = bytearray(struct.pack(
             self._PACK_STR, self.param_type(), self.length))
         if self.value:
             buf.extend(self.value)
+        if 0 == self.length:
+            self.length = len(buf)
+            struct.pack_into('!H', buf, 2, self.length)
         mod = len(buf) % 4
         if mod:
             buf.extend(bytearray(4 - mod))
@@ -1845,9 +1869,11 @@ class param_cookie_preserve(param):
     @classmethod
     def parser(cls, buf):
         (_, length, value) = struct.unpack_from(cls._PACK_STR, buf)
-        return cls(length, value)
+        return cls(value, length)

     def serialize(self):
+        if 0 == self.length:
+            self.length = self._MIN_LEN
         buf = struct.pack(
             self._PACK_STR, self.param_type(), self.length, self.value)
         return buf
@@ -1880,9 +1906,9 @@ class param_ecn(param):
     def param_type(cls):
         return PTYPE_ECN

-    def __init__(self, length, value):
-        super(param_ecn, self).__init__(length, value)
-        assert 4 == length
+    def __init__(self, value=None, length=0):
+        super(param_ecn, self).__init__(value, length)
+        assert 4 == length or 0 == length
         assert None is value


@@ -1944,14 +1970,12 @@ class param_supported_addr(param):
     def param_type(cls):
         return PTYPE_SUPPORTED_ADDR

-    def __init__(self, length, value):
-        super(param_supported_addr, self).__init__(length, value)
+    def __init__(self, value, length=0):
         if not isinstance(value, list):
             value = [value]
         for one in value:
             assert isinstance(one, int)
-        self.length = length
-        self.value = value
+        super(param_supported_addr, self).__init__(value, length)

     @classmethod
     def parser(cls, buf):
@@ -1962,13 +1986,16 @@ class param_supported_addr(param):
             (one, ) = struct.unpack_from(cls._VALUE_STR, buf, offset)
             value.append(one)
             offset += cls._VALUE_LEN
-        return cls(length, value)
+        return cls(value, length)

     def serialize(self):
         buf = bytearray(struct.pack(
             self._PACK_STR, self.param_type(), self.length))
         for one in self.value:
             buf.extend(struct.pack(param_supported_addr._VALUE_STR, one))
+        if 0 == self.length:
+            self.length = len(buf)
+            struct.pack_into('!H', buf, 2, self.length)
         mod = len(buf) % 4
         if mod:
             buf.extend(bytearray(4 - mod))
@@ -2013,13 +2040,16 @@ class param_ipv4(param):
         if (cls._MIN_LEN < length):
             fmt = '%ds' % (length - cls._MIN_LEN)
             (value, ) = struct.unpack_from(fmt, buf, cls._MIN_LEN)
-        return cls(length, addrconv.ipv4.bin_to_text(value))
+        return cls(addrconv.ipv4.bin_to_text(value), length)

     def serialize(self):
         buf = bytearray(struct.pack(
             self._PACK_STR, self.param_type(), self.length))
         if self.value:
             buf.extend(addrconv.ipv4.text_to_bin(self.value))
+        if 0 == self.length:
+            self.length = len(buf)
+            struct.pack_into('!H', buf, 2, self.length)
         return str(buf)


@@ -2061,11 +2091,14 @@ class param_ipv6(param):
         if (cls._MIN_LEN < length):
             fmt = '%ds' % (length - cls._MIN_LEN)
             (value, ) = struct.unpack_from(fmt, buf, cls._MIN_LEN)
-        return cls(length, addrconv.ipv6.bin_to_text(value))
+        return cls(addrconv.ipv6.bin_to_text(value), length)

     def serialize(self):
         buf = bytearray(struct.pack(
             self._PACK_STR, self.param_type(), self.length))
         if self.value:
             buf.extend(addrconv.ipv6.text_to_bin(self.value))
+        if 0 == self.length:
+            self.length = len(buf)
+            struct.pack_into('!H', buf, 2, self.length)
         return str(buf)
diff --git a/ryu/tests/unit/packet/test_sctp.py 
b/ryu/tests/unit/packet/test_sctp.py
index adacb96..6f6132a 100644
--- a/ryu/tests/unit/packet/test_sctp.py
+++ b/ryu/tests/unit/packet/test_sctp.py
@@ -81,13 +81,12 @@ class Test_sctp(unittest.TestCase):
         self.mis = 3
         self.i_tsn = 123456

-        self.p_ipv4 = sctp.param_ipv4(8, '192.168.1.1')
-        self.p_ipv6 = sctp.param_ipv6(20, 'fe80::647e:1aff:fec4:8284')
-        self.p_cookie_preserve = sctp.param_cookie_preserve(8, 5000)
-        self.p_ecn = sctp.param_ecn(4, None)
-        self.p_host_addr = sctp.param_host_addr(14, 'test host\x00')
+        self.p_ipv4 = sctp.param_ipv4('192.168.1.1')
+        self.p_ipv6 = sctp.param_ipv6('fe80::647e:1aff:fec4:8284')
+        self.p_cookie_preserve = sctp.param_cookie_preserve(5000)
+        self.p_ecn = sctp.param_ecn()
+        self.p_host_addr = sctp.param_host_addr('test host\x00')
         self.p_support_type = sctp.param_supported_addr(
-            14,
             [sctp.PTYPE_IPV4, sctp.PTYPE_IPV6, sctp.PTYPE_COOKIE_PRESERVE,
              sctp.PTYPE_ECN, sctp.PTYPE_HOST_ADDR])

@@ -125,14 +124,13 @@ class Test_sctp(unittest.TestCase):
         self.mis = 3
         self.i_tsn = 123456

-        self.p_state_cookie = sctp.param_state_cookie(
-            7, '\x01\x02\x03')
-        self.p_ipv4 = sctp.param_ipv4(8, '192.168.1.1')
-        self.p_ipv6 = sctp.param_ipv6(20, 'fe80::647e:1aff:fec4:8284')
+        self.p_state_cookie = sctp.param_state_cookie('\x01\x02\x03')
+        self.p_ipv4 = sctp.param_ipv4('192.168.1.1')
+        self.p_ipv6 = sctp.param_ipv6('fe80::647e:1aff:fec4:8284')
         self.p_unrecognized_param = sctp.param_unrecognized_param(
-            8, '\xff\xff\x00\x04')
-        self.p_ecn = sctp.param_ecn(4, None)
-        self.p_host_addr = sctp.param_host_addr(14, 'test host\x00')
+            '\xff\xff\x00\x04')
+        self.p_ecn = sctp.param_ecn()
+        self.p_host_addr = sctp.param_host_addr('test host\x00')

         self.init_ack = sctp.chunk_init_ack(
             self.flags, self.length, self.init_tag, self.a_rwnd,
@@ -189,7 +187,7 @@ class Test_sctp(unittest.TestCase):
         self.flags = 0
         self.length = 4 + 8

-        self.p_heartbeat = sctp.param_heartbeat(8, '\x01\x02\x03\x04')
+        self.p_heartbeat = sctp.param_heartbeat('\x01\x02\x03\x04')

         self.heartbeat = sctp.chunk_heartbeat(
             self.flags, self.length, self.p_heartbeat)
@@ -209,7 +207,7 @@ class Test_sctp(unittest.TestCase):
         self.length = 4 + 12

         self.p_heartbeat = sctp.param_heartbeat(
-            12, '\xff\xee\xdd\xcc\xbb\xaa\x99\x88')
+            '\xff\xee\xdd\xcc\xbb\xaa\x99\x88')

         self.heartbeat_ack = sctp.chunk_heartbeat_ack(
             self.flags, self.length, self.p_heartbeat)
@@ -229,28 +227,27 @@ class Test_sctp(unittest.TestCase):
         self.length = 4 + 8 + 16 + 8 + 4 + 20 + 8 + 4 + 8 + 8 + 4 + 12 \
             + 20 + 20

-        self.c_invalid_stream_id = sctp.cause_invalid_stream_id(
-            8, 4096)
+        self.c_invalid_stream_id = sctp.cause_invalid_stream_id(4096)
         self.c_missing_param = sctp.cause_missing_param(
-            16, 4, [sctp.PTYPE_IPV4, sctp.PTYPE_IPV6,
-                    sctp.PTYPE_COOKIE_PRESERVE, sctp.PTYPE_HOST_ADDR])
-        self.c_stale_cookie = sctp.cause_stale_cookie(8, '\x00\x00\x13\x88')
-        self.c_out_of_resource = sctp.cause_out_of_resource(4)
+            [sctp.PTYPE_IPV4, sctp.PTYPE_IPV6,
+             sctp.PTYPE_COOKIE_PRESERVE, sctp.PTYPE_HOST_ADDR])
+        self.c_stale_cookie = sctp.cause_stale_cookie('\x00\x00\x13\x88')
+        self.c_out_of_resource = sctp.cause_out_of_resource()
         self.c_unresolvable_addr = sctp.cause_unresolvable_addr(
-            20, sctp.param_host_addr(14, 'test host\x00'))
+            sctp.param_host_addr('test host\x00'))
         self.c_unrecognized_chunk = sctp.cause_unrecognized_chunk(
-            8, '\xff\x00\x00\x04')
-        self.c_invalid_param = sctp.cause_invalid_param(4)
+            '\xff\x00\x00\x04')
+        self.c_invalid_param = sctp.cause_invalid_param()
         self.c_unrecognized_param = sctp.cause_unrecognized_param(
-            8, '\xff\xff\x00\x04')
-        self.c_no_userdata = sctp.cause_no_userdata(8, '\x00\x01\xe2\x40')
-        self.c_cookie_while_shutdown = sctp.cause_cookie_while_shutdown(4)
+            '\xff\xff\x00\x04')
+        self.c_no_userdata = sctp.cause_no_userdata('\x00\x01\xe2\x40')
+        self.c_cookie_while_shutdown = sctp.cause_cookie_while_shutdown()
         self.c_restart_with_new_addr = sctp.cause_restart_with_new_addr(
-            12, sctp.param_ipv4(8, '192.168.1.1'))
+            sctp.param_ipv4('192.168.1.1'))
         self.c_user_initiated_abort = sctp.cause_user_initiated_abort(
-            19, 'Key Interrupt.\x00')
+            'Key Interrupt.\x00')
         self.c_protocol_violation = sctp.cause_protocol_violation(
-            20, 'Unknown reason.\x00')
+            'Unknown reason.\x00')

         self.causes = [
             self.c_invalid_stream_id, self.c_missing_param,
@@ -328,28 +325,27 @@ class Test_sctp(unittest.TestCase):
         self.length = 4 + 8 + 16 + 8 + 4 + 20 + 8 + 4 + 8 + 8 + 4 + 12 \
             + 20 + 20

-        self.c_invalid_stream_id = sctp.cause_invalid_stream_id(
-            8, 4096)
+        self.c_invalid_stream_id = sctp.cause_invalid_stream_id(4096)
         self.c_missing_param = sctp.cause_missing_param(
-            16, 4, [sctp.PTYPE_IPV4, sctp.PTYPE_IPV6,
-                    sctp.PTYPE_COOKIE_PRESERVE, sctp.PTYPE_HOST_ADDR])
-        self.c_stale_cookie = sctp.cause_stale_cookie(8, '\x00\x00\x13\x88')
-        self.c_out_of_resource = sctp.cause_out_of_resource(4)
+            [sctp.PTYPE_IPV4, sctp.PTYPE_IPV6,
+             sctp.PTYPE_COOKIE_PRESERVE, sctp.PTYPE_HOST_ADDR])
+        self.c_stale_cookie = sctp.cause_stale_cookie('\x00\x00\x13\x88')
+        self.c_out_of_resource = sctp.cause_out_of_resource()
         self.c_unresolvable_addr = sctp.cause_unresolvable_addr(
-            20, sctp.param_host_addr(16, 'test host\x00\x00\x00'))
+            sctp.param_host_addr('test host\x00'))
         self.c_unrecognized_chunk = sctp.cause_unrecognized_chunk(
-            8, '\xff\x00\x00\x04')
-        self.c_invalid_param = sctp.cause_invalid_param(4)
+            '\xff\x00\x00\x04')
+        self.c_invalid_param = sctp.cause_invalid_param()
         self.c_unrecognized_param = sctp.cause_unrecognized_param(
-            8, '\xff\xff\x00\x04')
-        self.c_no_userdata = sctp.cause_no_userdata(8, '\x00\x01\xe2\x40')
-        self.c_cookie_while_shutdown = sctp.cause_cookie_while_shutdown(4)
+            '\xff\xff\x00\x04')
+        self.c_no_userdata = sctp.cause_no_userdata('\x00\x01\xe2\x40')
+        self.c_cookie_while_shutdown = sctp.cause_cookie_while_shutdown()
         self.c_restart_with_new_addr = sctp.cause_restart_with_new_addr(
-            12, sctp.param_ipv4(8, '192.168.1.1'))
+            sctp.param_ipv4('192.168.1.1'))
         self.c_user_initiated_abort = sctp.cause_user_initiated_abort(
-            20, 'Key Interrupt.\x00\x00')
+            'Key Interrupt.\x00')
         self.c_protocol_violation = sctp.cause_protocol_violation(
-            20, 'Unknown reason.\x00')
+            'Unknown reason.\x00')

         self.causes = [
             self.c_invalid_stream_id, self.c_missing_param,
@@ -375,7 +371,7 @@ class Test_sctp(unittest.TestCase):
             '\x00\x03\x00\x08\x00\x00\x13\x88' + \
             '\x00\x04\x00\x04' + \
             '\x00\x05\x00\x14' + \
-            '\x00\x0b\x00\x10' + \
+            '\x00\x0b\x00\x0e' + \
             '\x74\x65\x73\x74\x20\x68\x6f\x73\x74\x00\x00\x00' + \
             '\x00\x06\x00\x08\xff\x00\x00\x04' + \
             '\x00\x07\x00\x04' + \
@@ -384,7 +380,7 @@ class Test_sctp(unittest.TestCase):
             '\x00\x0a\x00\x04' + \
             '\x00\x0b\x00\x0c' + \
             '\x00\x05\x00\x08\xc0\xa8\x01\x01' + \
-            '\x00\x0c\x00\x14' + \
+            '\x00\x0c\x00\x13' + \
             '\x4b\x65\x79\x20\x49\x6e\x74\x65' + \
             '\x72\x72\x75\x70\x74\x2e\x00\x00' + \
             '\x00\x0d\x00\x14' + \
@@ -607,6 +603,8 @@ class Test_sctp(unittest.TestCase):
             res = _res[0]
         else:
             res = _res
+        # to calculate the lengths of parameters.
+        self.sc.serialize(None, None)

         eq_(self.src_port, res.src_port)
         eq_(self.dst_port, res.dst_port)
@@ -1072,7 +1070,7 @@ class Test_sctp(unittest.TestCase):
             sctp.cause_unresolvable_addr._PACK_STR, buf)
         eq_(sctp.cause_unresolvable_addr.cause_code(), res5[0])
         eq_(20, res5[1])
-        eq_('\x00\x0b\x00\x10\x74\x65\x73\x74' +
+        eq_('\x00\x0b\x00\x0e\x74\x65\x73\x74' +
             '\x20\x68\x6f\x73\x74\x00\x00\x00',
             buf[sctp.cause_unresolvable_addr._MIN_LEN:
                 sctp.cause_unresolvable_addr._MIN_LEN + 16])
@@ -1127,10 +1125,10 @@ class Test_sctp(unittest.TestCase):
         res12 = struct.unpack_from(
             sctp.cause_user_initiated_abort._PACK_STR, buf)
         eq_(sctp.cause_user_initiated_abort.cause_code(), res12[0])
-        eq_(20, res12[1])
-        eq_('Key Interrupt.\x00\x00',
+        eq_(19, res12[1])
+        eq_('Key Interrupt.\x00',
             buf[sctp.cause_user_initiated_abort._MIN_LEN:
-                sctp.cause_user_initiated_abort._MIN_LEN + 16])
+                sctp.cause_user_initiated_abort._MIN_LEN + 15])

         buf = buf[20:]
         res13 = struct.unpack_from(
-- 
1.7.10.4


------------------------------------------------------------------------------
October Webinars: Code for Performance
Free Intel webinars can help you accelerate application performance.
Explore tips for MPI, OpenMP, advanced profiling, and more. Get the most from 
the latest Intel processors and coprocessors. See abstracts and register >
http://pubads.g.doubleclick.net/gampad/clk?id=60135991&iu=/4140/ostg.clktrk
_______________________________________________
Ryu-devel mailing list
[email protected]
https://lists.sourceforge.net/lists/listinfo/ryu-devel

Reply via email to