Author: russellm
Date: 2010-05-28 04:52:24 -0500 (Fri, 28 May 2010)
New Revision: 13308

Modified:
   django/trunk/tests/regressiontests/serializers_regress/tests.py
Log:
Fixed #13638 -- Refactored the serializers_regress tests to avoid the use of 
flush, and make better use of the transactional capabilities of 
django.test.TestCase. Thanks to Alex Gaynor for the patch.

Modified: django/trunk/tests/regressiontests/serializers_regress/tests.py
===================================================================
--- django/trunk/tests/regressiontests/serializers_regress/tests.py     
2010-05-27 09:38:59 UTC (rev 13307)
+++ django/trunk/tests/regressiontests/serializers_regress/tests.py     
2010-05-28 09:52:24 UTC (rev 13308)
@@ -10,14 +10,16 @@
 
 import datetime
 import decimal
-import unittest
-from cStringIO import StringIO
+try:
+    from cStringIO import StringIO
+except ImportError:
+    from StringIO import StringIO
 
-from django.utils.functional import curry
-from django.core import serializers
-from django.db import transaction, DEFAULT_DB_ALIAS
-from django.core import management
 from django.conf import settings
+from django.core import serializers, management
+from django.db import transaction, DEFAULT_DB_ALIAS
+from django.test import TestCase
+from django.utils.functional import curry
 
 from models import *
 
@@ -59,10 +61,10 @@
 
 def im_create(pk, klass, data):
     instance = klass(id=pk)
-    setattr(instance, 'right_id', data['right'])
-    setattr(instance, 'left_id', data['left'])
+    instance.right_id = data['right']
+    instance.left_id = data['left']
     if 'extra' in data:
-        setattr(instance, 'extra', data['extra'])
+        instance.extra = data['extra']
     models.Model.save_base(instance, raw=True)
     return [instance]
 
@@ -96,7 +98,9 @@
 def data_compare(testcase, pk, klass, data):
     instance = klass.objects.get(id=pk)
     testcase.assertEqual(data, instance.data,
-                         "Objects with PK=%d not equal; expected '%s' (%s), 
got '%s' (%s)" % (pk,data, type(data), instance.data, type(instance.data)))
+         "Objects with PK=%d not equal; expected '%s' (%s), got '%s' (%s)" % (
+            pk, data, type(data), instance.data, type(instance.data))
+    )
 
 def generic_compare(testcase, pk, klass, data):
     instance = klass.objects.get(id=pk)
@@ -348,28 +352,16 @@
 
 # Dynamically create serializer tests to ensure that all
 # registered serializers are automatically tested.
-class SerializerTests(unittest.TestCase):
+class SerializerTests(TestCase):
     pass
 
 def serializerTest(format, self):
-    # Clear the database first
-    management.call_command('flush', verbosity=0, interactive=False)
 
     # Create all the objects defined in the test data
     objects = []
     instance_count = {}
-    transaction.enter_transaction_management()
-    try:
-        transaction.managed(True)
-        for (func, pk, klass, datum) in test_data:
-            objects.extend(func[0](pk, klass, datum))
-            instance_count[klass] = 0
-        transaction.commit()
-    except:
-        transaction.rollback()
-        transaction.leave_transaction_management()
-        raise
-    transaction.leave_transaction_management()
+    for (func, pk, klass, datum) in test_data:
+        objects.extend(func[0](pk, klass, datum))
 
     # Get a count of the number of objects created for each class
     for klass in instance_count:
@@ -381,19 +373,8 @@
     # Serialize the test database
     serialized_data = serializers.serialize(format, objects, indent=2)
 
-    # Flush the database and recreate from the serialized data
-    management.call_command('flush', verbosity=0, interactive=False)
-    transaction.enter_transaction_management()
-    try:
-        transaction.managed(True)
-        for obj in serializers.deserialize(format, serialized_data):
-            obj.save()
-        transaction.commit()
-    except:
-        transaction.rollback()
-        transaction.leave_transaction_management()
-        raise
-    transaction.leave_transaction_management()
+    for obj in serializers.deserialize(format, serialized_data):
+        obj.save()
 
     # Assert that the deserialized data is the same
     # as the original source
@@ -406,10 +387,7 @@
         self.assertEquals(count, klass.objects.count())
 
 def fieldsTest(format, self):
-    # Clear the database first
-    management.call_command('flush', verbosity=0, interactive=False)
-
-    obj = ComplexModel(field1='first',field2='second',field3='third')
+    obj = ComplexModel(field1='first', field2='second', field3='third')
     obj.save_base(raw=True)
 
     # Serialize then deserialize the test database
@@ -422,9 +400,6 @@
     self.assertEqual(result.object.field3, 'third')
 
 def streamTest(format, self):
-    # Clear the database first
-    management.call_command('flush', verbosity=0, interactive=False)
-
     obj = ComplexModel(field1='first',field2='second',field3='third')
     obj.save_base(raw=True)
 
@@ -440,7 +415,7 @@
     stream.close()
 
 for format in serializers.get_serializer_formats():
-    setattr(SerializerTests, 'test_'+format+'_serializer', 
curry(serializerTest, format))
-    setattr(SerializerTests, 'test_'+format+'_serializer_fields', 
curry(fieldsTest, format))
+    setattr(SerializerTests, 'test_' + format + '_serializer', 
curry(serializerTest, format))
+    setattr(SerializerTests, 'test_' + format + '_serializer_fields', 
curry(fieldsTest, format))
     if format != 'python':
-        setattr(SerializerTests, 'test_'+format+'_serializer_stream', 
curry(streamTest, format))
+        setattr(SerializerTests, 'test_' + format + '_serializer_stream', 
curry(streamTest, format))

-- 
You received this message because you are subscribed to the Google Groups 
"Django updates" group.
To post to this group, send email to [email protected].
To unsubscribe from this group, send email to 
[email protected].
For more options, visit this group at 
http://groups.google.com/group/django-updates?hl=en.

Reply via email to