detach и no_grad: когда градиенты не нужны
Учимся отключать autograd там, где градиенты не нужны, — ради скорости и памяти.
torch.no_grad() — контекст, внутри которого PyTorch не строит граф; .detach() — метод, возвращающий копию тензора, отрезанную от графа.
Зачем выключать граф
Построение графа стоит памяти и времени: PyTorch хранит все промежуточные тензоры, чтобы потом пройти назад. Но градиенты нужны только при обучении. Когда мы просто применяем готовую модель (инференс) или считаем метрику на валидации, обратный проход не делается — значит, и граф строить незачем. Отключив его, мы экономим память и ускоряем вычисления.
torch.no_grad() — контекст для инференса
Самый частый инструмент — менеджер контекста with torch.no_grad():. Всё, что выполнено внутри, не попадает в граф:
import torch
w = torch.tensor(2.0, requires_grad=True)
x = torch.tensor(5.0)
# обычный режим — граф строится
y = w * x
print(y.requires_grad) # True
# режим без градиентов
with torch.no_grad():
y2 = w * x
print(y2.requires_grad) # False — граф не строился
В реальном коде вы будете оборачивать в no_grad() весь цикл предсказаний на тесте или в проде. Это стандарт: меньше памяти, выше скорость, поведение модели то же самое.
.detach() — отрезать один тензор от графа
Иногда нужно вытащить значение из середины вычислений и использовать его дальше так, чтобы градиент через него не шёл. Для этого есть .detach() — он возвращает тензор с теми же числами, но без связи с графом (и с requires_grad=False).
w = torch.tensor(3.0, requires_grad=True)
y = w * w # y в графе, requires_grad=True
y_const = y.detach() # те же 9.0, но вне графа
print(y.requires_grad) # True
print(y_const.requires_grad) # False
Классический случай — логирование и метрики. Чтобы накопить значение loss для печати, его отрезают от графа, иначе вы случайно удержите в памяти весь граф ради одного числа:
# правильно: для статистики берём число без графа
running_loss += loss.detach().item()
Мы уже встречали detach в мосте с numpy: tensor.detach().cpu().numpy() — там это нужно ровно по той же причине, нельзя превратить в массив тензор, сидящий в графе.
Чем отличаются no_grad и detach
| Инструмент | Область действия | Когда применять |
with torch.no_grad(): | весь блок кода внутри | инференс, валидация — целиком |
.detach() | один конкретный тензор | достать значение для лога/метрики/numpy |
Грубое правило: оборачиваете процесс — no_grad; вынимаете одно значение — detach.
Частая ошибка производительности
Если в цикле обучения копить статистику как total += loss (без detach/item), вы нечаянно сохраните ссылку на граф каждой итерации, и память будет течь, пока процесс не упадёт по нехватке памяти. Поэтому для метрик всегда loss.item() или loss.detach().
Итог
- Граф нужен только для обучения; на инференсе и валидации его отключают ради скорости и памяти.
with torch.no_grad():отключает построение графа для всего блока — это для инференса..detach()отрезает один тензор от графа — для метрик, логов и перехода в numpy.- Накапливать loss для статистики надо через
.item()/.detach(), иначе утечёт память.