Dataset и DataLoader

Учимся правильно кормить модель данными: батчами, с перемешиванием, без ручной возни.

DataLoader — обёртка, которая берёт Dataset и выдаёт данные удобными батчами, при желании перемешивая их.

Зачем не обучать на всём сразу

В примерах выше мы подавали в модель сразу все 100 точек. На реальных данных (миллионы примеров) так нельзя — не влезет в память. Поэтому обучают батчами: берут по N примеров, делают шаг, берут следующие N. Это и память экономит, и помогает обучению (шум от маленьких батчей не даёт застрять). Разбивать данные на батчи и перемешивать их вручную — рутина, и PyTorch её автоматизирует через пару Dataset + DataLoader.

Dataset: доступ к примерам

Dataset — это класс, который умеет ответить на два вопроса: «сколько у тебя примеров» (__len__) и «дай пример номер i» (__getitem__). Свой Dataset пишут так:

import torch
from torch.utils.data import Dataset, DataLoader

class PointsDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)            # сколько примеров

    def __getitem__(self, i):
        return self.X[i], self.y[i]   # i-й пример: (признаки, метка)

dataset = PointsDataset(X, y)
print(len(dataset))      # число примеров
print(dataset[0])        # первый пример: (x0, y0)

Идея: Dataset знает, как достать один пример (хоть из памяти, хоть с диска, хоть по сети), но ничего не знает про батчи. Этим займётся DataLoader.

DataLoader: батчи и перемешивание

DataLoader оборачивает Dataset и превращает его в источник батчей, по которому можно итерироваться циклом for:

loader = DataLoader(dataset, batch_size=16, shuffle=True)

for X_batch, y_batch in loader:
    print(X_batch.shape)   # torch.Size([16, 1]) — батч из 16 примеров
    break

Два главных параметра:

  • batch_size — сколько примеров в одном батче (типично 16, 32, 64);
  • shuffle=True — перемешивать порядок каждую эпоху. Для обучения почти всегда нужно: иначе модель увидит данные в одном и том же порядке и может «выучить порядок». Для валидации shuffle=False.

DataLoader сам собирает отдельные примеры в батч-тензор: 16 пар (x, y) превращаются в тензоры X_batch формы (16, 1) и y_batch (16, 1). Возиться с конкатенацией вручную не надо.

Как это меняет цикл обучения

Цикл обучения теперь двойной: внешний по эпохам, внутренний по батчам из DataLoader. Пять шагов остаются те же, но выполняются на каждом батче:

for epoch in range(num_epochs):
    for X_batch, y_batch in loader:        # внутренний цикл по батчам
        pred = model(X_batch)
        loss = criterion(pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Один проход по всем батчам = одна эпоха. Теперь обучение масштабируется на любой объём данных: DataLoader подаёт их по кусочку.

Готовые датасеты

Часто писать Dataset руками не нужно — для популярных наборов есть готовые. Например, torchvision.datasets.MNIST сам скачает рукописные цифры и отдаст их как Dataset, который можно сразу обернуть в DataLoader. Свой Dataset пишут, когда данные нестандартные (свои картинки, свой формат).

Итог

  • Реальные данные обучают батчами — иначе не влезут в память и хуже сходится.
  • Dataset отвечает на __len__ и __getitem__ — как достать один пример.
  • DataLoader нарезает Dataset на батчи и перемешивает (shuffle=True для обучения).
  • Цикл обучения становится двойным: по эпохам и по батчам внутри.
Проверьте себя
1. Какие два метода обязан реализовать класс Dataset?
Afit и predict
B__len__ и __getitem__
Cforward и backward
Dtrain и eval
2. За что отвечает параметр shuffle=True в DataLoader?
AПеремешивает признаки внутри одного примера
BПеремешивает порядок примеров каждую эпоху
CУвеличивает размер батча
DВключает обучение на GPU
3. Зачем вообще обучать батчами, а не на всём датасете сразу?
AБатчи всегда дают нулевой loss
BБольшие данные не влезают в память, а батчи ещё и помогают сходимости
CБез батчей не работает backward
DБатчи отключают переобучение
Поддержать проект