jax.numpy.heaviside#
- jax.numpy.heaviside(x1, x2, /)[原始碼]#
計算 heaviside 步階函數。
JAX 版本的
numpy.heaviside
。heaviside 步階函數定義為
\[\begin{split}\mathrm{heaviside}(x1, x2) = \begin{cases} 0., & x < 0\\ x2, & x = 0\\ 1., & x > 0. \end{cases}\end{split}\]- 參數:
x1 (ArrayLike) – 輸入陣列或純量。
complex
dtype 不支援。x2 (ArrayLike) – 純量或陣列。指定當
x1
為0
時的回傳值。complex
dtype 不支援。x1
和x2
必須具有相同形狀或可廣播相容。
- 回傳值:
一個包含
x1
的 heaviside 步階函數的陣列,提升為非精確 dtype。- 回傳型別:
範例
>>> x1 = jnp.array([[-2, 0, 3], ... [5, -1, 0], ... [0, 7, -3]]) >>> x2 = jnp.array([2, 0.5, 1]) >>> jnp.heaviside(x1, x2) Array([[0. , 0.5, 1. ], [1. , 0. , 1. ], [2. , 1. , 0. ]], dtype=float32) >>> jnp.heaviside(x1, 0.5) Array([[0. , 0.5, 1. ], [1. , 0. , 0.5], [0.5, 1. , 0. ]], dtype=float32) >>> jnp.heaviside(-3, x2) Array([0., 0., 0.], dtype=float32)