I was about to point out the same thing.
a =: 0.5 0.6 0.23 0.66
sm=:(] % +/ )@:^ NB. softmax
softmax=: { sm
dsoftmax=: 4 : 0
idx=. x
vals=. y
smx=. idx softmax vals
rx=. ''
for_j. i.#vals do.
if. j = idx do. rx=. rx , smx * (1 - smx)
elseif. 1 do. rx=. rx ,(j softmax vals)* (0 - smx) end.
end.
rx
)
sm D.1 a
0.186192 _0.0676431 _0.0467234 _0.0718259
_0.0676431 0.19866 _0.0516374 _0.0793799
_0.0467234 _0.0516374 0.153191 _0.0548305
_0.0718259 _0.0793799 _0.0548305 0.206036
(i.# a) dsoftmax"0 _ a
0.186192 _0.0676431 _0.0467234 _0.0718259
_0.0676431 0.19866 _0.0516374 _0.0793799
_0.0467234 _0.0516374 0.153191 _0.0548305
_0.0718259 _0.0793799 _0.0548305 0.206036
The speedup is not too impressive, but it is a speedup (probably
because we are retaining and reusing all results from sm rather than
recomputing it so many times -- I imagine using sm directly and
lifting it out of the loop ):
timespacex '(i.# a) dsoftmax"0 _ a'
5.5e_5 6528
timespacex 'sm D.1 a'
2.3e_5 7424
That said, note that we can approximate this speedup by using a
variant on what Pascal proposed:
d_softmax=: 4 : 0
rx=. i.0 0
smv=. sm y
for_i. x do.
ry=. i.0
smx=. i { smv
for_j. i.#y do.
if. j_index = i_index do. ry=. ry , smx * (1 - smx)
else. ry=. ry ,(j {smv)* (0 - smx) end.
end.
rx=.rx, ry
end.
rx
)
(i.# a) d_softmax a
0.186192 _0.0676431 _0.0467234 _0.0718259
_0.0676431 0.19866 _0.0516374 _0.0793799
_0.0467234 _0.0516374 0.153191 _0.0548305
_0.0718259 _0.0793799 _0.0548305 0.206036
timespacex '(i.# a) d_softmax a'
3e_5 6016
(Remember that it's generally a good idea to ignore speedups which are
less than a factor of 2, because of scheduling issues within the
machine itself - you can see this by inspecting multiple timing runs)
timespacex '(i.# a) d_softmax a'
3e_5 6016
timespacex 'sm D.1 a'
2.3e_5 7424
timespacex '(i.# a) d_softmax a'
2.9e_5 6016
timespacex 'sm D.1 a'
2.3e_5 7424
timespacex '(i.# a) d_softmax a'
2.8e_5 6016
timespacex 'sm D.1 a'
3.7e_5 7424
timespacex '(i.# a) d_softmax a'
3.2e_5 6016
I hope this helps,
--
Raul
On Mon, Feb 27, 2017 at 8:05 AM, Louis de Forcrand <[email protected]> wrote:
> You probably know about it, but I'll mention it anyway: there's a primitive
> partial derivative operator in J. I think it would do exactly what you want
> (numerically), and it's probably reasonably fast. It's not too hard to use
> either:
>
> dsoftmax=: sm D.1
>
> Louis
>
>> On 27 Feb 2017, at 10:20, 'Pascal Jasmin' via Programming
>> <[email protected]> wrote:
>>
>> one optimization is removing the rank"0 _, so that the function not need to
>> be reparsed for each x
>>
>> untested.
>>
>>
>> dsoftmax=: 4 : 0
>> rx=. ''
>> for_i. x do.
>> smx=. i softmax y
>> for_j. i.#vals do.
>> if. j_index. = i_index. do. rx=. rx , smx * (1 - smx)
>> else. rx=. rx ,(j softmax y)* (0 - smx) end.
>> end. end.
>> rx
>> )
>>
>>
>> ----- Original Message -----
>> From: 'Jon Hough' via Programming <[email protected]>
>> To: Programming Forum <[email protected]>
>> Sent: Monday, February 27, 2017 3:09 AM
>> Subject: [Jprogramming] Fast derivative of Softmax function
>>
>> Given an array, we can calculate the softmax function
>> https://en.wikipedia.org/wiki/Softmax_function
>>
>> a =: 0.5 0.6 0.23 0.66
>> sm=:(] % +/ )@:^ NB. softmax
>>
>> sm a
>> 0.247399 0.273418 0.188859 0.290325
>>
>> The (partial) derivative of softmax is a little more complicated:
>>
>> If the array is of length N, we need an NxN matrix of partial derivatives
>> where (in pseudo code)
>>
>> derivatives[i,j] = sm (array[i] ) *( 1 - sm(array[j]) if i == j
>> or
>> derivatives[i,j] = -1 * sm (array[i] ) * ( sm(array[j]) if i != j
>>
>> ( see here for the reasoning:
>> http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/ )
>>
>> My implementation of the partial derivatives is this:
>>
>>
>> NB. x value is index, y value is the whole array
>> dsoftmax=: 4 : 0
>> idx=. x
>> vals=. y
>> smx=. idx softmax vals
>> rx=. ''
>> for_j. i.#vals do.
>> if. j = idx do. rx=. rx , smx * (1 - smx)
>> elseif. 1 do. rx=. rx ,(j softmax vals)* (0 - smx) end.
>> end.
>> rx
>> )
>>
>>
>> Then, for example using above array a,
>>
>> (i.# a) dsoftmax"0 _ a
>>
>> gives the values, in a 4x4 matrix.
>>
>> This is quite slow. I have tried to do this without iterating and branching,
>> but cannot figure out a way to do it.
>> Any help appreciated.
>> Thanks,
>>
>> Jon
>> ----------------------------------------------------------------------
>> For information about J forums see http://www.jsoftware.com/forums.htm
>> ----------------------------------------------------------------------
>> For information about J forums see http://www.jsoftware.com/forums.htm
>
> ----------------------------------------------------------------------
> For information about J forums see http://www.jsoftware.com/forums.htm
----------------------------------------------------------------------
For information about J forums see http://www.jsoftware.com/forums.htm