jax.numpy.arctan2#

jax.numpy.arctan2(x1, x2, /)[source]#

Compute the arctangent of x1/x2, choosing the correct quadrant.

JAX implementation of numpy.arctan2()

Parameters:
  • x1 (ArrayLike) – numerator array.

  • x2 (ArrayLike) – denomniator array; should be broadcast-compatible with x1.

Returns:

The elementwise arctangent of x1 / x2, tracking the correct quadrant.

Return type:

Array

See also

Examples

Consider a sequence of angles in radians between 0 and \(2\pi\):

>>> theta = jnp.linspace(-jnp.pi, jnp.pi, 9)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(theta)
[-3.14 -2.36 -1.57 -0.79  0.    0.79  1.57  2.36  3.14]

These angles can equivalently be represented by (x, y) coordinates on a unit circle:

>>> x, y = jnp.cos(theta), jnp.sin(theta)

To reconstruct the input angle, we might be tempted to use the identity \(\tan(\theta) = y / x\), and compute \(\theta = \tan^{-1}(y/x)\). Unfortunately, this does not recover the input angle:

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.arctan(y / x))
[-0.    0.79  1.57 -0.79  0.    0.79  1.57 -0.79  0.  ]

The problem is that \(y/x\) contains some ambiguity: although \((y, x) = (-1, -1)\) and \((y, x) = (1, 1)\) represent different points in Cartesian space, in both cases \(y / x = 1\), and so the simple arctan approach loses information about which quadrant the angle lies in. arctan2() is built to address this:

>>> with jnp.printoptions(precision=2, suppress=True):
...  print(jnp.arctan2(y, x))
[ 3.14 -2.36 -1.57 -0.79  0.    0.79  1.57  2.36 -3.14]

The results match the input theta, except at the endpoints where \(+\pi\) and \(-\pi\) represent indistinguishable points on the unit circle. By convention, arctan2() always returns values between \(-\pi\) and \(+\pi\) inclusive.