jax.numpy.left_shift#

jax.numpy.left_shift(x, y, /)[原始碼]#

x 的位元向左移動 y 指定的量,逐元素運算。

numpy.left_shift 的 JAX 實作。

參數:
  • x (ArrayLike) – 輸入陣列,必須為整數類型。

  • y (ArrayLike) – 將 x 中每個元素向左移動的位元數,僅接受整數子類型。xy 必須具有相同的形狀或可廣播相容。

返回:

一個陣列,包含 x 中向左移動 y 指定量的元素,其形狀與 xy 的廣播形狀相同。

返回類型:

陣列

注意

在所涉及的 dtype 範圍內,將 x 左移 y 位元等同於 x * (2**y)

參見

範例

>>> def print_binary(x):
...   return [bin(int(val)) for val in x]
>>> x1 = jnp.arange(5)
>>> x1
Array([0, 1, 2, 3, 4], dtype=int32)
>>> print_binary(x1)
['0b0', '0b1', '0b10', '0b11', '0b100']
>>> x2 = 1
>>> result = jnp.left_shift(x1, x2)
>>> result
Array([0, 2, 4, 6, 8], dtype=int32)
>>> print_binary(result)
['0b0', '0b10', '0b100', '0b110', '0b1000']
>>> x3 = 4
>>> print_binary([x3])
['0b100']
>>> x4 = jnp.array([1, 2, 3, 4])
>>> result1 = jnp.left_shift(x3, x4)
>>> result1
Array([ 8, 16, 32, 64], dtype=int32)
>>> print_binary(result1)
['0b1000', '0b10000', '0b100000', '0b1000000']