Author: rhs
Date: Thu Jun  4 15:57:48 2009
New Revision: 781786

URL: http://svn.apache.org/viewvc?rev=781786&view=rev
Log:
Added commit and rollback to the Session API and streamlined some test 
utilities.

Modified:
    qpid/trunk/qpid/python/qpid/messaging.py
    qpid/trunk/qpid/python/qpid/tests/messaging.py

Modified: qpid/trunk/qpid/python/qpid/messaging.py
URL: 
http://svn.apache.org/viewvc/qpid/trunk/qpid/python/qpid/messaging.py?rev=781786&r1=781785&r2=781786&view=diff
==============================================================================
--- qpid/trunk/qpid/python/qpid/messaging.py (original)
+++ qpid/trunk/qpid/python/qpid/messaging.py Thu Jun  4 15:57:48 2009
@@ -147,7 +147,7 @@
     self._condition = Condition(self._lock)
 
   @synchronized
-  def session(self, name=None):
+  def session(self, name=None, transactional=False):
     """
     Creates or retrieves the named session. If the name is omitted or
     None, then a unique name is chosen based on a randomly generated
@@ -168,7 +168,7 @@
     if self.sessions.has_key(name):
       return self.sessions[name]
     else:
-      ssn = Session(self, name, self.started)
+      ssn = Session(self, name, self.started, transactional=transactional)
       self.sessions[name] = ssn
       if self._conn is not None:
         ssn._attach()
@@ -268,10 +268,11 @@
   messages, and manage various Senders and Receivers.
   """
 
-  def __init__(self, connection, name, started):
+  def __init__(self, connection, name, started, transactional):
     self.connection = connection
     self.name = name
     self.started = started
+    self.transactional = transactional
     self._ssn = None
     self.senders = []
     self.receivers = []
@@ -279,6 +280,8 @@
     self.incoming = []
     self.closed = False
     self.unacked = []
+    if self.transactional:
+      self.acked = []
     self._lock = RLock()
     self._condition = Condition(self._lock)
     self.thread = Thread(target = self.run)
@@ -294,6 +297,8 @@
     self._ssn.invoke_lock = self._lock
     self._ssn.lock = self._lock
     self._ssn.condition = self._condition
+    if self.transactional:
+      self._ssn.tx_select()
     for link in self.senders + self.receivers:
       link._link()
 
@@ -414,17 +419,17 @@
   def acknowledge(self, message=None):
     """
     Acknowledge the given L{Message}. If message is None, then all
-    unackednowledged messages on the session are acknowledged.
+    unacknowledged messages on the session are acknowledged.
 
     @type message: Message
     @param message: the message to acknowledge or None
     """
     if message is None:
-      messages = self.unacked
+      messages = self.unacked[:]
     else:
       messages = [message]
 
-    ids = RangedSet(*[m._transfer_id for m in self.unacked])
+    ids = RangedSet(*[m._transfer_id for m in messages])
     for range in ids:
       self._ssn.receiver._completed.add_range(range)
     self._ssn.channel.session_completed(self._ssn.receiver._completed)
@@ -436,6 +441,46 @@
         self.unacked.remove(m)
       except ValueError:
         pass
+      if self.transactional:
+        self.acked.append(m)
+
+  @synchronized
+  def commit(self):
+    """
+    Commit outstanding transactional work. This consists of all
+    message sends and receives since the prior commit or rollback.
+    """
+    if not self.transactional:
+      raise NontransactionalSession()
+    if self._ssn is None:
+      raise Disconnected()
+    self._ssn.tx_commit(sync=True)
+    del self.acked[:]
+    self._ssn.sync()
+
+  @synchronized
+  def rollback(self):
+    """
+    Rollback outstanding transactional work. This consists of all
+    message sends and receives since the prior commit or rollback.
+    """
+    if not self.transactional:
+      raise NontransactionalSession()
+    if self._ssn is None:
+      raise Disconnected()
+
+    ids = RangedSet(*[m._transfer_id for m in self.acked + self.unacked + 
self.incoming])
+    for range in ids:
+      self._ssn.receiver._completed.add_range(range)
+    self._ssn.channel.session_completed(self._ssn.receiver._completed)
+    self._ssn.message_release(ids)
+    self._ssn.tx_rollback(sync=True)
+
+    del self.incoming[:]
+    del self.unacked[:]
+    del self.acked[:]
+
+    self._ssn.sync()
 
   @synchronized
   def start(self):
