> You should also decide where the main cse() function ought to be
> exposed. I punted, and didn't expose it in sympy.* or
> sympy.simplify.*.
I'd add the following patch:
# HG changeset patch
# User Ondrej Certik <[EMAIL PROTECTED]>
# Date 1215346162 -7200
# Node ID a899701f1e4f303dcc41fd4a1e00795936a035c5
# Parent d799b0e232f2d15f7c771aec318d30840506a681
Make cse accept a single expression as well.
diff --git a/sympy/simplify/cse.py b/sympy/simplify/cse.py
--- a/sympy/simplify/cse.py
+++ b/sympy/simplify/cse.py
@@ -1,7 +1,7 @@
""" Tools for doing common subexpression elimination.
"""
-from sympy import Symbol
+from sympy import Symbol, Basic
from sympy.utilities.iterables import postorder_traversal
import cse_opts
@@ -91,7 +91,7 @@ def cse(exprs, symbols=None, optimizatio
Parameters
----------
- exprs : list of sympy expressions
+ exprs : list of sympy expressions, or a single sympy expression
The expressions to reduce.
symbols : infinite iterator yielding unique Symbols
The symbols used to label the common subexpressions which are pulled
@@ -124,6 +124,9 @@ def cse(exprs, symbols=None, optimizatio
# manipulations of the module-level list in some other thread.
optimizations = list(cse_optimizations)
+ # Handle the case if just one expression was passed.
+ if isinstance(exprs, Basic):
+ exprs = [exprs]
# Preprocess the expressions to give us better optimization opportunities.
exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
diff --git a/sympy/simplify/tests/test_cse.py b/sympy/simplify/tests/test_cse.py
--- a/sympy/simplify/tests/test_cse.py
+++ b/sympy/simplify/tests/test_cse.py
@@ -45,6 +45,13 @@ def test_cse_single():
assert substs == [(x0, x+y)]
assert reduced == [sqrt(x0) + x0**2]
+def test_cse_single2():
+ # Simple substitution, test for being able to pass the expression directly
+ e = Add(Pow(x+y,2), sqrt(x+y))
+ substs, reduced = cse.cse(e, optimizations=[])
+ assert substs == [(x0, x+y)]
+ assert reduced == [sqrt(x0) + x0**2]
+
def test_cse_not_possible():
# No substitution possible.
e = Add(x,y)
and export it using this patch:
# HG changeset patch
# User Ondrej Certik <[EMAIL PROTECTED]>
# Date 1215346276 -7200
# Node ID 049e4c8c531bacf82216959d2c6bfa004a662961
# Parent a899701f1e4f303dcc41fd4a1e00795936a035c5
Export cse() by default.
diff --git a/sympy/simplify/__init__.py b/sympy/simplify/__init__.py
--- a/sympy/simplify/__init__.py
+++ b/sympy/simplify/__init__.py
@@ -10,3 +10,5 @@ from rewrite import cancel, trim, apart
from rewrite import cancel, trim, apart
from sqrtdenest import sqrtdenest
+
+from cse import cse
Let me know what you think.
Ondrej
--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups
"sympy" group.
To post to this group, send email to [email protected]
To unsubscribe from this group, send email to [EMAIL PROTECTED]
For more options, visit this group at http://groups.google.com/group/sympy?hl=en
-~----------~----~----~----~------~----~------~--~---