Классификация изображений: как работает и как мерить

Классификация — базовая задача CV: «что на картинке?». Разберём, как её обучают и как честно оценивают.

Классификация изображений — задача присвоить всему изображению один ярлык из заранее заданного набора классов.

Как обучается классификатор

Нужен размеченный датасет: много картинок, у каждой — правильный класс. Обучение идёт циклом:

  1. Сеть смотрит на картинку и выдаёт вероятности классов (через softmax).
  2. Сравниваем прогноз с истинным ярлыком — считаем ошибку (loss).
  3. Алгоритм обратного распространения подправляет веса, чтобы в следующий раз ошибиться меньше.
  4. Повторяем на тысячах картинок много эпох, пока качество не перестанет расти.

Качество всегда проверяют на отложенных данных, которые сеть не видела при обучении — иначе можно принять зубрёжку за понимание.

Главная метрика: accuracy

Самая простая метрика — точность (accuracy): доля верно классифицированных картинок.

predictions = ["кошка", "собака", "кошка", "птица", "собака"]
truth       = ["кошка", "собака", "собака", "птица", "собака"]

correct = sum(p == t for p, t in zip(predictions, truth))
accuracy = correct / len(truth)
print(f"Верно: {correct} из {len(truth)}")
print(f"Accuracy: {accuracy:.0%}")

Вывод:

Верно: 4 из 5
Accuracy: 80%

Где accuracy обманывает

Точность коварна при дисбалансе классов. Пусть на снимках болезнь встречается в 1% случаев. Модель, которая всегда отвечает «здоров», получит 99% accuracy — и при этом не найдёт ни одного больного! Поэтому используют:

  • Precision (точность): из тех, кого назвали больными, сколько действительно больны.
  • Recall (полнота): из всех больных, скольких нашли.
  • F1: баланс между precision и recall.
  • Матрица ошибок (confusion matrix): кого с кем путает модель.

Top-5 accuracy

Когда классов тысячи (как в ImageNet — 1000), мерить «угадал ли единственный ответ» слишком жёстко: кошку легко спутать с похожей породой. Поэтому считают top-5 accuracy: засчитывается, если верный класс попал в пятёрку самых вероятных. Именно top-5 на ImageNet был главным мерилом гонки архитектур.

# top-5: верный класс среди 5 самых вероятных?
top5_predictions = ["собака", "волк", "кошка", "лиса", "кролик"]
true_label = "кошка"
hit = true_label in top5_predictions
print("Топ-5 прогнозов:", top5_predictions)
print("Верный класс в топ-5:", hit)

Вывод:

Топ-5 прогнозов: ['собака', 'волк', 'кошка', 'лиса', 'кролик']
Верный класс в топ-5: True

Итог

  • Классификация: один ярлык на всё изображение; обучается на размеченном датасете.
  • Качество проверяют на отложенных данных, accuracy — доля верных.
  • При дисбалансе классов accuracy обманчива — нужны precision, recall, F1.
  • При тысячах классов мерят top-5 accuracy.
Проверьте себя
1. Почему accuracy обманчива при сильном дисбалансе классов?
AОна всегда занижена
BМодель, всегда предсказывающая частый класс, получит высокую accuracy, но провалит редкий важный класс
CОна требует цветных картинок
DAccuracy нельзя посчитать
2. Что засчитывает метрика top-5 accuracy?
AТолько если верный класс — самый вероятный
BЕсли верный класс попал в пятёрку самых вероятных предсказаний
CЕсли ошибок меньше пяти
DЕсли пять картинок классифицированы верно
3. На каких данных нужно оценивать качество классификатора?
AНа тех же, на которых обучали
BНа отложенных данных, которые сеть не видела при обучении
CНа самых лёгких примерах
DОценивать необязательно
Поддержать проект