detach и no_grad: когда градиенты не нужны

Учимся отключать autograd там, где градиенты не нужны, — ради скорости и памяти.

torch.no_grad() — контекст, внутри которого PyTorch не строит граф; .detach() — метод, возвращающий копию тензора, отрезанную от графа.

Зачем выключать граф

Построение графа стоит памяти и времени: PyTorch хранит все промежуточные тензоры, чтобы потом пройти назад. Но градиенты нужны только при обучении. Когда мы просто применяем готовую модель (инференс) или считаем метрику на валидации, обратный проход не делается — значит, и граф строить незачем. Отключив его, мы экономим память и ускоряем вычисления.

torch.no_grad() — контекст для инференса

Самый частый инструмент — менеджер контекста with torch.no_grad():. Всё, что выполнено внутри, не попадает в граф:

import torch

w = torch.tensor(2.0, requires_grad=True)
x = torch.tensor(5.0)

# обычный режим — граф строится
y = w * x
print(y.requires_grad)        # True

# режим без градиентов
with torch.no_grad():
    y2 = w * x
    print(y2.requires_grad)   # False — граф не строился

В реальном коде вы будете оборачивать в no_grad() весь цикл предсказаний на тесте или в проде. Это стандарт: меньше памяти, выше скорость, поведение модели то же самое.

.detach() — отрезать один тензор от графа

Иногда нужно вытащить значение из середины вычислений и использовать его дальше так, чтобы градиент через него не шёл. Для этого есть .detach() — он возвращает тензор с теми же числами, но без связи с графом (и с requires_grad=False).

w = torch.tensor(3.0, requires_grad=True)
y = w * w                 # y в графе, requires_grad=True

y_const = y.detach()      # те же 9.0, но вне графа
print(y.requires_grad)        # True
print(y_const.requires_grad)  # False

Классический случай — логирование и метрики. Чтобы накопить значение loss для печати, его отрезают от графа, иначе вы случайно удержите в памяти весь граф ради одного числа:

# правильно: для статистики берём число без графа
running_loss += loss.detach().item()

Мы уже встречали detach в мосте с numpy: tensor.detach().cpu().numpy() — там это нужно ровно по той же причине, нельзя превратить в массив тензор, сидящий в графе.

Чем отличаются no_grad и detach

ИнструментОбласть действияКогда применять
with torch.no_grad():весь блок кода внутриинференс, валидация — целиком
.detach()один конкретный тензордостать значение для лога/метрики/numpy

Грубое правило: оборачиваете процессno_grad; вынимаете одно значениеdetach.

Частая ошибка производительности

Если в цикле обучения копить статистику как total += loss (без detach/item), вы нечаянно сохраните ссылку на граф каждой итерации, и память будет течь, пока процесс не упадёт по нехватке памяти. Поэтому для метрик всегда loss.item() или loss.detach().

Итог

  • Граф нужен только для обучения; на инференсе и валидации его отключают ради скорости и памяти.
  • with torch.no_grad(): отключает построение графа для всего блока — это для инференса.
  • .detach() отрезает один тензор от графа — для метрик, логов и перехода в numpy.
  • Накапливать loss для статистики надо через .item()/.detach(), иначе утечёт память.
Проверьте себя
1. Когда оборачивают код в with torch.no_grad()?
AПри обучении, чтобы ускорить backward
BПри инференсе и валидации, где градиенты не нужны
CВсегда, во всех вычислениях с тензорами
DТолько при работе с numpy
2. Что вернёт y.detach()?
AГрадиент тензора y
BКопию y с теми же числами, но отрезанную от графа
CТензор на GPU
DNone, метод действует на месте
3. Почему running_loss += loss (без detach/item) опасно в цикле обучения?
Aloss нельзя складывать
BТак теряется точность вычислений
CСохраняются ссылки на графы всех итераций и память течёт
DЭто замедлит только первую итерацию
Поддержать проект