zero_grad: почему градиенты обнуляют
Разбираем неочевидную, но критичную деталь: почему в каждом цикле обучения есть zero_grad().
zero_grad() — обнуление накопленных градиентов перед новым обратным проходом, без которого градиенты разных шагов суммируются и обучение ломается.
Градиенты накапливаются, а не перезаписываются
Ключевой и контринтуитивный факт: каждый вызов .backward() не заменяет старое значение в .grad, а прибавляет к нему. Если вызвать backward дважды подряд, не очистив .grad, градиенты сложатся:
import torch
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward()
print(x.grad) # tensor(4.) — это 2*x
# ещё раз, БЕЗ обнуления
y = x ** 2
y.backward()
print(x.grad) # tensor(8.) — НЕ 4! сложилось 4 + 4
Видите: второй раз в x.grad оказалось 8, хотя честный градиент по-прежнему 4. Старое значение никуда не делось — новое прибавилось к нему.
Как чистят градиенты
Поэтому перед каждым новым backward() градиенты обнуляют. Способ зависит от того, есть ли у вас оптимизатор:
# вручную для одного тензора
if x.grad is not None:
x.grad.zero_()
# в реальном обучении — через оптимизатор (разберём позже)
optimizer.zero_grad() # обнулит .grad у всех параметров модели
loss.backward()
optimizer.step()
В настоящем training loop вы почти всегда увидите тройку zero_grad() → backward() → step(). Пропуск zero_grad() — одна из самых частых и коварных ошибок новичка: код не падает, но модель учится плохо, потому что градиенты «помнят» все прошлые шаги.
Что именно сломается, если забыть
Без обнуления градиент на шаге N будет суммой градиентов шагов 1..N. Шаг оптимизатора станет слишком большим и направленным «не туда» — обучение будет расходиться или скакать. Самое неприятное: ошибки в логах нет, просто loss ведёт себя странно. Поэтому забытый zero_grad() — классический пункт в любом чек-листе отладки.
Зачем вообще сделали накопление
Возникает вопрос: почему разработчики не очищают градиент автоматически? Накопление — это полезная фича для продвинутых сценариев. Главный из них — gradient accumulation: когда большой батч не влезает в память GPU, его дробят на части, прогоняют по очереди, дают градиентам накопиться, и только потом делают один шаг оптимизатора. Получается эффект большого батча при маленькой памяти. Так что накопление — это сознательный дизайн, а zero_grad() — цена за гибкость.
# gradient accumulation: эффект большого батча
optimizer.zero_grad()
for i, mini in enumerate(chunks): # дробим батч на куски
loss = compute_loss(mini)
loss.backward() # градиенты НАКАПЛИВАЮТСЯ — это нам нужно
# один шаг по сумме градиентов всех кусков
optimizer.step()
Итог
.backward()прибавляет градиенты к.grad, а не перезаписывает их.- Поэтому перед каждым обратным проходом вызывают
optimizer.zero_grad(). - Забытый
zero_grad()не даёт ошибки, но портит обучение — градиенты суммируются по шагам. - Накопление сделано намеренно: оно позволяет приём gradient accumulation для больших батчей.