I'm trying to slice a theano.tensor.tensor3d so that every index from the
first and the second dimension are selected, and the third dimension is
subsetted based on another tensor variable.
For clarity, I'm trying to generalise something like this:
import theano
import theano.tensor as T
X = T.fmatrix('X')
y = T.ivector('y')
rslt = X[T.arange(X.shape[0]), y]
fn = theano.function([X, y], rslt)
into something like this:
import theano
import theano.tensor as T
X = T.ftensor3('X')
y = T.ivector('y')
rslt = X[T.arange(X.shape[0]), T.arange(X.shape[1]), y]
fn = theano.function([X, y], rslt)
but the latter unfortunately throws an error like this:
Traceback (most recent call last):
File
"/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line
2885, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-73-e2af657671ee>", line 1, in <module>
fn(xx3, yy)
File
"/usr/local/lib/python2.7/dist-packages/theano/compile/function_module.py",
line 871, in __call__
storage_map=getattr(self.fn, 'storage_map', None))
File "/usr/local/lib/python2.7/dist-packages/theano/gof/link.py", line 314,
in raise_with_op
reraise(exc_type, exc_value, exc_trace)
File
"/usr/local/lib/python2.7/dist-packages/theano/compile/function_module.py",
line 859, in __call__
outputs = self.fn()
File "/usr/local/lib/python2.7/dist-packages/theano/gof/op.py", line 912, in
rval
r = p(n, [x[0] for x in i], o)
File "/usr/local/lib/python2.7/dist-packages/theano/tensor/subtensor.py",
line 2166, in perform
out[0] = inputs[0].__getitem__(inputs[1:])
IndexError: shape mismatch: indexing arrays could not be broadcast together
with shapes (2,) (3,) (3,)
Apply node that caused the error: AdvancedSubtensor(X, ARange{dtype='int64'}.0,
ARange{dtype='int64'}.0, y)
Toposort index: 4
Inputs types: [TensorType(float32, 3D), TensorType(int64, vector),
TensorType(int64, vector), TensorType(int32, vector)]
Inputs shapes: [(2, 3, 3), (2,), (3,), (3,)]
Inputs strides: [(36, 12, 4), (8,), (8,), (4,)]
Inputs values: ['not shown', array([0, 1]), array([0, 1, 2]), array([0, 1, 2],
dtype=int32)]
Outputs clients: [['output']]
Backtrace when the node is created(use Theano flag traceback.limit=N to make it
longer):
File
"/home/jirka/.IntelliJIdea2016.2/config/plugins/python/helpers/pydev/pydevconsole.py",
line 213, in process_exec_queue
more = interpreter.add_exec(code_fragment)
File
"/home/jirka/.IntelliJIdea2016.2/config/plugins/python/helpers/pydev/_pydev_bundle/pydev_console_utils.py",
line 236, in add_exec
more = self.do_add_exec(code_fragment)
File
"/home/jirka/.IntelliJIdea2016.2/config/plugins/python/helpers/pydev/_pydev_bundle/pydev_ipython_console.py",
line 42, in do_add_exec
res = bool(self.interpreter.add_exec(codeFragment.text))
File
"/home/jirka/.IntelliJIdea2016.2/config/plugins/python/helpers/pydev/_pydev_bundle/pydev_ipython_console_011.py",
line 436, in add_exec
self.ipython.run_cell(line, store_history=True)
File
"/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line
2723, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File
"/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line
2825, in run_ast_nodes
if self.run_code(code, result):
File
"/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line
2885, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-69-d2e570619075>", line 4, in <module>
rslt = X[T.arange(X.shape[0]), T.arange(X.shape[1]), y]
HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and
storage map footprint of this apply node.
My current solution is to convert y to one_hot_encoding, add broadcastable
dimension, and doing elementwise multiplication, which does yield the desired
result but is inefficient.
Is there any better solution please?
--
---
You received this message because you are subscribed to the Google Groups
"theano-users" group.
To unsubscribe from this group and stop receiving emails from it, send an email
to [email protected].
For more options, visit https://groups.google.com/d/optout.