Diff comments:
> === modified file 'cloudinit/sources/DataSourceSmartOS.py' > --- cloudinit/sources/DataSourceSmartOS.py 2015-01-27 20:03:52 +0000 > +++ cloudinit/sources/DataSourceSmartOS.py 2015-03-04 16:40:58 +0000 > @@ -29,9 +29,10 @@ > # http://us-east.manta.joyent.com/jmc/public/mdata/datadict.html > # Comments with "@datadictionary" are snippets of the definition > > -import base64 > import binascii > import os > +import random > +import re > import serial > > from cloudinit import log as logging > @@ -301,6 +302,59 @@ > return ser > > > +class JoyentMetadataFetchException(Exception): > + pass > + > + > +class JoyentMetadataClient(object): > + > + def __init__(self, serial): > + self.serial = serial > + > + def _checksum(self, body): > + return '{0:08x}'.format( > + binascii.crc32(body.encode('utf-8')) & 0xffffffff) > + > + def _get_value_from_frame(self, expected_request_id, frame): > + regex = ( > + r'V2 (?P<length>\d+) (?P<checksum>[0-9a-f]+)' Nope, it's actually defined here: http://eng.joyent.com/mdata/protocol.html The example of a frame they give is: ,- Body (CRC32/Length is of _____________________ <--/ this string) V2 21 265ae1d8 dc4fae17 SUCCESS W10=\n ^ ^ ^ ^ ^ ^ ^ | | | | | | \--- Terminating Linefeed | | | | | \------- Payload | | | | \--------------- Code | | | \------------------------ Request ID | | \--------------------------------- Body Checksum (CRC32) | \------------------------------------ Body Length \--------------------------------------- So that this can be a V1 command as well, we start with "V2". This will be a FAILURE on a host that only supports the V1 protocol. > + r' (?P<body>(?P<request_id>[0-9a-f]+) > (?P<status>SUCCESS|NOTFOUND)' > + r'( (?P<payload>.+))?)') > + frame_data = re.match(regex, frame).groupdict() > + if int(frame_data['length']) != len(frame_data['body']): > + raise JoyentMetadataFetchException( > + 'Incorrect frame length given ({0} != {1}).'.format( > + frame_data['length'], len(frame_data['body']))) > + expected_checksum = self._checksum(frame_data['body']) > + if frame_data['checksum'] != expected_checksum: > + raise JoyentMetadataFetchException( > + 'Invalid checksum (expected: {0}; got {1}).'.format( > + expected_checksum, frame_data['checksum'])) > + if frame_data['request_id'] != expected_request_id: > + raise JoyentMetadataFetchException( > + 'Request ID mismatch (expected: {0}; got {1}).'.format( > + expected_request_id, frame_data['request_id'])) > + if not frame_data.get('payload', None): > + LOG.info('No value found.') > + return None > + value = util.b64d(frame_data['payload']) > + LOG.info('Value "%s" found.', value) > + return value > + > + def get_metadata(self, metadata_key): > + LOG.info('Fetching metadata key "%s"...', metadata_key) > + request_id = '{0:08x}'.format(random.randint(0, 0xffffffff)) > + message_body = '{0} GET {1}'.format(request_id, > + util.b64e(metadata_key)) > + msg = 'V2 {0} {1} {2}\n'.format( > + len(message_body), self._checksum(message_body), message_body) > + LOG.debug('Writing "%s" to serial port.', msg) > + self.serial.write(msg) > + response = self.serial.readline() > + LOG.debug('Read "%s" from serial port.', response) > + return self._get_value_from_frame(request_id, response) > + > + > def query_data(noun, seed_device, seed_timeout, strip=False, default=None, > b64=None): > """Makes a request to via the serial console via "GET <NOUN>" > @@ -314,33 +368,21 @@ > encoded, so this method relies on being told if the data is base64 or > not. > """ > - > if not noun: > return False > > ser = get_serial(seed_device, seed_timeout) > - ser.write("GET %s\n" % noun.rstrip()) > - status = str(ser.readline()).rstrip() > - response = [] > - eom_found = False > > - if 'SUCCESS' not in status: > - ser.close() > + client = JoyentMetadataClient(ser) Yep, good call. > + response = client.get_metadata(noun) > + ser.close() > + if response is None: > return default > > - while not eom_found: > - m = ser.readline() > - if m.rstrip() == ".": > - eom_found = True > - else: > - response.append(m) > - > - ser.close() > - > if b64 is None: > b64 = query_data('b64-%s' % noun, seed_device=seed_device, > - seed_timeout=seed_timeout, b64=False, > - default=False, strip=True) > + seed_timeout=seed_timeout, b64=False, > + default=False, strip=True) > b64 = util.is_true(b64) > > resp = None > > === modified file 'tests/unittests/test_datasource/test_smartos.py' > --- tests/unittests/test_datasource/test_smartos.py 2015-01-27 20:03:52 > +0000 > +++ tests/unittests/test_datasource/test_smartos.py 2015-03-04 16:40:58 > +0000 > @@ -24,18 +24,27 @@ > > from __future__ import print_function > > -from cloudinit import helpers as c_helpers > -from cloudinit.sources import DataSourceSmartOS > -from cloudinit.util import b64e > -from .. import helpers > import os > import os.path > import re > import shutil > +import stat > import tempfile > -import stat > import uuid > - > +from binascii import crc32 > + > +import serial > + > +from cloudinit import helpers as c_helpers > +from cloudinit.sources import DataSourceSmartOS > +from cloudinit.util import b64e > + > +from .. import helpers > + > +try: > + from unittest import mock > +except ImportError: > + import mock > > MOCK_RETURNS = { > 'hostname': 'test-host', > @@ -54,60 +63,15 @@ > DMI_DATA_RETURN = (str(uuid.uuid4()), 'smartdc') > > > -class MockSerial(object): > - """Fake a serial terminal for testing the code that > - interfaces with the serial""" > - > - port = None > - > - def __init__(self, mockdata): > - self.last = None > - self.last = None > - self.new = True > - self.count = 0 > - self.mocked_out = [] > - self.mockdata = mockdata > - > - def open(self): > - return True > - > - def close(self): > - return True > - > - def isOpen(self): > - return True > - > - def write(self, line): > - line = line.replace('GET ', '') > - self.last = line.rstrip() > - > - def readline(self): > - if self.new: > - self.new = False > - if self.last in self.mockdata: > - return 'SUCCESS\n' > - else: > - return 'NOTFOUND %s\n' % self.last > - > - if self.last in self.mockdata: > - if not self.mocked_out: > - self.mocked_out = [x for x in self._format_out()] > - > - if len(self.mocked_out) > self.count: > - self.count += 1 > - return self.mocked_out[self.count - 1] > - > - def _format_out(self): > - if self.last in self.mockdata: > - _mret = self.mockdata[self.last] > - try: > - for l in _mret.splitlines(): > - yield "%s\n" % l.rstrip() > - except: > - yield "%s\n" % _mret.rstrip() > - > - yield '.' > - yield '\n' > +def get_mock_client(mockdata): > + class MockMetadataClient(object): > + > + def __init__(self, serial): > + pass > + > + def get_metadata(self, metadata_key): > + return mockdata.get(metadata_key) > + return MockMetadataClient > > > class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): > @@ -155,9 +119,6 @@ > if dmi_data is None: > dmi_data = DMI_DATA_RETURN > > - def _get_serial(*_): > - return MockSerial(mockdata) > - > def _dmi_data(): > return dmi_data > > @@ -174,7 +135,9 @@ > sys_cfg['datasource']['SmartOS'] = ds_cfg > > self.apply_patches([(mod, 'LEGACY_USER_D', self.legacy_user_d)]) > - self.apply_patches([(mod, 'get_serial', _get_serial)]) > + self.apply_patches([(mod, 'get_serial', mock.MagicMock())]) > + self.apply_patches([ > + (mod, 'JoyentMetadataClient', get_mock_client(mockdata))]) > self.apply_patches([(mod, 'dmi_data', _dmi_data)]) > self.apply_patches([(os, 'uname', _os_uname)]) > self.apply_patches([(mod, 'device_exists', lambda d: True)]) > @@ -453,3 +416,129 @@ > setattr(ref, name, replace) > ret.append((ref, name, orig)) > return ret > + > + > +class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase): > + > + def setUp(self): > + super(TestJoyentMetadataClient, self).setUp() > + self.serial = mock.MagicMock(spec=serial.Serial) > + self.request_id = 0xabcdef12 > + self.metadata_value = 'value' > + self.response_parts = { > + 'command': 'SUCCESS', > + 'crc': 'b5a9ff00', > + 'length': 17 + len(b64e(self.metadata_value)), > + 'payload': b64e(self.metadata_value), > + 'request_id': '{0:08x}'.format(self.request_id), > + } > + > + def make_response(): > + payload = '' > + if self.response_parts['payload']: > + payload = ' {0}'.format(self.response_parts['payload']) > + del self.response_parts['payload'] > + return ( > + 'V2 {length} {crc} {request_id} {command}{payload}\n'.format( > + payload=payload, **self.response_parts)) > + self.serial.readline.side_effect = make_response > + self.patched_funcs.enter_context( > + mock.patch('cloudinit.sources.DataSourceSmartOS.random.randint', > + mock.Mock(return_value=self.request_id))) > + > + def _get_client(self): > + return DataSourceSmartOS.JoyentMetadataClient(self.serial) > + > + def assertEndsWith(self, haystack, prefix): > + self.assertTrue(haystack.endswith(prefix), > + "{0} does not end with '{1}'".format( > + repr(haystack), prefix)) > + > + def assertStartsWith(self, haystack, prefix): > + self.assertTrue(haystack.startswith(prefix), > + "{0} does not start with '{1}'".format( > + repr(haystack), prefix)) > + > + def test_get_metadata_writes_a_single_line(self): > + client = self._get_client() > + client.get_metadata('some_key') > + self.assertEqual(1, self.serial.write.call_count) > + written_line = self.serial.write.call_args[0][0] > + self.assertEndsWith(written_line, '\n') > + self.assertEqual(1, written_line.count('\n')) > + > + def _get_written_line(self, key='some_key'): > + client = self._get_client() > + client.get_metadata(key) > + return self.serial.write.call_args[0][0] > + > + def test_get_metadata_line_starts_with_v2(self): > + self.assertStartsWith(self._get_written_line(), 'V2') > + > + def test_get_metadata_uses_get_command(self): > + parts = self._get_written_line().strip().split(' ') > + self.assertEqual('GET', parts[4]) > + > + def test_get_metadata_base64_encodes_argument(self): > + key = 'my_key' > + parts = self._get_written_line(key).strip().split(' ') > + self.assertEqual(b64e(key), parts[5]) > + > + def test_get_metadata_calculates_length_correctly(self): > + parts = self._get_written_line().strip().split(' ') > + expected_length = len(' '.join(parts[3:])) > + self.assertEqual(expected_length, int(parts[1])) > + > + def test_get_metadata_uses_appropriate_request_id(self): > + parts = self._get_written_line().strip().split(' ') > + request_id = parts[3] > + self.assertEqual(8, len(request_id)) > + self.assertEqual(request_id, request_id.lower()) > + > + def test_get_metadata_uses_random_number_for_request_id(self): > + request_id = self._get_written_line().strip().split(' ')[3] > + self.assertEqual('{0:08x}'.format(self.request_id), request_id) > + > + def test_get_metadata_checksums_correctly(self): > + parts = self._get_written_line().strip().split(' ') > + expected_checksum = '{0:08x}'.format( > + crc32(' '.join(parts[3:]).encode('utf-8')) & 0xffffffff) > + checksum = parts[2] > + self.assertEqual(expected_checksum, checksum) > + > + def test_get_metadata_reads_a_line(self): > + client = self._get_client() > + client.get_metadata('some_key') > + self.assertEqual(1, self.serial.readline.call_count) > + > + def test_get_metadata_returns_valid_value(self): > + client = self._get_client() > + value = client.get_metadata('some_key') > + self.assertEqual(self.metadata_value, value) > + > + def test_get_metadata_throws_exception_for_incorrect_length(self): > + self.response_parts['length'] = 0 > + client = self._get_client() > + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, > + client.get_metadata, 'some_key') > + > + def test_get_metadata_throws_exception_for_incorrect_crc(self): > + self.response_parts['crc'] = 'deadbeef' > + client = self._get_client() > + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, > + client.get_metadata, 'some_key') > + > + def test_get_metadata_throws_exception_for_request_id_mismatch(self): > + self.response_parts['request_id'] = 'deadbeef' > + client = self._get_client() > + client._checksum = lambda _: self.response_parts['crc'] > + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, > + client.get_metadata, 'some_key') > + > + def test_get_metadata_returns_None_if_value_not_found(self): > + self.response_parts['payload'] = '' > + self.response_parts['command'] = 'NOTFOUND' > + self.response_parts['length'] = 17 > + client = self._get_client() > + client._checksum = lambda _: self.response_parts['crc'] > + self.assertIsNone(client.get_metadata('some_key')) > -- https://code.launchpad.net/~daniel-thewatkins/cloud-init/smartos-v2-metadata/+merge/251775 Your team cloud init development team is requested to review the proposed merge of lp:~daniel-thewatkins/cloud-init/smartos-v2-metadata into lp:cloud-init. _______________________________________________ Mailing list: https://launchpad.net/~cloud-init-dev Post to : [email protected] Unsubscribe : https://launchpad.net/~cloud-init-dev More help : https://help.launchpad.net/ListHelp

