Тензоры: создание, типы, форма, индексация

Разбираем тензор — центральную структуру данных PyTorch — со всех сторон.

Тензор — многомерный массив чисел одного типа; обобщение скаляра (0D), вектора (1D) и матрицы (2D) на любое число измерений.

Создание тензоров

Способов много, и каждый под свою задачу. Из готовых данных — torch.tensor; заполненные нулями/единицами — zeros/ones; случайные — rand (равномерно 0..1) и randn (нормальное распределение).

import torch

a = torch.tensor([[1, 2, 3], [4, 5, 6]])   # из списка списков -> матрица 2x3
z = torch.zeros(2, 4)                       # матрица 2x4 из нулей
o = torch.ones(3)                           # вектор из трёх единиц
r = torch.randn(2, 3)                       # 2x3 случайных из N(0, 1)
seq = torch.arange(0, 10, 2)               # tensor([0, 2, 4, 6, 8])

print(a)
print(a.shape)   # torch.Size([2, 3])

Форма (shape) и размерности

Форма — это кортеж длин по каждому измерению. У тензора выше форма (2, 3): два ряда, три столбца. Связанные свойства:

  • a.shape или a.size() — форма, например (2, 3);
  • a.ndim — число измерений (рангов): у матрицы это 2;
  • a.numel() — всего элементов: 6.

Форму нужно держать в голове постоянно: 90% ошибок в PyTorch — это несовпадение размерностей при операциях. Привыкайте читать форму как «(батч, признаки)» или «(батч, каналы, высота, ширина)» для картинок.

Типы данных (dtype)

Все элементы тензора одного типа. По умолчанию из целых списков получается int64, из дробных — float32. Для обучения сетей почти всегда нужен float32: градиенты считаются только для чисел с плавающей точкой.

i = torch.tensor([1, 2, 3])        # dtype = int64
f = torch.tensor([1.0, 2.0, 3.0])  # dtype = float32
print(i.dtype, f.dtype)            # torch.int64 torch.float32

# явное приведение типа
f2 = i.float()                     # int64 -> float32
i2 = f.long()                      # float32 -> int64
g = torch.zeros(2, 2, dtype=torch.float64)  # сразу нужный dtype

Если при подаче в модель вы получите ошибку вида «expected Float but got Long» — почти наверняка забыли перевести целые в float().

Индексация и срезы

Индексация работает как в numpy: по осям через запятую, со срезами через двоеточие. Индексы с нуля.

m = torch.tensor([[10, 11, 12],
                  [20, 21, 22],
                  [30, 31, 32]])

print(m[0])        # первый ряд: tensor([10, 11, 12])
print(m[1, 2])     # элемент ряд 1, столбец 2: tensor(22)
print(m[:, 0])     # весь столбец 0: tensor([10, 20, 30])
print(m[0:2, 1:])  # подматрица: tensor([[11, 12], [21, 22]])

Одиночный элемент — это тоже тензор (0D). Чтобы вытащить обычное число Python, используют .item(): m[1, 2].item() вернёт 22 как int.

Смена формы: reshape и view

Часто нужно «перекомпоновать» те же числа в другую форму, не меняя их количество. Это делают view и reshape. Они меняют только то, как PyTorch читает память, — сами данные остаются.

t = torch.arange(12)          # 12 чисел: 0..11, форма (12,)

a = t.view(3, 4)              # та же память, форма (3, 4)
b = t.reshape(2, 6)          # форма (2, 6)
c = t.view(2, -1)            # -1 = «посчитай сам»: получится (2, 6)

print(a.shape, b.shape, c.shape)  # (3,4) (2,6) (2,6)

Удобный приём -1 означает «вычисли эту размерность автоматически из общего числа элементов». Разница между методами тонкая: view требует, чтобы память была непрерывной, и работает быстрее, но иногда падает; reshape надёжнее — сделает копию, если нужно. На старте используйте reshape, если сомневаетесь.

Ещё две частые операции формы: squeeze() убирает размерности длины 1, а unsqueeze(dim) добавляет новую ось — это нужно, например, чтобы превратить один пример формы (3,) в «батч из одного» формы (1, 3).

Итог

  • Тензор — многомерный массив одного dtype; для обучения нужен float32.
  • shape/size() — форма; следить за ней критично, это источник большинства ошибок.
  • Индексация и срезы — как в numpy; .item() достаёт число Python из тензора-скаляра.
  • view/reshape меняют форму без изменения данных; -1 — «посчитай размерность сам».
Проверьте себя
1. Какой dtype получится у torch.tensor([1.0, 2.0, 3.0])?
Aint64, потому что это целые индексы
Bfloat32 — стандартный тип для дробных значений
Cfloat64, как в numpy по умолчанию
Dbool, потому что значения ненулевые
2. Что означает -1 в вызове t.view(2, -1)?
AРазвернуть тензор в обратном порядке
BУдалить последнюю размерность
CВычислить эту размерность автоматически из числа элементов
DСделать размерность отрицательной длины
3. Как из тензора-скаляра m[1, 2] получить обычное число Python?
AВызвать m[1, 2].item()
BВызвать m[1, 2].float()
CПрименить int(m) ко всей матрице
DИспользовать m[1, 2].shape
Поддержать проект