On Mon, Nov 10, 2014 at 2:33 PM, Ondřej Čertík <[email protected]> wrote:
> On Mon, Nov 10, 2014 at 10:54 AM, Amir Farbin <[email protected]> wrote:
>> Hi,
>>
>>
>> I'm trying to convert expressions sympy to theano... following:
>>
>>
>> http://matthewrocklin.com/blog/work/2013/03/28/SymPy-Theano-part-2/
>>
>>
>> And I run into a failure that appears simple to fix. Here's a simple
>> example:
>>
>>
>> from sympy.printing.theanocode import theano_function
>>
>>
>> im=sp.sqrt(-1)
>
> Note that you can use "I", i.e. sp.I.
>
>>
>> x=sp.symbols("x")
>>
>> y=x+im*x
>>
>> fn_theano = theano_function([x], [y], dims={x: 1}, dtypes={x: 'float64'})
>>
>>
>> ends with:
>>
>>
>> KeyError: <class 'sympy.core.numbers.ImaginaryUnit'>
>>
>>
>> which appears to my naive eyes to be a just a missing entry in a conversion
>> map. Can someone help?
>
> There seem to be more bugs. With the latest master, I am getting:
>
>
> In [1]: from sympy.printing.theanocode import theano_function
>
> In [3]: import sympy as sp
>
> In [4]: im=sp.sqrt(-1)
>
> In [5]: x=sp.symbols("x")
>
> In [6]: y=x+im*x
>
> In [7]: fn_theano = theano_function([x], [y], dims={x: 1}, dtypes={x:
> 'float64'})
> ---------------------------------------------------------------------------
> NameError Traceback (most recent call last)
> <ipython-input-7-8d3dbb9c015b> in <module>()
> ----> 1 fn_theano = theano_function([x], [y], dims={x: 1}, dtypes={x:
> 'float64'})
>
> /home/certik/repos/sympy/sympy/printing/theanocode.py in
> theano_function(inputs, outputs, dtypes, cache, **kwargs)
> 224 code = partial(theano_code, cache=cache, dtypes=dtypes,
> 225 broadcastables=broadcastables)
> --> 226 tinputs = list(map(code, inputs))
> 227 toutputs = list(map(code, outputs))
> 228 toutputs = toutputs[0] if len(toutputs) == 1 else toutputs
>
> /home/certik/repos/sympy/sympy/printing/theanocode.py in
> theano_code(expr, cache, **kwargs)
> 192
> 193 def theano_code(expr, cache=global_cache, **kwargs):
> --> 194 return TheanoPrinter(cache=cache,
> settings={}).doprint(expr, **kwargs)
> 195
> 196
>
> /home/certik/repos/sympy/sympy/printing/theanocode.py in doprint(self,
> expr, **kwargs)
> 187 def doprint(self, expr, **kwargs):
> 188 """Returns printer's representation for expr (as a string)"""
> --> 189 return self._print(expr, **kwargs)
> 190
> 191 global_cache = {}
>
> /home/certik/repos/sympy/sympy/printing/printer.pyc in _print(self,
> expr, *args, **kwargs)
> 255 printmethod = '_print_' + cls.__name__
> 256 if hasattr(self, printmethod):
> --> 257 return getattr(self, printmethod)(expr,
> *args, **kwargs)
> 258
> 259 # Unknown object, fall back to the emptyPrinter.
>
> /home/certik/repos/sympy/sympy/printing/theanocode.py in
> _print_Symbol(self, s, dtypes, broadcastables)
> 78 return self.cache[key]
> 79 else:
> ---> 80 value = tt.tensor(name=s.name, dtype=dtype,
> broadcastable=broadcastable)
> 81 self.cache[key] = value
> 82 return value
>
> NameError: global name 'tt' is not defined
Ah, ok, this is caused by me not having the theano module installed,
i.e. from the beginning of sympy/printing/theanocode.py:
theano = import_module('theano')
if theano:
ts = theano.scalar
tt = theano.tensor
otherwise 'tt' is not defined.
Anyway, Amir, try to add ImaginaryUnit into the "mapping" dicitonary
in sympy/printing/theanocode.py, that should do it. Send us a PR with
the fix.
Ondrej
>
>
>
>
> Thanks for letting us know. This might be quite easy to fix, just go
> into the source file, do the fix and test it. We can help you out. If
> you fix it, then just send us a pull request.
>
> Ondrej
--
You received this message because you are subscribed to the Google Groups
"sympy" group.
To unsubscribe from this group and stop receiving emails from it, send an email
to [email protected].
To post to this group, send email to [email protected].
Visit this group at http://groups.google.com/group/sympy.
To view this discussion on the web visit
https://groups.google.com/d/msgid/sympy/CADDwiVCS9vYD-y0-C3VZsuOwFcMU%3D5oXHO5SNTMZG-wyx1cEcQ%40mail.gmail.com.
For more options, visit https://groups.google.com/d/optout.