jax.numpy.repeat#
- jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)[原始碼]#
從重複的元素建構陣列。
numpy.repeat()
的 JAX 實作。- 參數:
a (ArrayLike) – N 維陣列
repeats (ArrayLike) – 1 維整數陣列,指定重複次數。必須符合重複軸的長度。
axis (int | None | None) – 整數,指定要沿著哪個軸在
a
中建構重複陣列。如果為 None (預設),則先將a
平坦化。total_repeat_length (int | None | None) – 為了使
jnp.repeat
與jit()
和其他 JAX 轉換相容,必須靜態指定此值。如果sum(repeats)
大於指定的total_repeat_length
,則會捨棄剩餘的值。如果sum(repeats)
小於total_repeat_length
,則最後一個值將會重複。
- 返回:
從
a
的重複值建構的陣列。- 返回類型:
參見
jax.numpy.tile()
:重複整個陣列,而不是個別值。
範例
沿著最後一個軸將每個值重複兩次
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.repeat(a, 2, axis=-1) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
如果未指定
axis
,則輸入陣列將會被平坦化>>> jnp.repeat(a, 2) Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
傳遞陣列至
repeats
以將每個值重複不同的次數>>> repeats = jnp.array([2, 3]) >>> jnp.repeat(a, repeats, axis=1) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
為了在
jit
和其他 JAX 轉換中使用repeat
,必須使用total_repeat_length
靜態指定輸出的大小>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length']) >>> jit_repeat(a, repeats, axis=1, total_repeat_length=5) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
如果 total_repeat_length 小於
sum(repeats)
,結果將會被截斷>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
如果它更大,則額外的條目將會以最後一個值填滿
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7) Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32)