Сохранение и загрузка модели (state_dict)

Учимся сохранять результат обучения на диск и поднимать его обратно правильным способом.

state_dict — словарь, в котором лежат все обучаемые веса модели; именно его сохраняют на диск, а не модель целиком.

Что вообще сохранять

Обучение может идти часы — терять результат при перезапуске нельзя. Но сохранять нужно не «модель» как объект Python, а её веса. Они лежат в model.state_dict() — обычном словаре «имя параметра → тензор»:

import torch

sd = model.state_dict()
for name in sd:
    print(name, tuple(sd[name].shape))
# fc1.weight (128, 784)
# fc1.bias   (128,)
# fc2.weight (10, 128)
# fc2.bias   (10,)

Сохранение

Рекомендуемый способ — сохранить именно state_dict через torch.save. Файлам принято давать расширение .pt или .pth:

torch.save(model.state_dict(), "model.pth")

Загрузка

Загрузка — в два шага, и это важно. Сначала надо создать модель той же архитектуры (пустую, со случайными весами), а потом залить в неё сохранённые веса:

# 1. создаём модель той же структуры
model = Net()

# 2. грузим веса в неё
model.load_state_dict(torch.load("model.pth"))

# 3. для инференса переводим в режим оценки
model.eval()

Почему так: state_dict хранит только числа весов, но не саму структуру сети. Поэтому код модели (класс Net) должен быть под рукой, чтобы было куда эти числа положить. Не забудьте model.eval() после загрузки, если собираетесь предсказывать — иначе dropout/batchnorm останутся в режиме обучения.

Почему не сохранять модель целиком

Можно технически сделать torch.save(model, ...) и сохранить весь объект. Но так не рекомендуют: этот способ жёстко привязывается к структуре вашего кода и путям импортов. Переименуете класс или переложите файл — загрузка сломается. state_dict же — просто словарь чисел, он переносим и надёжен.

СпособЧто сохраняетНадёжность
state_dict (рекомендуется)только веса (словарь чисел)переносим, надёжен
вся модельвеса + привязку к кодуломается при изменении кода

Чекпойнт: больше, чем веса

В долгом обучении сохраняют не только веса модели, но и состояние оптимизатора и номер эпохи — чтобы можно было продолжить ровно с места остановки. Всё это кладут в один словарь:

# сохранить чекпойнт
checkpoint = {
    "epoch": epoch,
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, "checkpoint.pth")

# восстановить и продолжить обучение
ckpt = torch.load("checkpoint.pth")
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt["epoch"] + 1

Оптимизатор тоже хранит состояние (например, накопленные моменты у Adam), поэтому для бесшовного продолжения сохраняют и его state_dict.

Итог

  • Сохраняют веса через torch.save(model.state_dict(), путь), а не модель целиком.
  • Загрузка в два шага: создать модель той же архитектуры, затем load_state_dict.
  • После загрузки для инференса вызывают model.eval().
  • Чекпойнт для продолжения обучения хранит и веса, и состояние оптимизатора, и эпоху.
Проверьте себя
1. Что именно рекомендуют сохранять при работе с моделью?
AВесь объект модели через torch.save(model)
Bmodel.state_dict() — словарь весов
CТолько архитектуру без весов
DГрадиенты модели
2. Что нужно сделать перед вызовом load_state_dict?
AУдалить старую модель
BСоздать модель той же архитектуры, в которую загрузятся веса
CПеревести модель на GPU
DВызвать backward
3. Зачем в чекпойнт для продолжения обучения кладут optimizer.state_dict()?
AЧтобы ускорить загрузку весов
BОптимизатор хранит своё состояние (например, моменты Adam), нужное для бесшовного продолжения
CБез этого модель не загрузится
DЧтобы уменьшить размер файла
Поддержать проект