CNN для классификации изображений

Собираем настоящую свёрточную сеть для распознавания картинок и понимаем её устройство.

CNN (свёрточная сеть) — архитектура, где слои-свёртки сканируют изображение небольшими окнами, выделяя признаки от краёв до целых объектов.

Почему для картинок не годится Linear

Картинка 28×28 — это 784 числа. Подать её в nn.Linear можно (мы так и делали), но плохо: полносвязный слой не знает, что соседние пиксели связаны, и тратит огромное число весов. Свёрточный слой устроен умнее: он скользит маленьким окном (например 3×3) по всей картинке с одними и теми же весами, выделяя локальные узоры — края, углы, текстуры. Это и экономит параметры, и учитывает структуру изображения.

Форма данных-изображений

Картинки в PyTorch имеют форму (батч, каналы, высота, ширина). Например (32, 1, 28, 28) — батч из 32 чёрно-белых картинок 28×28 (1 канал); у цветных каналов 3 (RGB). Эту четырёхмерную форму держите в голове — она ключ к пониманию свёрток.

Слои свёрточной сети

СлойЧто делает
nn.Conv2d(in_ch, out_ch, k)свёртка окном k×k, выделяет признаки
nn.ReLUнелинейность после свёртки
nn.MaxPool2d(2)уменьшает картинку вдвое, оставляя главное
nn.Linearфинальный слой, выдаёт логиты классов

Классическая структура CNN

Типичная маленькая сеть для цифр MNIST: пара блоков «свёртка → relu → пулинг», затем «выпрямление» в вектор и полносвязный хвост. Вот она целиком:

import torch
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)   # 1 канал -> 16
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)  # 16 -> 32
        self.pool = nn.MaxPool2d(2)                    # уменьшает в 2 раза
        self.relu = nn.ReLU()
        self.fc = nn.Linear(32 * 7 * 7, 10)            # хвост -> 10 классов

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))   # 28x28 -> 14x14
        x = self.pool(self.relu(self.conv2(x)))   # 14x14 -> 7x7
        x = x.view(x.size(0), -1)                  # выпрямляем в вектор
        x = self.fc(x)                             # логиты 10 классов
        return x

model = CNN()

Разбор forward по формам

Проследим путь картинки через сеть (вход — батч (N, 1, 28, 28)):

  • conv1 + pool: каналы 1→16, размер 28×28 → 14×14. Форма (N, 16, 14, 14).
  • conv2 + pool: каналы 16→32, размер 14×14 → 7×7. Форма (N, 32, 7, 7).
  • x.view(x.size(0), -1): «выпрямляем» каждую картинку в вектор длины 32×7×7 = 1568. Форма (N, 1568).
  • fc: 1568 → 10. Форма (N, 10) — логиты классов.

Ключевой момент — строка x.view(x.size(0), -1): она превращает «объёмные» признаки свёрток в плоский вектор, чтобы их принял полносвязный слой. x.size(0) — это батч (сохраняем), -1 — «остальное в одну размерность». Число входов fc (32×7×7) обязано совпасть с тем, что выдали свёртки, иначе ошибка размерностей.

Обучается так же

Самое приятное: CNN обучают тем же циклом из пяти шагов и той же CrossEntropyLoss, что и обычную сеть. Поменялась только архитектура внутри — а forward → loss → zero_grad → backward → step остаются нетронутыми. Всё, что вы выучили о цикле обучения, работает здесь без изменений.

Итог

  • Свёрточный слой сканирует картинку окном с общими весами — экономит параметры и видит структуру.
  • Изображения имеют форму (батч, каналы, высота, ширина).
  • Типичная CNN: блоки «conv → relu → pool», затем view в вектор и Linear-хвост.
  • Обучается тем же циклом и тем же CrossEntropyLoss, что и полносвязная сеть.
Проверьте себя
1. Чем свёрточный слой выгоднее полносвязного для изображений?
AОн всегда точнее на любых данных
BОн сканирует картинку окном с общими весами: меньше параметров и учитывается структура
CОн не требует функции активации
DОн работает только на GPU
2. Какую форму имеет батч изображений в PyTorch?
A(батч, высота, ширина)
B(батч, каналы, высота, ширина)
C(высота, ширина, каналы)
D(батч, признаки)
3. Зачем перед полносвязным слоем делают x.view(x.size(0), -1)?
AЧтобы перенести данные на GPU
BЧтобы выпрямить объёмные признаки свёрток в плоский вектор для Linear
CЧтобы перемешать батч
DЧтобы добавить нелинейность
Поддержать проект