jax.nn.standardize#

jax.nn.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#

Standardizes input to zero mean and unit variance.

The standardization is given by:

\[x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}}\]

where \(\langle x\rangle\) indicates the mean of \(x\), and \(\epsilon\) is a small correction factor introduced to avoid division by zero.

Parameters:
  • x (ArrayLike) – input array to be standardized.

  • axis (Axis) – integer or tuple of integers representing the axes along which to standardize. Defaults to the last axis (-1).

  • mean (ArrayLike | None) – optionally specify the mean used for standardization. If not specified, then x.mean(axis, where=where) will be used.

  • variance (ArrayLike | None) – optionally specify the variance used for standardization. If not specified, then x.var(axis, where=where) will be used.

  • epsilon (ArrayLike) – correction factor added to variance to avoid division by zero; defaults to 1E-5.

  • where (ArrayLike | None) – optional boolean mask specifying which elements to use when computing the mean and variance.

Returns:

An array of the same shape as x containing the standardized input.

Return type:

Array