You probably know about it, but I'll mention it anyway: there's a primitive 
partial derivative operator in J. I think it would do exactly what you want 
(numerically), and it's probably reasonably fast. It's not too hard to use 
either:

dsoftmax=: sm D.1

Louis

> On 27 Feb 2017, at 10:20, 'Pascal Jasmin' via Programming 
> <[email protected]> wrote:
> 
> one optimization is removing the rank"0 _, so that the function not need to 
> be reparsed for each x
> 
> untested.
> 
> 
> dsoftmax=: 4 : 0
> rx=. ''
> for_i. x do.
> smx=. i softmax y
> for_j. i.#vals do.
> if. j_index. = i_index. do. rx=. rx , smx * (1 - smx)
> else.  rx=. rx ,(j softmax y)* (0 - smx) end.
> end. end.
> rx
> )
> 
> 
> ----- Original Message -----
> From: 'Jon Hough' via Programming <[email protected]>
> To: Programming Forum <[email protected]>
> Sent: Monday, February 27, 2017 3:09 AM
> Subject: [Jprogramming] Fast derivative of Softmax function
> 
> Given an array, we can calculate the softmax function
> https://en.wikipedia.org/wiki/Softmax_function
> 
> a =: 0.5 0.6 0.23 0.66
> sm=:(] % +/ )@:^ NB. softmax
> 
> sm a 
> 0.247399 0.273418 0.188859 0.290325
> 
> The (partial) derivative of softmax is a little more complicated:
> 
> If the array is of length N, we need an NxN matrix of partial derivatives 
> where (in pseudo code)
> 
> derivatives[i,j] = sm (array[i] )  *( 1 - sm(array[j])   if i == j
> or
> derivatives[i,j] =  -1 * sm (array[i] )  * ( sm(array[j])   if i != j
> 
> ( see here for the reasoning: 
> http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/ )
> 
> My implementation of the partial derivatives is this:
> 
> 
> NB. x value is index, y value is the whole array
> dsoftmax=: 4 : 0
> idx=. x
> vals=. y
> smx=. idx softmax vals
> rx=. ''
> for_j. i.#vals do.
>  if. j = idx do. rx=. rx , smx * (1 - smx)
>  elseif. 1 do. rx=. rx ,(j softmax vals)* (0 - smx) end.
> end.
> rx
> )
> 
> 
> Then, for example using above array a,
> 
> (i.# a) dsoftmax"0 _  a
> 
> gives the values, in a 4x4 matrix.
> 
> This is quite slow. I have tried to do this without iterating and branching, 
> but cannot figure out a way to do it.
> Any help appreciated.
> Thanks,
> 
> Jon
> ----------------------------------------------------------------------
> For information about J forums see http://www.jsoftware.com/forums.htm 
> ----------------------------------------------------------------------
> For information about J forums see http://www.jsoftware.com/forums.htm

----------------------------------------------------------------------
For information about J forums see http://www.jsoftware.com/forums.htm

Reply via email to