I was about to point out the same thing.

   a =: 0.5 0.6 0.23 0.66
   sm=:(] % +/ )@:^ NB. softmax
   softmax=: { sm
   dsoftmax=: 4 : 0
     idx=. x
     vals=. y
     smx=. idx softmax vals
     rx=. ''
     for_j. i.#vals do.
       if. j = idx do. rx=. rx , smx * (1 - smx)
       elseif. 1 do. rx=. rx ,(j softmax vals)* (0 - smx) end.
     end.
     rx
  )

   sm D.1 a
  0.186192 _0.0676431 _0.0467234 _0.0718259
_0.0676431    0.19866 _0.0516374 _0.0793799
_0.0467234 _0.0516374   0.153191 _0.0548305
_0.0718259 _0.0793799 _0.0548305   0.206036
   (i.# a) dsoftmax"0 _  a
  0.186192 _0.0676431 _0.0467234 _0.0718259
_0.0676431    0.19866 _0.0516374 _0.0793799
_0.0467234 _0.0516374   0.153191 _0.0548305
_0.0718259 _0.0793799 _0.0548305   0.206036

The speedup is not too impressive, but it is a speedup (probably
because we are retaining and reusing all results from sm rather than
recomputing it so many times -- I imagine using sm directly and
lifting it out of the loop ):

   timespacex '(i.# a) dsoftmax"0 _  a'
5.5e_5 6528
   timespacex 'sm D.1 a'
2.3e_5 7424

That said, note that we can approximate this speedup by using a
variant on what Pascal proposed:

d_softmax=: 4 : 0
  rx=. i.0 0
  smv=. sm y
  for_i. x do.
    ry=. i.0
    smx=. i { smv
    for_j. i.#y do.
      if. j_index = i_index do. ry=. ry , smx * (1 - smx)
      else.  ry=. ry ,(j {smv)* (0 - smx) end.
    end.
    rx=.rx, ry
  end.
  rx
)

   (i.# a) d_softmax  a
  0.186192 _0.0676431 _0.0467234 _0.0718259
_0.0676431    0.19866 _0.0516374 _0.0793799
_0.0467234 _0.0516374   0.153191 _0.0548305
_0.0718259 _0.0793799 _0.0548305   0.206036
   timespacex '(i.# a) d_softmax  a'
3e_5 6016

(Remember that it's generally a good idea to ignore speedups which are
less than a factor of 2, because of scheduling issues within the
machine itself - you can see this by inspecting multiple timing runs)

   timespacex '(i.# a) d_softmax  a'
3e_5 6016
   timespacex 'sm D.1 a'
2.3e_5 7424
   timespacex '(i.# a) d_softmax  a'
2.9e_5 6016
   timespacex 'sm D.1 a'
2.3e_5 7424
   timespacex '(i.# a) d_softmax  a'
2.8e_5 6016
   timespacex 'sm D.1 a'
3.7e_5 7424
   timespacex '(i.# a) d_softmax  a'
3.2e_5 6016

I hope this helps,

-- 
Raul

On Mon, Feb 27, 2017 at 8:05 AM, Louis de Forcrand <[email protected]> wrote:
> You probably know about it, but I'll mention it anyway: there's a primitive 
> partial derivative operator in J. I think it would do exactly what you want 
> (numerically), and it's probably reasonably fast. It's not too hard to use 
> either:
>
> dsoftmax=: sm D.1
>
> Louis
>
>> On 27 Feb 2017, at 10:20, 'Pascal Jasmin' via Programming 
>> <[email protected]> wrote:
>>
>> one optimization is removing the rank"0 _, so that the function not need to 
>> be reparsed for each x
>>
>> untested.
>>
>>
>> dsoftmax=: 4 : 0
>> rx=. ''
>> for_i. x do.
>> smx=. i softmax y
>> for_j. i.#vals do.
>> if. j_index. = i_index. do. rx=. rx , smx * (1 - smx)
>> else.  rx=. rx ,(j softmax y)* (0 - smx) end.
>> end. end.
>> rx
>> )
>>
>>
>> ----- Original Message -----
>> From: 'Jon Hough' via Programming <[email protected]>
>> To: Programming Forum <[email protected]>
>> Sent: Monday, February 27, 2017 3:09 AM
>> Subject: [Jprogramming] Fast derivative of Softmax function
>>
>> Given an array, we can calculate the softmax function
>> https://en.wikipedia.org/wiki/Softmax_function
>>
>> a =: 0.5 0.6 0.23 0.66
>> sm=:(] % +/ )@:^ NB. softmax
>>
>> sm a
>> 0.247399 0.273418 0.188859 0.290325
>>
>> The (partial) derivative of softmax is a little more complicated:
>>
>> If the array is of length N, we need an NxN matrix of partial derivatives 
>> where (in pseudo code)
>>
>> derivatives[i,j] = sm (array[i] )  *( 1 - sm(array[j])   if i == j
>> or
>> derivatives[i,j] =  -1 * sm (array[i] )  * ( sm(array[j])   if i != j
>>
>> ( see here for the reasoning: 
>> http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/ )
>>
>> My implementation of the partial derivatives is this:
>>
>>
>> NB. x value is index, y value is the whole array
>> dsoftmax=: 4 : 0
>> idx=. x
>> vals=. y
>> smx=. idx softmax vals
>> rx=. ''
>> for_j. i.#vals do.
>>  if. j = idx do. rx=. rx , smx * (1 - smx)
>>  elseif. 1 do. rx=. rx ,(j softmax vals)* (0 - smx) end.
>> end.
>> rx
>> )
>>
>>
>> Then, for example using above array a,
>>
>> (i.# a) dsoftmax"0 _  a
>>
>> gives the values, in a 4x4 matrix.
>>
>> This is quite slow. I have tried to do this without iterating and branching, 
>> but cannot figure out a way to do it.
>> Any help appreciated.
>> Thanks,
>>
>> Jon
>> ----------------------------------------------------------------------
>> For information about J forums see http://www.jsoftware.com/forums.htm
>> ----------------------------------------------------------------------
>> For information about J forums see http://www.jsoftware.com/forums.htm
>
> ----------------------------------------------------------------------
> For information about J forums see http://www.jsoftware.com/forums.htm
----------------------------------------------------------------------
For information about J forums see http://www.jsoftware.com/forums.htm

Reply via email to