Model training with Hugging Face

Model setup

Let’s consider a simple CNN model that trains on MNIST dataset to predict handwritten digits.

The example model is similar to the one shown in the intro-to-dl-course.

import os
# Here environment variable WRKDIR points to a personal work directory
os.environ["HF_HOME"] = f"{os.environ["WRKDIR"]}/huggingface"
os.environ["HF_TOKEN_PATH"] = "~/.cache/huggingface/token"
import torch
import transformers

from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

Let’s define the data loaders for the model:

data_dir = './data'

batch_size = 32

train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=ToTensor())
test_dataset = datasets.MNIST(data_dir, train=False, transform=ToTensor())

Let’s define the model architecture:

from torch import nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding='valid'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(32*13*13, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.layers(x)

model = SimpleCNN().to(device)
print(model)
SimpleCNN(
  (layers): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=valid)
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=5408, out_features=128, bias=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=10, bias=True)
  )
)

Data loading and Trainer setup

Now we could write our own training loop that would train the model, but we can also use Hugging Face Trainer.

Given a model and the datasets we want to use, Trainer will automatically handle the model training.

Trainer will automatically create DataLoaders for the datasets that we have, but depending on the data you might need to specify a data collator that combines individual data instances into batches.

Because the data coming from Torchvision datasets contains a tuple of an image tensor and corresponding label, we can combine the data into batches with the following collator function:

def collator_fn(data):
    images = torch.stack([d[0] for d in data])
    labels = torch.tensor([d[1] for d in data])
    return {"images":images, "labels":labels}

To train the model we need to specify the loss function that we optimize. This is usually achieved by extending the Trainer-class.

Our new MNISTTrainer specifies the loss calculation in the compute_loss-function:

from transformers import Trainer

class MNISTTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        images = inputs.pop('images')
        target = inputs.pop('labels')
        output = model(images)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        return (loss, outputs) if return_outputs else loss

trainer = MNISTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collator_fn,
)
    

Now that we have the Trainer specified, we can train the model:

trainer.train()
[22500/22500 00:43, Epoch 3/3]
Step Training Loss
500 1.627500
1000 0.758000
1500 0.490400
2000 0.393400
2500 0.359800
3000 0.346300
3500 0.312300
4000 0.331600
4500 0.313500
5000 0.292200
5500 0.315700
6000 0.273300
6500 0.301700
7000 0.285800
7500 0.276600
8000 0.270300
8500 0.260600
9000 0.250300
9500 0.260000
10000 0.245300
10500 0.262400
11000 0.247400
11500 0.237900
12000 0.240100
12500 0.210700
13000 0.251200
13500 0.223500
14000 0.233500
14500 0.235600
15000 0.229100
15500 0.230500
16000 0.243400
16500 0.207800
17000 0.211100
17500 0.206400
18000 0.221400
18500 0.197000
19000 0.216900
19500 0.216900
20000 0.213000
20500 0.216000
21000 0.215500
21500 0.199600
22000 0.202100
22500 0.228600

TrainOutput(global_step=22500, training_loss=0.30138072950575084, metrics={'train_runtime': 43.9328, 'train_samples_per_second': 4097.163, 'train_steps_per_second': 512.145, 'total_flos': 0.0, 'train_loss': 0.30138072950575084, 'epoch': 3.0})

We can test out predictions of the model by running some data through the model:

Checking model results

%matplotlib inline
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

example_loader = DataLoader(test_dataset, batch_size=32)

example_data = next(iter(example_loader))

example_images = example_data[0]
example_labels = example_data[1]

predicted_labels = model(example_data[0].to(device))
fig, axes = plt.subplots(4, 8, figsize=(15,12))

for i in range(len(example_images)):
    plt.subplot(4, 8, i+1)
    
    image = example_images[i][0]
    real_label = example_labels[i]
    predicted_label = torch.argmax(predicted_labels[i])
    plt.imshow(image)
    plt.title(f"Real: {real_label} Pred: {predicted_label}")
    plt.axis("off")
../_images/241677c6f170a1c9908c6b7b98f13c06b21102821e98b32507d1dbc9fdff828a.png

It looks like our model works as expected.

Customizing Trainer with TrainingArguments

Trainer supports customizing most aspects of it via TrainingArguments.

Arguments range from logging configurations to optimizer settings.

For example if we want to train for 1 epoch and want output only every 1000 steps, we can configure it easily:

from transformers import TrainingArguments

training_args = TrainingArguments(
    logging_steps=1000,
    num_train_epochs=1
)

trainer = MNISTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collator_fn,
)

trainer.train()
[7500/7500 00:14, Epoch 1/1]
Step Training Loss
1000 0.215100
2000 0.189400
3000 0.206700
4000 0.203200
5000 0.194800
6000 0.192600
7000 0.202900

TrainOutput(global_step=7500, training_loss=0.19971505533854167, metrics={'train_runtime': 14.5625, 'train_samples_per_second': 4120.183, 'train_steps_per_second': 515.023, 'total_flos': 0.0, 'train_loss': 0.19971505533854167, 'epoch': 1.0})