@@ -515,6 +560,13 @@
   """
   pass
 
+class NontransactionalSession(Exception):
+  """
+  Exception raised when commit or rollback is attempted on a non
+  transactional session.
+  """
+  pass
+
 class Sender(Lockable):
 
   """

Modified: qpid/trunk/qpid/python/qpid/tests/messaging.py
URL: 
http://svn.apache.org/viewvc/qpid/trunk/qpid/python/qpid/tests/messaging.py?rev=781786&r1=781785&r2=781786&view=diff
==============================================================================
--- qpid/trunk/qpid/python/qpid/tests/messaging.py (original)
+++ qpid/trunk/qpid/python/qpid/tests/messaging.py Thu Jun  4 15:57:48 2009
@@ -40,6 +40,7 @@
     return None
 
   def setup(self):
+    self.test_id = uuid4()
     self.broker = self.config.broker
     self.conn = self.setup_connection()
     self.ssn = self.setup_session()
@@ -50,10 +51,16 @@
     if self.conn is not None and self.conn.connected():
       self.conn.close()
 
+  def content(self, base, count = None):
+    if count is None:
+      return "%s[%s]" % (base, self.test_id)
+    else:
+      return "%s[%s, %s]" % (base, count, self.test_id)
+
   def ping(self, ssn):
     # send a message
     sender = ssn.sender("ping-queue")
-    content = "ping[%s]" % uuid4()
+    content = self.content("ping")
     sender.send(content)
     receiver = ssn.receiver("ping-queue")
     msg = receiver.fetch(timeout=0)
@@ -61,13 +68,17 @@
     assert msg.content == content
 
   def drain(self, rcv, limit=None):
-    msgs = []
+    contents = []
     try:
-      while limit is None or len(msgs) < limit:
-        msgs.append(rcv.fetch(0))
+      while limit is None or len(contents) < limit:
+        contents.append(rcv.fetch(0).content)
     except Empty:
       pass
-    return msgs
+    return contents
+
+  def assertEmpty(self, rcv):
+    contents = self.drain(rcv)
+    assert len(contents) == 0, "%s is supposed to be empty: %s" % (rcv, 
contents)
 
   def delay(self):
     d = float(self.config.defines.get("delay", "2"))
@@ -156,7 +167,7 @@
     assert snd is not snd2
     snd2.close()
 
-    content = "testSender[%s]" % uuid4()
+    content = self.content("testSender")
     snd.send(content)
     rcv = self.ssn.receiver(snd.target)
     msg = rcv.fetch(0)
@@ -169,7 +180,7 @@
     assert rcv is not rcv2
     rcv2.close()
 
-    content = "testReceiver[%s]" % uuid4()
+    content = self.content("testReceiver")
     snd = self.ssn.sender(rcv.source)
     snd.send(content)
     msg = rcv.fetch(0)
@@ -206,16 +217,14 @@
     # drain the queue, verify the messages are there and then close
     # without acking
     rcv = self.ssn.receiver(snd.target)
-    msgs = self.drain(rcv)
-    assert contents == [m.content for m in msgs]
+    assert contents == self.drain(rcv)
     self.ssn.close()
 
     # drain the queue again, verify that they are all the messages
     # were requeued, and ack this time before closing
     self.ssn = self.conn.session()
     rcv = self.ssn.receiver("test-ack-queue")
-    msgs = self.drain(rcv)
-    assert contents == [m.content for m in msgs]
+    assert contents == self.drain(rcv)
     self.ssn.acknowledge()
     self.ssn.close()
 
@@ -223,8 +232,69 @@
     # dequeued
     self.ssn = self.conn.session()
     rcv = self.ssn.receiver("test-ack-queue")
