Сборка сети: __init__ и forward
Пишем свою первую полноценную сеть как класс — наследник nn.Module.
Своя модель в PyTorch — это класс, наследующий
nn.Module, где слои объявлены в__init__, а их применение описано вforward.
Шаблон любой модели
Почти все сети в PyTorch выглядят одинаково. Два обязательных метода: __init__ (что у нас за слои) и forward (как данные через них проходят). Вот многослойный перцептрон (MLP) для классификации:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__() # обязательно!
self.fc1 = nn.Linear(784, 128) # вход 784 -> 128
self.relu = nn.ReLU() # активация
self.fc2 = nn.Linear(128, 10) # 128 -> 10 классов
def forward(self, x):
x = self.fc1(x) # линейный слой
x = self.relu(x) # нелинейность
x = self.fc2(x) # выходной слой
return x
model = Net()
print(model)
Разберём по частям.
Метод __init__: объявляем слои
Здесь мы создаём слои и сохраняем их в self. Это не просто переменные: когда вы присваиваете self.fc1 = nn.Linear(...), nn.Module автоматически регистрирует этот слой как подмодуль и подхватывает его параметры. Благодаря этому позже model.parameters() вернёт веса всех слоёв скопом.
Строка super().__init__() обязательна — она запускает инициализацию базового nn.Module, без которой регистрация параметров не работает. Забыть её — частая ошибка, дающая загадочные сбои.
Метод forward: описываем поток данных
В forward мы пишем, как вход x проходит через слои — строка за строкой, как обычный Python. Здесь и проявляется динамический граф: можно вставлять if, циклы, print для отладки. Главное — вернуть результат.
Важно: forward не вызывают напрямую. Мы запускаем модель как model(x), а PyTorch сам вызовет forward внутри. Это та же идиома, что и для одиночного слоя.
model = Net()
x = torch.randn(16, 784) # батч из 16 картинок 28x28, развёрнутых в 784
out = model(x) # НЕ model.forward(x)
print(out.shape) # torch.Size([16, 10]) — по 10 логитов на пример
Что значит этот граф размерностей
Проследите формы насквозь: вход (16, 784) → fc1 → (16, 128) → relu (форма та же) → fc2 → (16, 10). Батч 16 идёт сквозь всю сеть нетронутым, меняется только размер признаков. Выход (16, 10) — это по 10 чисел (логитов) на каждый из 16 примеров; дальше их превратят в вероятности 10 классов.
Почему именно класс
Класс удобен тем, что инкапсулирует и архитектуру, и состояние (веса). Один раз описали — создавайте сколько угодно экземпляров. А ещё nn.Module даёт бесплатно кучу методов: .parameters(), .to(device), .train()/.eval(), сохранение весов. Всё это работает именно потому, что слои зарегистрированы через присваивание в __init__.
Итог
- Своя сеть = класс, наследующий
nn.Module, с методами__init__иforward. - В
__init__объявляют слои вself— так они регистрируются и их параметры видны модели. super().__init__()обязателен, иначе регистрация параметров сломается.- В
forwardописывают поток данных; запускают модель какmodel(x), а неforwardнапрямую.