Инференс, деплой, отладка и профилирование

Доводим модель до продакшена: правильный инференс, экспорт и поиск узких мест.

Инференс — применение обученной модели к новым данным для получения предсказаний, без обучения.

Правильный инференс

В проде модель только предсказывает, и важно делать это эффективно и корректно. Собираем правила из всего курса в один блок: режим оценки, без графа, нужное устройство, перевод в numpy для выдачи.

import torch

model.eval()                          # 1. режим оценки (dropout off)
with torch.no_grad():                # 2. без построения графа
    x = preprocess(raw_input)        # та же предобработка, что при обучении!
    x = x.to(device)                 # 3. на устройство модели
    logits = model(x)
    probs = torch.softmax(logits, dim=1)   # 4. логиты -> вероятности
    pred = probs.argmax(dim=1)
    result = pred.cpu().numpy()      # 5. в numpy для выдачи

Критичная деталь: на инференсе предобработку входа делают точно так же, как при обучении (тот же resize, та же нормализация). Расхождение здесь — частая причина, почему модель «работала в обучении, но врёт в проде».

Экспорт модели

Запускать модель в проде через сохранённый state_dict и Python-код можно, но иногда нужен переносимый формат, не зависящий от вашего кода. Два главных пути:

ФорматЗачем
TorchScriptзапуск без Python-кода модели, в т.ч. на C++
ONNXуниверсальный формат, запуск в других рантаймах
# TorchScript: «заморозить» модель в самодостаточный файл
scripted = torch.jit.script(model)
scripted.save("model_scripted.pt")

# ONNX: экспорт для сторонних рантаймов
torch.onnx.export(model, example_input, "model.onnx")

Оба варианта дают артефакт, который можно загрузить без вашего класса Net. Для большинства веб-сервисов хватает и обычной загрузки state_dict внутри Python-сервера (FastAPI/Flask) — экспорт нужен, когда Python в проде нежелателен.

Отладка моделей

Динамический граф PyTorch — большое преимущество при отладке: модель это обычный код. Поэтому работают привычные приёмы:

  • print(x.shape) внутри forward — самый быстрый способ найти, где ломаются размерности;
  • точка останова в отладчике прямо в forward — можно посмотреть тензоры на каждом шаге;
  • проверка на NaN: torch.isnan(loss) — если loss стал NaN, обычно виноват слишком большой lr или деление на ноль в данных.

Профилирование: где тратится время

Когда обучение медленное, нужно понять, что именно тормозит. PyTorch даёт встроенный профилировщик, который замеряет время операций на CPU и GPU:

from torch.profiler import profile, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU]) as prof:
    model(example_input)

# таблица самых тяжёлых операций
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

Профилировщик покажет, какие операции съедают время. Частые виновники медленного обучения, кстати, не в самой сети: это узкий DataLoader (увеличьте num_workers, чтобы готовить батчи параллельно) и лишние переносы между CPU и GPU. Сначала измеряйте, потом оптимизируйте — гадать бесполезно.

Итог

  • Инференс: eval() + no_grad() + та же предобработка, что при обучении.
  • Для переносимости экспортируют в TorchScript (запуск без Python) или ONNX (сторонние рантаймы).
  • Отлаживают как обычный код: print(shape), отладчик, проверка на NaN.
  • Узкое место ищут профилировщиком; частые виновники — DataLoader и лишние переносы CPU↔GPU.
Проверьте себя
1. Почему на инференсе важно повторять ту же предобработку входа, что и при обучении?
AИначе модель не загрузится
BРасхождение предобработки даёт неверные предсказания в проде при хорошем обучении
CЭто ускоряет инференс
DБез этого не работает no_grad
2. Для чего экспортируют модель в TorchScript или ONNX?
AЧтобы уменьшить число параметров
BЧтобы получить переносимый артефакт, запускаемый без Python-кода модели
CЧтобы ускорить обучение
DЧтобы включить dropout в проде
3. Что чаще всего оказывается узким местом медленного обучения по данным профилировщика?
AСам backward всегда самый медленный
BУзкий DataLoader и лишние переносы между CPU и GPU
CСлишком маленькая модель
DИспользование eval() режима
Поддержать проект