Yes, thanks. This seems to be what I need!
-------------------------------------------- On Mon, 2/27/17, Louis de Forcrand <[email protected]> wrote: Subject: Re: [Jprogramming] Fast derivative of Softmax function To: [email protected] Date: Monday, February 27, 2017, 10:05 PM You probably know about it, but I'll mention it anyway: there's a primitive partial derivative operator in J. I think it would do exactly what you want (numerically), and it's probably reasonably fast. It's not too hard to use either: dsoftmax=: sm D.1 Louis > On 27 Feb 2017, at 10:20, 'Pascal Jasmin' via Programming <[email protected]> wrote: > > one optimization is removing the rank"0 _, so that the function not need to be reparsed for each x > > untested. > > > dsoftmax=: 4 : 0 > rx=. '' > for_i. x do. > smx=. i softmax y > for_j. i.#vals do. > if. j_index. = i_index. do. rx=. rx , smx * (1 - smx) > else. rx=. rx ,(j softmax y)* (0 - smx) end. > end. end. > rx > ) > > > ----- Original Message ----- > From: 'Jon Hough' via Programming <[email protected]> > To: Programming Forum <[email protected]> > Sent: Monday, February 27, 2017 3:09 AM > Subject: [Jprogramming] Fast derivative of Softmax function > > Given an array, we can calculate the softmax function > https://en.wikipedia.org/wiki/Softmax_function > > a =: 0.5 0.6 0.23 0.66 > sm=:(] % +/ )@:^ NB. softmax > > sm a > 0.247399 0.273418 0.188859 0.290325 > > The (partial) derivative of softmax is a little more complicated: > > If the array is of length N, we need an NxN matrix of partial derivatives where (in pseudo code) > > derivatives[i,j] = sm (array[i] ) *( 1 - sm(array[j]) if i == j > or > derivatives[i,j] = -1 * sm (array[i] ) * ( sm(array[j]) if i != j > > ( see here for the reasoning: http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/ ) > > My implementation of the partial derivatives is this: > > > NB. x value is index, y value is the whole array > dsoftmax=: 4 : 0 > idx=. x > vals=. y > smx=. idx softmax vals > rx=. '' > for_j. i.#vals do. > if. j = idx do. rx=. rx , smx * (1 - smx) > elseif. 1 do. rx=. rx ,(j softmax vals)* (0 - smx) end. > end. > rx > ) > > > Then, for example using above array a, > > (i.# a) dsoftmax"0 _ a > > gives the values, in a 4x4 matrix. > > This is quite slow. I have tried to do this without iterating and branching, but cannot figure out a way to do it. > Any help appreciated. > Thanks, > > Jon > ---------------------------------------------------------------------- > For information about J forums see http://www.jsoftware.com/forums.htm > ---------------------------------------------------------------------- > For information about J forums see http://www.jsoftware.com/forums.htm ---------------------------------------------------------------------- For information about J forums see http://www.jsoftware.com/forums.htm ---------------------------------------------------------------------- For information about J forums see http://www.jsoftware.com/forums.htm
