Сохранение, загрузка и инференс
Урок про финальный этап: сохранить обученную модель, загрузить её и делать предсказания в продакшене.
Инференс — это применение уже обученной модели к новым данным для получения предсказаний, без дальнейшего обновления весов.
Обучение — дорогая разовая операция, инференс — то, что модель делает в проде тысячи раз в секунду. Поэтому обученную модель сохраняют на диск, а потом загружают там, где нужно предсказывать. Разберём весь цикл.
Сохранение модели
Современный формат Keras — один файл .keras, в нём и архитектура, и веса. Сохранение и загрузка (требует TF):
model.save("mnist_model.keras") # сохранить всё в один файл
# позже, в другом скрипте:
import tensorflow as tf
loaded = tf.keras.models.load_model("mnist_model.keras")
loaded.summary()Инференс
Загруженная модель готова предсказывать сразу — без compile и fit:
import numpy as np
# new_image — картинка формы (28, 28), нормализованная
batch = new_image.reshape(1, 28, 28) # добавили ось батча
probs = loaded.predict(batch)
print("вероятности:", probs[0])
print("предсказанная цифра:", np.argmax(probs[0]))Только веса
Иногда сохраняют лишь веса (если архитектура описана в коде):
model.save_weights("weights.weights.h5")
# восстановить:
new_model.load_weights("weights.weights.h5")argmax: из вероятностей в класс
Инференс классификатора почти всегда заканчивается argmax по выходу softmax. Сделаем это на чистом Python — код запускается:
probs = [0.02, 0.01, 0.90, 0.04, 0.03]
best_index = probs.index(max(probs))
print("индекс класса:", best_index)
print("уверенность:", probs[best_index])Вывод:
индекс класса: 2 уверенность: 0.9
Форматы — памятка
| Способ | Что внутри | Когда |
.keras | архитектура + веса | обычное сохранение модели |
.weights.h5 | только веса | архитектура в коде |
| SavedModel | граф для сервинга | деплой в TF Serving |
Как работает под капотом
При инференсе модель делает только forward pass: данные проходят слой за слоем к выходу, градиенты не считаются, веса не меняются. Слои вроде Dropout автоматически переключаются в режим применения (dropout отключается, BatchNorm использует накопленную статистику). Поэтому инференс быстрее обучения и потребляет меньше памяти.
Частые ошибки
- Забыть ось батча.
predictждёт форму (batch, ...); одиночную картинку нужно завернуть вreshape(1, ...). - Не нормализовать вход на инференсе. Данные нужно обрабатывать так же, как при обучении (то же деление на 255).
- Переобучать вместо загрузки. Загруженная модель готова к
predictсразу — повторныйfitне нужен.
Итог
- Обученную модель сохраняют в файл
.keras(архитектура + веса) и загружают черезload_model. - Инференс — только forward pass;
compile/fitне нужны. - Вход для
predictдолжен иметь ось батча и ту же предобработку, что при обучении. argmaxвыхода softmax даёт итоговый класс.