Операции и броадкастинг

Учимся считать с тензорами: арифметика, матричное умножение и магия броадкастинга.

Броадкастинг — автоматическое «растягивание» тензоров разной формы до совместимой, чтобы поэлементная операция стала возможной без явного копирования.

Поэлементные операции

Арифметика между тензорами одинаковой формы работает поэлементно — как с числами, только сразу со всем массивом:

import torch

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([10.0, 20.0, 30.0])

print(a + b)     # tensor([11., 22., 33.])
print(a * b)     # tensor([10., 40., 90.])  поэлементно!
print(a ** 2)    # tensor([1., 4., 9.])
print(torch.exp(a))  # экспонента от каждого элемента

Важно: a * b — это поэлементное умножение, а не матричное. Многие новички путают.

Матричное умножение

Для настоящего матричного умножения (строка на столбец) есть оператор @ и функция torch.matmul. Именно оно лежит в основе слоёв нейросети.

A = torch.randn(2, 3)   # форма (2, 3)
B = torch.randn(3, 4)   # форма (3, 4)

C = A @ B               # (2, 3) @ (3, 4) -> (2, 4)
print(C.shape)          # torch.Size([2, 4])

Правило размерностей: внутренние размеры должны совпадать. (2, 3) и (3, 4) — тройки совпали, результат (2, 4). Если внутренние размеры не совпадут, будет ошибка — это самая частая причина падений при сборке сети.

Агрегации по осям

Суммы, средние, максимумы можно считать по всему тензору или вдоль конкретной оси через аргумент dim. Понимание dim критично для работы с батчами.

m = torch.tensor([[1.0, 2.0, 3.0],
                  [4.0, 5.0, 6.0]])   # форма (2, 3)

print(m.sum())          # 21.0  — по всему тензору
print(m.sum(dim=0))     # tensor([5., 7., 9.])  — вдоль строк, осталось 3 столбца
print(m.sum(dim=1))     # tensor([6., 15.])     — вдоль столбцов, осталось 2 строки
print(m.mean(dim=0))    # среднее по каждому столбцу

Запомните правило: dim=0 «схлопывает» нулевую ось (идём по строкам сверху вниз), dim=1 — первую (идём по столбцам слева направо). Размерность, по которой агрегируем, исчезает из формы результата.

Броадкастинг: операции с разными формами

Что если формы не совпадают? PyTorch пытается их согласовать по правилам броадкастинга, выравнивая формы справа налево: размерности совместимы, если они равны или одна из них равна 1. Размерность длины 1 «растягивается» до нужной.

# прибавить число к каждому элементу — простейший броадкастинг
a = torch.tensor([1.0, 2.0, 3.0])
print(a + 10)        # tensor([11., 12., 13.])

# матрица (2, 3) + вектор (3,)  ->  вектор применяется к каждой строке
m = torch.tensor([[1.0, 2.0, 3.0],
                  [4.0, 5.0, 6.0]])
bias = torch.tensor([100.0, 200.0, 300.0])  # форма (3,)
print(m + bias)
# tensor([[101., 202., 303.],
#         [104., 205., 306.]])

Это не игрушка: именно так слой Linear прибавляет вектор сдвига (bias) сразу ко всему батчу примеров. Без броадкастинга пришлось бы вручную копировать bias под каждую строку.

Операции на месте (in-place)

Методы с подчёркиванием на конце меняют тензор на месте: a.add_(1) прибавит 1 прямо в a, без создания нового тензора. Это экономит память, но с такими операциями нужно быть осторожным при обучении — они могут «сломать» граф для autograd. На старте предпочитайте обычные операции, создающие новый тензор.

Итог

  • * — поэлементное умножение; матричное — это @ или torch.matmul.
  • В матричном умножении внутренние размерности обязаны совпадать.
  • dim в агрегациях указывает ось, которая «схлопывается» и исчезает из формы.
  • Броадкастинг растягивает размерности длины 1; так bias прибавляется ко всему батчу.
Проверьте себя
1. Чем отличается a * b от a @ b для тензоров?
AЭто одно и то же — оба матричное умножение
B* — поэлементное умножение, @ — матричное
C* — матричное, @ — поэлементное
D@ работает только для векторов, * — для матриц
2. Что вернёт m.sum(dim=0) для матрицы формы (2, 3)?
AСкаляр — сумму всех шести элементов
BВектор длины 2 — суммы по строкам
CВектор длины 3 — суммы по столбцам
DМатрицу той же формы (2, 3)
3. Почему сложение матрицы (2, 3) с вектором (3,) работает?
APyTorch обрезает матрицу до формы вектора
BИз-за броадкастинга: вектор растягивается на каждую строку
CЭто ошибка, такое сложение не выполнится
DВектор сначала превращается в матрицу вручную
Поддержать проект