Дообучение с Trainer API
Осваиваем Trainer — высокоуровневый инструмент для дообучения без ручного цикла обучения.
Trainer — класс библиотеки
transformers, который берёт на себя весь цикл обучения: батчи, шаги оптимизатора, логирование, сохранение чекпойнтов.
Зачем нужен Trainer
Ручной цикл обучения в PyTorch — это десятки строк: перебор эпох, батчей, обнуление градиентов, шаг оптимизатора, валидация. Trainer скрывает всё это. Вы описываете что обучать и как (гиперпараметры), а Trainer делает как именно.
Минимальный пример
Код требует transformers и не исполняется в браузере — он для чтения:
from transformers import (
AutoModelForSequenceClassification,
TrainingArguments, Trainer
)
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2
)
args = TrainingArguments(
output_dir="out",
learning_rate=2e-5,
per_device_train_batch_size=16,
num_train_epochs=3,
eval_strategy="epoch",
)
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["test"],
)
trainer.train()TrainingArguments — пульт управления
| Параметр | Что задаёт |
learning_rate | скорость обучения (для fine-tuning малая, ~2e-5) |
num_train_epochs | сколько раз пройти по данным |
per_device_train_batch_size | размер батча |
eval_strategy | когда оценивать на валидации |
Как работает под капотом
При вызове trainer.train() Trainer создаёт загрузчики данных, оптимизатор (по умолчанию AdamW) и расписание learning rate, затем гоняет цикл: для каждого батча считает предсказание, вычисляет loss (модель делает это сама, если в батче есть метки labels), считает градиенты и делает шаг оптимизатора. Параллельно логирует метрики и периодически сохраняет чекпойнты в output_dir. Если доступен GPU, Trainer использует его автоматически; библиотека accelerate позволяет распределить обучение на несколько устройств.
Частые ошибки
- Большой learning rate. Для fine-tuning берут малые значения (2e-5…5e-5), иначе модель «забывает» предобучение.
- Несовпадение num_labels и числа классов. Голова классификации должна иметь столько выходов, сколько классов в данных.
- Передать нетокенизированный датасет. Trainer ждёт уже токенизированные данные с
input_ids.
Итог
- Trainer скрывает ручной цикл обучения: батчи, оптимизатор, логи, чекпойнты.
TrainingArgumentsзадаёт гиперпараметры (learning rate, эпохи, батч).- Для fine-tuning нужен малый learning rate.
- Trainer сам использует GPU и считает loss, если в данных есть
labels.