Инференс, деплой, отладка и профилирование
Доводим модель до продакшена: правильный инференс, экспорт и поиск узких мест.
Инференс — применение обученной модели к новым данным для получения предсказаний, без обучения.
Правильный инференс
В проде модель только предсказывает, и важно делать это эффективно и корректно. Собираем правила из всего курса в один блок: режим оценки, без графа, нужное устройство, перевод в numpy для выдачи.
import torch
model.eval() # 1. режим оценки (dropout off)
with torch.no_grad(): # 2. без построения графа
x = preprocess(raw_input) # та же предобработка, что при обучении!
x = x.to(device) # 3. на устройство модели
logits = model(x)
probs = torch.softmax(logits, dim=1) # 4. логиты -> вероятности
pred = probs.argmax(dim=1)
result = pred.cpu().numpy() # 5. в numpy для выдачи
Критичная деталь: на инференсе предобработку входа делают точно так же, как при обучении (тот же resize, та же нормализация). Расхождение здесь — частая причина, почему модель «работала в обучении, но врёт в проде».
Экспорт модели
Запускать модель в проде через сохранённый state_dict и Python-код можно, но иногда нужен переносимый формат, не зависящий от вашего кода. Два главных пути:
| Формат | Зачем |
| TorchScript | запуск без Python-кода модели, в т.ч. на C++ |
| ONNX | универсальный формат, запуск в других рантаймах |
# TorchScript: «заморозить» модель в самодостаточный файл
scripted = torch.jit.script(model)
scripted.save("model_scripted.pt")
# ONNX: экспорт для сторонних рантаймов
torch.onnx.export(model, example_input, "model.onnx")
Оба варианта дают артефакт, который можно загрузить без вашего класса Net. Для большинства веб-сервисов хватает и обычной загрузки state_dict внутри Python-сервера (FastAPI/Flask) — экспорт нужен, когда Python в проде нежелателен.
Отладка моделей
Динамический граф PyTorch — большое преимущество при отладке: модель это обычный код. Поэтому работают привычные приёмы:
print(x.shape)внутриforward— самый быстрый способ найти, где ломаются размерности;- точка останова в отладчике прямо в
forward— можно посмотреть тензоры на каждом шаге; - проверка на
NaN:torch.isnan(loss)— если loss стал NaN, обычно виноват слишком большойlrили деление на ноль в данных.
Профилирование: где тратится время
Когда обучение медленное, нужно понять, что именно тормозит. PyTorch даёт встроенный профилировщик, который замеряет время операций на CPU и GPU:
from torch.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU]) as prof:
model(example_input)
# таблица самых тяжёлых операций
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
Профилировщик покажет, какие операции съедают время. Частые виновники медленного обучения, кстати, не в самой сети: это узкий DataLoader (увеличьте num_workers, чтобы готовить батчи параллельно) и лишние переносы между CPU и GPU. Сначала измеряйте, потом оптимизируйте — гадать бесполезно.
Итог
- Инференс:
eval()+no_grad()+ та же предобработка, что при обучении. - Для переносимости экспортируют в TorchScript (запуск без Python) или ONNX (сторонние рантаймы).
- Отлаживают как обычный код:
print(shape), отладчик, проверка наNaN. - Узкое место ищут профилировщиком; частые виновники — DataLoader и лишние переносы CPU↔GPU.