backward() и .grad: как считаются градиенты

Нажимаем «кнопку backprop»: как .backward() заполняет .grad и почему результат совпадает с ручной производной.

.backward() — метод, который проходит граф от результата назад к листьям и заполняет их .grad градиентами.

Минимальный пример

Возьмём простейшую функцию одной переменной и попросим её производную в точке. Пусть y = x². Из математики известно: производная dy/dx = 2x, значит в точке x = 3 она равна 6. Посмотрим, что скажет autograd:

import torch

x = torch.tensor(3.0, requires_grad=True)

y = x ** 2          # строим граф: y = x^2
y.backward()         # backprop: посчитать dy/dx

print(x.grad)        # tensor(6.)  — ровно 2*x при x=3

Что произошло по шагам: x ** 2 построил граф, y.backward() прошёл его назад, применил правило для возведения в квадрат (производная 2x) и положил результат в x.grad. Мы не писали формулу производной — PyTorch вывел её сам.

Несколько переменных и chain rule

Сила autograd в том, что он автоматически применяет правило цепочки к длинным выражениям. Пусть результат зависит от двух параметров:

w = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)
x = torch.tensor(4.0)               # данные

y = w * x + b        # y = 2*4 + 1 = 9
y.backward()

print(w.grad)        # tensor(4.)   = dy/dw = x
print(b.grad)        # tensor(1.)   = dy/db = 1

autograd сам разложил частные производные: dy/dw = x = 4, dy/db = 1. Вы только описали forward y = w*x + b — а градиенты по всем параметрам появились бесплатно. Это в точности то, что делает backprop в нейросети, только над миллионами параметров.

backward() требует скаляр

Важное правило: .backward() без аргументов вызывают только на скаляре (тензоре из одного числа). Поэтому в обучении из вектора ошибок всегда получают одно число — функцию потерь (loss) через .mean() или .sum():

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

y = (x ** 2)          # вектор [1, 4, 9] — НЕ скаляр
loss = y.sum()        # сворачиваем в одно число = 14
loss.backward()       # теперь можно

print(x.grad)         # tensor([2., 4., 6.]) = 2*x поэлементно

Loss — это всегда одно число, и именно поэтому в любом цикле обучения вы увидите loss.backward(), а не predictions.backward().

Сверим с ручным градиентом

Чтобы убедиться, что autograd не «магия», а честная производная, посчитаем тот же градиент в чистом Python двумя способами: аналитически (формула 2x) и численно (определение производной через малое приращение). Этот код запускается:

# f(x) = x^2, хотим производную в точке x0
def f(x):
    return x ** 2

x0 = 3.0

# 1) аналитически: f'(x) = 2x
analytic = 2 * x0

# 2) численно: (f(x+h) - f(x-h)) / (2h)
h = 1e-6
numeric = (f(x0 + h) - f(x0 - h)) / (2 * h)

print("аналитически:", analytic)
print("численно:    ", round(numeric, 4))

Вывод:

аналитически: 6.0
численно:     6.0

Оба способа дают 6 — ровно то, что вернул x.grad в PyTorch. autograd делает по сути аналитический вариант, но автоматически и для любой функции, какой бы сложной она ни была. Численный способ (как здесь) иногда используют, чтобы проверить правильность градиентов, но для обучения он слишком медленный.

Итог

  • .backward() запускает backprop по графу и кладёт градиенты в .grad листьев.
  • autograd сам применяет chain rule: вы пишете только forward, частные производные появляются автоматически.
  • .backward() вызывают на скаляре — поэтому из ошибок всегда делают одно число (loss).
  • Результат совпадает с производной, посчитанной руками: это честная математика, а не приближение.
Проверьте себя
1. Почему loss.backward() вызывают на функции потерь, а не на векторе предсказаний?
AТак быстрее, разницы по сути нет
Bbackward() без аргументов требует скаляр, а loss — это одно число
CПредсказания нельзя дифференцировать
DВектор предсказаний не входит в граф
2. Куда PyTorch кладёт результат после y.backward()?
AВ новый тензор, который возвращает backward()
BВ атрибут .grad листовых тензоров
CВ атрибут .grad_fn
DВ отдельный файл на диске
3. Для y = w * x + b чему равен w.grad после backward (x — данные)?
A1, так как это линейная функция
Bзначению x, потому что dy/dw = x
Cзначению w
D0, потому что b не зависит от w
Поддержать проект