jax.eval_shape#
- jax.eval_shape(fun, *args, **kwargs)[原始碼]#
計算
fun
的形狀/dtype,而無需任何 FLOP。此實用函式對於執行形狀推斷非常有用。其輸入/輸出行為由以下定義
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.tree_util.tree_map(shape_dtype_struct, out)
但是,它不是直接應用
fun
(這可能很昂貴),而是使用 JAX 的抽象解譯機制來評估形狀,而無需執行任何 FLOP。使用
eval_shape()
也可以捕捉形狀錯誤,並且會引發與評估fun(*args, **kwargs)
相同的形狀錯誤。- 參數:
fun (Callable) – 應評估其輸出形狀的函式。
*args – 陣列、純量或這些類型的 (巢狀) 標準 Python 容器 (tuple、list、dict、namedtuple,即 pytree) 的位置引數 tuple。由於僅存取
shape
和dtype
屬性,因此可以使用jax.ShapeDtypeStruct
或另一個 duck-type 為 ndarray 的容器 (但請注意,duck-type 物件不能是 namedtuple,因為這些物件被視為標準 Python 容器)。**kwargs – 陣列、純量或這些類型的 (巢狀) 標準 Python 容器 (pytree) 的關鍵字引數 dict。與
args
中一樣,陣列值只需要 duck-type 為具有shape
和dtype
屬性。
- 傳回:
包含
jax.ShapeDtypeStruct
物件作為葉節點的巢狀 PyTree。- 傳回類型:
out
例如
>>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32) >>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) float32
透過
eval_shape()
傳遞的所有引數都將被視為動態;靜態引數可以透過閉包包含,例如使用functools.partial()
>>> import jax >>> from jax import lax >>> from functools import partial >>> import jax.numpy as jnp >>> >>> x = jax.ShapeDtypeStruct((1, 1, 28, 28), jnp.float32) >>> kernel = jax.ShapeDtypeStruct((32, 1, 3, 3), jnp.float32) >>> >>> conv_same = partial(lax.conv_general_dilated, window_strides=(1, 1), padding="SAME") >>> out = jax.eval_shape(conv_same, x, kernel) >>> print(out.shape) (1, 32, 28, 28) >>> print(out.dtype) float32