jax.numpy.repeat#

jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)[原始碼]#

從重複的元素建構陣列。

numpy.repeat() 的 JAX 實作。

參數:
  • a (ArrayLike) – N 維陣列

  • repeats (ArrayLike) – 1 維整數陣列,指定重複次數。必須符合重複軸的長度。

  • axis (int | None | None) – 整數,指定要沿著哪個軸在 a 中建構重複陣列。如果為 None (預設),則先將 a 平坦化。

  • total_repeat_length (int | None | None) – 為了使 jnp.repeatjit() 和其他 JAX 轉換相容,必須靜態指定此值。如果 sum(repeats) 大於指定的 total_repeat_length,則會捨棄剩餘的值。如果 sum(repeats) 小於 total_repeat_length,則最後一個值將會重複。

返回:

a 的重複值建構的陣列。

返回類型:

Array

參見

範例

沿著最後一個軸將每個值重複兩次

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.repeat(a, 2, axis=-1)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

如果未指定 axis,則輸入陣列將會被平坦化

>>> jnp.repeat(a, 2)
Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)

傳遞陣列至 repeats 以將每個值重複不同的次數

>>> repeats = jnp.array([2, 3])
>>> jnp.repeat(a, repeats, axis=1)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

為了在 jit 和其他 JAX 轉換中使用 repeat,必須使用 total_repeat_length 靜態指定輸出的大小

>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length'])
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=5)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

如果 total_repeat_length 小於 sum(repeats),結果將會被截斷

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

如果它更大,則額外的條目將會以最後一個值填滿

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7)
Array([[1, 1, 2, 2, 2, 2, 2],
       [3, 3, 4, 4, 4, 4, 4]], dtype=int32)