jax.nn.standardize#
- jax.nn.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#
Standardizes input to zero mean and unit variance.
The standardization is given by:
\[x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}}\]where \(\langle x\rangle\) indicates the mean of \(x\), and \(\epsilon\) is a small correction factor introduced to avoid division by zero.
- Parameters:
x (ArrayLike) – input array to be standardized.
axis (Axis) – integer or tuple of integers representing the axes along which to standardize. Defaults to the last axis (
-1).mean (ArrayLike | None) – optionally specify the mean used for standardization. If not specified, then
x.mean(axis, where=where)will be used.variance (ArrayLike | None) – optionally specify the variance used for standardization. If not specified, then
x.var(axis, where=where)will be used.epsilon (ArrayLike) – correction factor added to variance to avoid division by zero; defaults to
1E-5.where (ArrayLike | None) – optional boolean mask specifying which elements to use when computing the mean and variance.
- Returns:
An array of the same shape as
xcontaining the standardized input.- Return type: