Тензоры: создание, типы, форма, индексация
Разбираем тензор — центральную структуру данных PyTorch — со всех сторон.
Тензор — многомерный массив чисел одного типа; обобщение скаляра (0D), вектора (1D) и матрицы (2D) на любое число измерений.
Создание тензоров
Способов много, и каждый под свою задачу. Из готовых данных — torch.tensor; заполненные нулями/единицами — zeros/ones; случайные — rand (равномерно 0..1) и randn (нормальное распределение).
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # из списка списков -> матрица 2x3
z = torch.zeros(2, 4) # матрица 2x4 из нулей
o = torch.ones(3) # вектор из трёх единиц
r = torch.randn(2, 3) # 2x3 случайных из N(0, 1)
seq = torch.arange(0, 10, 2) # tensor([0, 2, 4, 6, 8])
print(a)
print(a.shape) # torch.Size([2, 3])
Форма (shape) и размерности
Форма — это кортеж длин по каждому измерению. У тензора выше форма (2, 3): два ряда, три столбца. Связанные свойства:
a.shapeилиa.size()— форма, например (2, 3);a.ndim— число измерений (рангов): у матрицы это 2;a.numel()— всего элементов: 6.
Форму нужно держать в голове постоянно: 90% ошибок в PyTorch — это несовпадение размерностей при операциях. Привыкайте читать форму как «(батч, признаки)» или «(батч, каналы, высота, ширина)» для картинок.
Типы данных (dtype)
Все элементы тензора одного типа. По умолчанию из целых списков получается int64, из дробных — float32. Для обучения сетей почти всегда нужен float32: градиенты считаются только для чисел с плавающей точкой.
i = torch.tensor([1, 2, 3]) # dtype = int64
f = torch.tensor([1.0, 2.0, 3.0]) # dtype = float32
print(i.dtype, f.dtype) # torch.int64 torch.float32
# явное приведение типа
f2 = i.float() # int64 -> float32
i2 = f.long() # float32 -> int64
g = torch.zeros(2, 2, dtype=torch.float64) # сразу нужный dtype
Если при подаче в модель вы получите ошибку вида «expected Float but got Long» — почти наверняка забыли перевести целые в float().
Индексация и срезы
Индексация работает как в numpy: по осям через запятую, со срезами через двоеточие. Индексы с нуля.
m = torch.tensor([[10, 11, 12],
[20, 21, 22],
[30, 31, 32]])
print(m[0]) # первый ряд: tensor([10, 11, 12])
print(m[1, 2]) # элемент ряд 1, столбец 2: tensor(22)
print(m[:, 0]) # весь столбец 0: tensor([10, 20, 30])
print(m[0:2, 1:]) # подматрица: tensor([[11, 12], [21, 22]])
Одиночный элемент — это тоже тензор (0D). Чтобы вытащить обычное число Python, используют .item(): m[1, 2].item() вернёт 22 как int.
Смена формы: reshape и view
Часто нужно «перекомпоновать» те же числа в другую форму, не меняя их количество. Это делают view и reshape. Они меняют только то, как PyTorch читает память, — сами данные остаются.
t = torch.arange(12) # 12 чисел: 0..11, форма (12,)
a = t.view(3, 4) # та же память, форма (3, 4)
b = t.reshape(2, 6) # форма (2, 6)
c = t.view(2, -1) # -1 = «посчитай сам»: получится (2, 6)
print(a.shape, b.shape, c.shape) # (3,4) (2,6) (2,6)
Удобный приём -1 означает «вычисли эту размерность автоматически из общего числа элементов». Разница между методами тонкая: view требует, чтобы память была непрерывной, и работает быстрее, но иногда падает; reshape надёжнее — сделает копию, если нужно. На старте используйте reshape, если сомневаетесь.
Ещё две частые операции формы: squeeze() убирает размерности длины 1, а unsqueeze(dim) добавляет новую ось — это нужно, например, чтобы превратить один пример формы (3,) в «батч из одного» формы (1, 3).
Итог
- Тензор — многомерный массив одного
dtype; для обучения нуженfloat32. shape/size()— форма; следить за ней критично, это источник большинства ошибок.- Индексация и срезы — как в numpy;
.item()достаёт число Python из тензора-скаляра. view/reshapeменяют форму без изменения данных;-1— «посчитай размерность сам».