jax.eval_shape#

jax.eval_shape(fun, *args, **kwargs)[原始碼]#

計算 fun 的形狀/dtype,而無需任何 FLOP。

此實用函式對於執行形狀推斷非常有用。其輸入/輸出行為由以下定義

def eval_shape(fun, *args, **kwargs):
  out = fun(*args, **kwargs)
  shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype)
  return jax.tree_util.tree_map(shape_dtype_struct, out)

但是,它不是直接應用 fun (這可能很昂貴),而是使用 JAX 的抽象解譯機制來評估形狀,而無需執行任何 FLOP。

使用 eval_shape() 也可以捕捉形狀錯誤,並且會引發與評估 fun(*args, **kwargs) 相同的形狀錯誤。

參數:
  • fun (Callable) – 應評估其輸出形狀的函式。

  • *args – 陣列、純量或這些類型的 (巢狀) 標準 Python 容器 (tuple、list、dict、namedtuple,即 pytree) 的位置引數 tuple。由於僅存取 shapedtype 屬性,因此可以使用 jax.ShapeDtypeStruct 或另一個 duck-type 為 ndarray 的容器 (但請注意,duck-type 物件不能是 namedtuple,因為這些物件被視為標準 Python 容器)。

  • **kwargs – 陣列、純量或這些類型的 (巢狀) 標準 Python 容器 (pytree) 的關鍵字引數 dict。與 args 中一樣,陣列值只需要 duck-type 為具有 shapedtype 屬性。

傳回:

包含 jax.ShapeDtypeStruct 物件作為葉節點的巢狀 PyTree。

傳回類型:

out

例如

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
>>> out = jax.eval_shape(f, A, x)  # no FLOPs performed
>>> print(out.shape)
(2000, 1000)
>>> print(out.dtype)
float32

透過 eval_shape() 傳遞的所有引數都將被視為動態;靜態引數可以透過閉包包含,例如使用 functools.partial()

>>> import jax
>>> from jax import lax
>>> from functools import partial
>>> import jax.numpy as jnp
>>>
>>> x = jax.ShapeDtypeStruct((1, 1, 28, 28), jnp.float32)
>>> kernel = jax.ShapeDtypeStruct((32, 1, 3, 3), jnp.float32)
>>>
>>> conv_same = partial(lax.conv_general_dilated, window_strides=(1, 1), padding="SAME")
>>> out = jax.eval_shape(conv_same, x, kernel)
>>> print(out.shape)
(1, 32, 28, 28)
>>> print(out.dtype)
float32