jax.scipy.special.softmax#
- jax.scipy.special.softmax(x, /, *, axis=None)[source]#
Softmax function.
JAX implementation of
scipy.special.softmax().Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axissum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters:
- Returns:
An array of the same shape as
x.- Return type:
Note
If any input values are
+inf, the result will be allNaN: this reflects the fact thatinf / infis not well-defined in the context of floating-point math.See also