jax.experimental.shard_map 模組#

API#

shard_map(f, mesh, in_specs, out_specs[, ...])

將函數映射到資料分片上。