jax.scipy.cluster.vq.vq#
- jax.scipy.cluster.vq.vq(obs, code_book, check_finite=True)[source]#
Assign codes from a code book to a set of observations.
JAX implementation of
scipy.cluster.vq.vq().Assigns each observation vector in
obsto a code fromcode_bookbased on the nearest Euclidean distance.- Parameters:
obs (ArrayLike) – array of observation vectors of shape
(M, N). Each row represents a single observation. Ifobsis one-dimensional, then each entry is treated as a length-1 observation.code_book (ArrayLike) – array of codes with shape
(K, N). Each row represents a single code vector. Ifcode_bookis one-dimensional, then each entry is treated as a length-1 code.check_finite (bool) – unused in JAX
- Returns:
A tuple of arrays
(code, dist)codeis an integer array of shape(M,)containing indices0 <= i < Kof the closest entry incode_bookfor the given entry inobs.distis a float array of shape(M,)containing the euclidean distance between each observation and the nearest code.
- Return type:
Examples
>>> obs = jnp.array([[1.1, 2.1, 3.1], ... [5.9, 4.8, 6.2]]) >>> code_book = jnp.array([[1., 2., 3.], ... [2., 3., 4.], ... [3., 4., 5.], ... [4., 5., 6.]]) >>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book) >>> print(codes) [0 3] >>> print(distances) [0.17320499 1.9209373 ]