requires_grad и граф вычислений

Понимаем, как и зачем PyTorch запоминает все операции над тензором.

Граф вычислений — записанная последовательность операций от исходных тензоров к результату; по нему autograd потом «прокручивает» backprop назад.

Флаг requires_grad

По умолчанию тензоры градиентов не требуют — PyTorch не тратит ресурсы на запоминание операций. Но если тензор — это обучаемый параметр (вес сети), мы помечаем его флагом requires_grad=True. С этого момента PyTorch начинает следить за всеми операциями, в которых тензор участвует, и строит из них граф.

import torch

x = torch.tensor([2.0, 3.0], requires_grad=True)
print(x.requires_grad)   # True

w = torch.tensor([1.0, 1.0])    # обычный тензор
print(w.requires_grad)          # False

Связь с теорией прямая: в backprop градиенты нужны именно для весов, чтобы их обновлять. Входные данные обновлять не надо — поэтому им requires_grad не ставят.

Как строится граф

Каждая операция над тензором с requires_grad=True порождает новый тензор, который помнит, какой операцией он получен. Эта «память» хранится в атрибуте grad_fn — функция, умеющая посчитать обратный шаг для данной операции.

x = torch.tensor(2.0, requires_grad=True)

y = x * 3        # y помнит: «я получен умножением»
z = y + 1        # z помнит: «я получен сложением»

print(y.grad_fn)   # <MulBackward0>
print(z.grad_fn)   # <AddBackward0>

Так выстраивается цепочка: z <- (+1) <- y <- (*3) <- x. Это и есть граф вычислений — динамически, прямо во время выполнения. Когда мы попросим градиент, autograd пройдёт по этой цепочке в обратную сторону, применяя на каждом шаге правило дифференцирования (это и есть chain rule из backprop).

Листовые тензоры

Тензоры, которые вы создали сами (а не получили операцией) и которым поставили requires_grad=True, называются листовыми (leaf). Именно для них в итоге накопится градиент в атрибуте .grad. Промежуточные тензоры (как y, z выше) — внутренние узлы графа, их градиенты по умолчанию не сохраняются, они нужны лишь для проброса дальше.

Тип тензораПримерХранит .grad?
Листовой, requires_grad=Trueвес сети wда
Промежуточный (узел графа)y = x * 3нет (по умолчанию)
Обычный, requires_grad=Falseвходные данныене участвует

Заразность requires_grad

Флаг «распространяется»: если хотя бы один вход операции требует градиент, результат тоже будет его требовать. Логично — раз результат зависит от обучаемого параметра, через него должен проходить градиент.

w = torch.tensor(2.0, requires_grad=True)   # параметр
x = torch.tensor(5.0)                        # данные, requires_grad=False

y = w * x        # зависит от w -> y.requires_grad == True
print(y.requires_grad)   # True

Зачем всё это

Граф — это и есть способ автоматизировать backprop. Вы пишете только прямой проход (как считается loss из весов), а PyTorch сам запоминает путь и умеет пройти его назад. Никаких ручных производных: цепочку grad_fn autograd разворачивает за вас. В следующем уроке мы наконец «нажмём кнопку» — вызовем .backward() и получим градиенты.

Итог

  • requires_grad=True включает слежение за операциями над тензором — обычно ставится весам.
  • Каждая операция записывается в граф; результат помнит её через grad_fn.
  • Граф строится динамически, во время forward; это автоматизированный backprop.
  • Листовые тензоры с requires_grad=True накопят градиент в .grad; промежуточные — нет.
Проверьте себя
1. Каким тензорам обычно ставят requires_grad=True?
AВходным данным, чтобы их тоже обновлять
BОбучаемым параметрам (весам), которые нужно двигать по градиенту
CВсем тензорам без исключения для надёжности
DТолько тензорам на GPU
2. Что хранит атрибут grad_fn у тензора?
AСамо значение градиента
BУстройство, на котором лежит тензор
CФункцию для обратного шага той операции, которой получен тензор
DИсторию всех значений тензора
3. Что произойдёт с requires_grad у результата операции, если один из входов имеет requires_grad=True?
AРезультат тоже будет requires_grad=True
BРезультат всегда requires_grad=False
CВозникнет ошибка о несовместимых флагах
DФлаг скопируется только если оба входа требуют градиент
Поддержать проект