Dataset и DataLoader
Учимся правильно кормить модель данными: батчами, с перемешиванием, без ручной возни.
DataLoader — обёртка, которая берёт Dataset и выдаёт данные удобными батчами, при желании перемешивая их.
Зачем не обучать на всём сразу
В примерах выше мы подавали в модель сразу все 100 точек. На реальных данных (миллионы примеров) так нельзя — не влезет в память. Поэтому обучают батчами: берут по N примеров, делают шаг, берут следующие N. Это и память экономит, и помогает обучению (шум от маленьких батчей не даёт застрять). Разбивать данные на батчи и перемешивать их вручную — рутина, и PyTorch её автоматизирует через пару Dataset + DataLoader.
Dataset: доступ к примерам
Dataset — это класс, который умеет ответить на два вопроса: «сколько у тебя примеров» (__len__) и «дай пример номер i» (__getitem__). Свой Dataset пишут так:
import torch
from torch.utils.data import Dataset, DataLoader
class PointsDataset(Dataset):
def __init__(self, X, y):
self.X = X
self.y = y
def __len__(self):
return len(self.X) # сколько примеров
def __getitem__(self, i):
return self.X[i], self.y[i] # i-й пример: (признаки, метка)
dataset = PointsDataset(X, y)
print(len(dataset)) # число примеров
print(dataset[0]) # первый пример: (x0, y0)
Идея: Dataset знает, как достать один пример (хоть из памяти, хоть с диска, хоть по сети), но ничего не знает про батчи. Этим займётся DataLoader.
DataLoader: батчи и перемешивание
DataLoader оборачивает Dataset и превращает его в источник батчей, по которому можно итерироваться циклом for:
loader = DataLoader(dataset, batch_size=16, shuffle=True)
for X_batch, y_batch in loader:
print(X_batch.shape) # torch.Size([16, 1]) — батч из 16 примеров
break
Два главных параметра:
batch_size— сколько примеров в одном батче (типично 16, 32, 64);shuffle=True— перемешивать порядок каждую эпоху. Для обучения почти всегда нужно: иначе модель увидит данные в одном и том же порядке и может «выучить порядок». Для валидацииshuffle=False.
DataLoader сам собирает отдельные примеры в батч-тензор: 16 пар (x, y) превращаются в тензоры X_batch формы (16, 1) и y_batch (16, 1). Возиться с конкатенацией вручную не надо.
Как это меняет цикл обучения
Цикл обучения теперь двойной: внешний по эпохам, внутренний по батчам из DataLoader. Пять шагов остаются те же, но выполняются на каждом батче:
for epoch in range(num_epochs):
for X_batch, y_batch in loader: # внутренний цикл по батчам
pred = model(X_batch)
loss = criterion(pred, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Один проход по всем батчам = одна эпоха. Теперь обучение масштабируется на любой объём данных: DataLoader подаёт их по кусочку.
Готовые датасеты
Часто писать Dataset руками не нужно — для популярных наборов есть готовые. Например, torchvision.datasets.MNIST сам скачает рукописные цифры и отдаст их как Dataset, который можно сразу обернуть в DataLoader. Свой Dataset пишут, когда данные нестандартные (свои картинки, свой формат).
Итог
- Реальные данные обучают батчами — иначе не влезут в память и хуже сходится.
Datasetотвечает на__len__и__getitem__— как достать один пример.DataLoaderнарезает Dataset на батчи и перемешивает (shuffle=Trueдля обучения).- Цикл обучения становится двойным: по эпохам и по батчам внутри.