Revision: 5007
          http://matplotlib.svn.sourceforge.net/matplotlib/?rev=5007&view=rev
Author:   sameerd
Date:     2008-03-19 11:07:49 -0700 (Wed, 19 Mar 2008)

Log Message:
-----------
Added outerjoin, lefjoin and rightjoin support to rec_join

Modified Paths:
--------------
    trunk/matplotlib/lib/matplotlib/mlab.py

Added Paths:
-----------
    trunk/matplotlib/examples/rec_join_demo.py

Added: trunk/matplotlib/examples/rec_join_demo.py
===================================================================
--- trunk/matplotlib/examples/rec_join_demo.py                          (rev 0)
+++ trunk/matplotlib/examples/rec_join_demo.py  2008-03-19 18:07:49 UTC (rev 
5007)
@@ -0,0 +1,27 @@
+import numpy as np
+import matplotlib.mlab as mlab
+
+
+r = mlab.csv2rec('data/aapl.csv')
+r.sort()
+r1 = r[-10:]
+
+# Create a new array 
+r2 = np.empty(12, dtype=[('date', '|O4'), ('high', np.float), 
+                            ('marker', np.float)])
+r2 = r2.view(np.recarray)
+r2.date = r.date[-17:-5]
+r2.high = r.high[-17:-5]
+r2.marker = np.arange(12)
+
+print "r1:"
+print mlab.rec2txt(r1)
+print "r2:"
+print mlab.rec2txt(r2)
+
+defaults = {'marker':-1, 'close':np.NaN, 'low':-4444.}
+
+for s in ('inner', 'outer', 'leftouter'):
+    rec = mlab.rec_join(['date', 'high'], r1, r2, 
+            jointype=s, defaults=defaults) 
+    print "\n%sjoin :\n%s" % (s, mlab.rec2txt(rec))

Modified: trunk/matplotlib/lib/matplotlib/mlab.py
===================================================================
--- trunk/matplotlib/lib/matplotlib/mlab.py     2008-03-19 14:36:57 UTC (rev 
5006)
+++ trunk/matplotlib/lib/matplotlib/mlab.py     2008-03-19 18:07:49 UTC (rev 
5007)
@@ -2044,12 +2044,19 @@
 
     return npy.rec.fromarrays(arrays, names=names)
 
-def rec_join(key, r1, r2):
+
+def rec_join(key, r1, r2, jointype='inner', defaults=None):
     """
     join record arrays r1 and r2 on key; key is a tuple of field
     names.  if r1 and r2 have equal values on all the keys in the key
     tuple, then their fields will be merged into a new record array
     containing the intersection of the fields of r1 and r2
+
+    The jointype keyword can be 'inner', 'outer', 'leftouter'.
+    To do a rightouter join just reverse r1 and r2.
+
+    The defaults keyword is a dictionary filled with
+    {column_name:default_value} pairs.
     """
 
     for name in key:
@@ -2067,17 +2074,22 @@
     r1keys = set(r1d.keys())
     r2keys = set(r2d.keys())
 
-    keys = r1keys & r2keys
+    common_keys = r1keys & r2keys
 
-    r1ind = npy.array([r1d[k] for k in keys])
-    r2ind = npy.array([r2d[k] for k in keys])
+    r1ind = npy.array([r1d[k] for k in common_keys])
+    r2ind = npy.array([r2d[k] for k in common_keys])
 
-    # Make sure that the output rows have the same relative order as r1
-    sortind = r1ind.argsort()
+    common_len = len(common_keys)
+    left_len = right_len = 0
+    if jointype == "outer" or jointype == "leftouter":
+        left_keys = r1keys.difference(r2keys)
+        left_ind = npy.array([r1d[k] for k in left_keys])
+        left_len = len(left_ind)
+    if jointype == "outer":
+        right_keys = r2keys.difference(r1keys)
+        right_ind = npy.array([r2d[k] for k in right_keys])
+        right_len = len(right_ind)
 
-    r1 = r1[r1ind[sortind]]
-    r2 = r2[r2ind[sortind]]
-
     r2 = rec_drop_fields(r2, r1.dtype.names)
 
 
@@ -2103,13 +2115,31 @@
                          [desc for desc in r2.dtype.descr if desc[0] not in 
key ] )
 
 
-    newrec = npy.empty(len(r1), dtype=newdtype)
+    newrec = npy.empty(common_len + left_len + right_len, dtype=newdtype)
+
+    if jointype != 'inner' and defaults is not None: # fill in the defaults 
enmasse
+        newrec_fields = newrec.dtype.fields.keys()
+        for k, v in defaults.items():
+            if k in newrec_fields:
+                newrec[k] = v
+
     for field in r1.dtype.names:
-        newrec[field] = r1[field]
+        newrec[field][:common_len] = r1[field][r1ind]
+        if jointype == "outer" or jointype == "leftouter":
+            newrec[field][common_len:(common_len+left_len)] = 
r1[field][left_ind]
 
     for field in r2.dtype.names:
-        newrec[field] = r2[field]
+        newrec[field][:common_len] = r2[field][r2ind]
+        if jointype == "outer":
+            newrec[field][-right_len:] = 
r2[field][right_ind[right_ind.argsort()]]
 
+    # sort newrec using the same order as r1
+    sort_indices = r1ind.copy()
+    if jointype == "outer" or jointype == "leftouter":
+        sort_indices = npy.append(sort_indices, left_ind)
+    newrec[:(common_len+left_len)] = newrec[sort_indices.argsort()]
+
+
     return newrec.view(npy.recarray)
 
 


This was sent by the SourceForge.net collaborative development platform, the 
world's largest Open Source development site.

-------------------------------------------------------------------------
This SF.net email is sponsored by: Microsoft
Defy all challenges. Microsoft(R) Visual Studio 2008.
http://clk.atdmt.com/MRT/go/vse0120000070mrt/direct/01/
_______________________________________________
Matplotlib-checkins mailing list
[email protected]
https://lists.sourceforge.net/lists/listinfo/matplotlib-checkins

Reply via email to