Сборка сети: __init__ и forward

Пишем свою первую полноценную сеть как класс — наследник nn.Module.

Своя модель в PyTorch — это класс, наследующий nn.Module, где слои объявлены в __init__, а их применение описано в forward.

Шаблон любой модели

Почти все сети в PyTorch выглядят одинаково. Два обязательных метода: __init__ (что у нас за слои) и forward (как данные через них проходят). Вот многослойный перцептрон (MLP) для классификации:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()                  # обязательно!
        self.fc1 = nn.Linear(784, 128)      # вход 784 -> 128
        self.relu = nn.ReLU()               # активация
        self.fc2 = nn.Linear(128, 10)       # 128 -> 10 классов

    def forward(self, x):
        x = self.fc1(x)     # линейный слой
        x = self.relu(x)    # нелинейность
        x = self.fc2(x)     # выходной слой
        return x

model = Net()
print(model)

Разберём по частям.

Метод __init__: объявляем слои

Здесь мы создаём слои и сохраняем их в self. Это не просто переменные: когда вы присваиваете self.fc1 = nn.Linear(...), nn.Module автоматически регистрирует этот слой как подмодуль и подхватывает его параметры. Благодаря этому позже model.parameters() вернёт веса всех слоёв скопом.

Строка super().__init__() обязательна — она запускает инициализацию базового nn.Module, без которой регистрация параметров не работает. Забыть её — частая ошибка, дающая загадочные сбои.

Метод forward: описываем поток данных

В forward мы пишем, как вход x проходит через слои — строка за строкой, как обычный Python. Здесь и проявляется динамический граф: можно вставлять if, циклы, print для отладки. Главное — вернуть результат.

Важно: forward не вызывают напрямую. Мы запускаем модель как model(x), а PyTorch сам вызовет forward внутри. Это та же идиома, что и для одиночного слоя.

model = Net()
x = torch.randn(16, 784)   # батч из 16 картинок 28x28, развёрнутых в 784
out = model(x)             # НЕ model.forward(x)
print(out.shape)           # torch.Size([16, 10]) — по 10 логитов на пример

Что значит этот граф размерностей

Проследите формы насквозь: вход (16, 784) → fc1 → (16, 128) → relu (форма та же) → fc2 → (16, 10). Батч 16 идёт сквозь всю сеть нетронутым, меняется только размер признаков. Выход (16, 10) — это по 10 чисел (логитов) на каждый из 16 примеров; дальше их превратят в вероятности 10 классов.

Почему именно класс

Класс удобен тем, что инкапсулирует и архитектуру, и состояние (веса). Один раз описали — создавайте сколько угодно экземпляров. А ещё nn.Module даёт бесплатно кучу методов: .parameters(), .to(device), .train()/.eval(), сохранение весов. Всё это работает именно потому, что слои зарегистрированы через присваивание в __init__.

Итог

  • Своя сеть = класс, наследующий nn.Module, с методами __init__ и forward.
  • В __init__ объявляют слои в self — так они регистрируются и их параметры видны модели.
  • super().__init__() обязателен, иначе регистрация параметров сломается.
  • В forward описывают поток данных; запускают модель как model(x), а не forward напрямую.
Проверьте себя
1. Где в классе-наследнике nn.Module объявляют слои?
AВ методе forward
BВ методе __init__, присваивая их в self
CВ отдельном методе build
DВне класса, на уровне модуля
2. Почему нельзя забывать super().__init__() в своей модели?
AБез неё forward не вызовется
BОна запускает инициализацию nn.Module, без которой регистрация параметров не работает
CОна создаёт оптимизатор
DОна переносит модель на GPU
3. Если вход батча имеет форму (16, 784), какой будет форма выхода сети fc1(784->128), relu, fc2(128->10)?
A(16, 784)
B(16, 128)
C(16, 10)
D(10, 16)
Поддержать проект