-    msgs = self.drain(rcv)
-    assert len(msgs) == 0
+    self.assertEmpty(rcv)
+
+  def send(self, ssn, queue, base, count=1):
+    snd = ssn.sender(queue)
+    contents = []
+    for i in range(count):
+      c = self.content(base, i)
+      snd.send(c)
+      contents.append(c)
+    snd.close()
+    return contents
+
+  def testCommitSend(self):
+    txssn = self.conn.session(transactional=True)
+    contents = self.send(txssn, "test-commit-send-queue", "testCommitSend", 3)
+    rcv = self.ssn.receiver("test-commit-send-queue")
+    self.assertEmpty(rcv)
+    txssn.commit()
+    assert contents == self.drain(rcv)
+    self.ssn.acknowledge()
+
+  def testCommitAck(self):
+    txssn = self.conn.session(transactional=True)
+    txrcv = txssn.receiver("test-commit-ack-queue")
+    self.assertEmpty(txrcv)
+    contents = self.send(self.ssn, "test-commit-ack-queue", "testCommitAck", 3)
+    assert contents == self.drain(txrcv)
+    txssn.acknowledge()
+    txssn.close()
+
+    txssn = self.conn.session(transactional=True)
+    txrcv = txssn.receiver("test-commit-ack-queue")
+    assert contents == self.drain(txrcv)
+    txssn.acknowledge()
+    txssn.commit()
+    rcv = self.ssn.receiver("test-commit-ack-queue")
+    self.assertEmpty(rcv)
+    txssn.close()
+    self.assertEmpty(rcv)
+
+  def testRollbackAck(self):
+    txssn = self.conn.session(transactional=True)
+    txrcv = txssn.receiver("test-rollback-ack-queue")
+    self.assertEmpty(txrcv)
+    contents = self.send(self.ssn, "test-rollback-ack-queue", 
"testRollbackAck", 3)
+    assert contents == self.drain(txrcv)
+    txssn.rollback()
+    assert contents == self.drain(txrcv)
+    txssn.acknowledge()
+    txssn.rollback()
+    assert contents == self.drain(txrcv)
+    txssn.commit() # commit without ack
+    self.assertEmpty(txrcv)
+    txssn.close()
+    txssn = self.conn.session(transactional=True)
+    txrcv = txssn.receiver("test-rollback-ack-queue")
+    assert contents == self.drain(txrcv)
+    txssn.acknowledge()
+    txssn.commit()
+    rcv = self.ssn.receiver("test-rollback-ack-queue")
+    self.assertEmpty(rcv)
+    txssn.close()
+    self.assertEmpty(rcv)
 
   def testClose(self):
     self.ssn.close()
@@ -249,10 +319,7 @@
     return self.ssn.receiver("test-receiver-queue")
 
   def send(self, base, count = None):
-    if count is None:
-      content = "%s[%s]" % (base, uuid4())
-    else:
-      content = "%s[%s, %s]" % (base, count, uuid4())
+    content = self.content(base, count)
     self.snd.send(content)
     return content
 
@@ -412,13 +479,13 @@
     self.ssn.acknowledge()
 
   def testSendString(self):
-    self.checkContent("testSendString[%s]" % uuid4())
+    self.checkContent(self.content("testSendString"))
 
   def testSendList(self):
-    self.checkContent(["testSendList", 1, 3.14, uuid4()])
+    self.checkContent(["testSendList", 1, 3.14, self.test_id])
 
   def testSendMap(self):
-    self.checkContent({"testSendMap": uuid4(), "pie": "blueberry", "pi": 3.14})
+    self.checkContent({"testSendMap": self.test_id, "pie": "blueberry", "pi": 
3.14})
 
 class MessageTests(Base):
 
@@ -506,7 +573,7 @@
     msg = Message()
     msg.to = "to-address"
     msg.subject = "subject"
-    msg.correlation_id = str(uuid4())
+    msg.correlation_id = str(self.test_id)
     msg.properties = MessageEchoTests.TEST_MAP
     msg.reply_to = "reply-address"
     self.check(msg)



---------------------------------------------------------------------
Apache Qpid - AMQP Messaging Implementation
Project:      http://qpid.apache.org
Use/Interact: mailto:[email protected]

Reply via email to