The dtype of  grad_steps and s_ is float64 while self.truncate_gradient is 
a python float.
Sorry I didn't answer it properly previously.



