使用 PyTorch 資料載入訓練簡單神經網路#
版權所有 2018 JAX 作者群。
依 Apache 授權條款 2.0 版(「授權條款」)授權;除非遵守授權條款,否則您不得使用此檔案。您可以在以下網址取得授權條款副本:
讓我們結合在快速入門中展示的所有內容,來訓練一個簡單的神經網路。我們將首先指定並訓練一個基於 JAX 在 MNIST 上進行計算的簡單 MLP。我們將使用 PyTorch 的資料載入 API 來載入影像和標籤(因為它非常棒,而且世界不需要另一個資料載入函式庫)。
當然,您可以將 JAX 與任何與 NumPy 相容的 API 搭配使用,使指定模型更隨插即用。在這裡,僅為了解釋目的,我們將不使用任何神經網路函式庫或特殊 API 來建構我們的模型。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))
讓我們首先定義我們的預測函數。請注意,我們正在為單個影像範例定義此函數。我們將使用 JAX 的 vmap
from jax.scipy.special import logsumexp
def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
preds = predict(params, random_flattened_images)
except TypeError:
print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`
# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))
# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
(10, 10)
此時,我們擁有定義神經網路並訓練它所需的所有要素。我們建構了 predict
的自動批次版本,我們應該能夠在損失函數中使用它。我們應該能夠使用 grad
來取得損失相對於神經網路參數的導數。最後,我們應該能夠使用 jit
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
使用 PyTorch 載入資料#
JAX 專注於程式轉換和加速器支援的 NumPy,因此我們在 JAX 函式庫中不包含資料載入或整理。已經有很多很棒的資料載入器,所以讓我們直接使用它們,而不是重新發明任何東西。我們將抓取 PyTorch 的資料載入器,並進行微小的墊片,使其與 NumPy 陣列搭配使用。
!pip install torch torchvision
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST
def numpy_collate(batch):
return tree_map(np.asarray, data.default_collate(batch))
class NumpyLoader(data.DataLoader):
def __init__(self, dataset, batch_size=1,
shuffle=False, sampler=None,
batch_sampler=None, num_workers=0,
pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super(self.__class__, self).__init__(dataset,
class FlattenAndCast(object):
def __call__(self, pic):
return np.ravel(np.array(pic, dtype=jnp.float32))
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)
# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)
import time
for epoch in range(num_epochs):
start_time = time.time()
for x, y in training_generator:
y = one_hot(y, n_targets)
params = update(params, x, y)
epoch_time = time.time() - start_time
train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Epoch 0 in 55.15 sec
Training set accuracy 0.9157500267028809
Test set accuracy 0.9195000529289246
Epoch 1 in 42.26 sec
Training set accuracy 0.9372166991233826
Test set accuracy 0.9384000301361084
Epoch 2 in 44.37 sec
Training set accuracy 0.9491666555404663
Test set accuracy 0.9469000697135925
Epoch 3 in 41.75 sec
Training set accuracy 0.9568166732788086
Test set accuracy 0.9534000158309937
Epoch 4 in 41.16 sec
Training set accuracy 0.9631333351135254
Test set accuracy 0.9577000737190247
Epoch 5 in 38.89 sec
Training set accuracy 0.9675000309944153
Test set accuracy 0.9616000652313232
Epoch 6 in 40.68 sec
Training set accuracy 0.9708333611488342
Test set accuracy 0.9650000333786011
Epoch 7 in 41.50 sec
Training set accuracy 0.973716676235199
Test set accuracy 0.9672000408172607
我們現在已經使用了整個 JAX API:用於導數的 grad
、用於加速的 jit
和用於自動向量化的 vmap
。我們使用 NumPy 來指定我們所有的計算,並借用了 PyTorch 的出色資料載入器,並在 GPU 上執行了整個過程。