Дообучение с Trainer API

Осваиваем Trainer — высокоуровневый инструмент для дообучения без ручного цикла обучения.

Trainer — класс библиотеки transformers, который берёт на себя весь цикл обучения: батчи, шаги оптимизатора, логирование, сохранение чекпойнтов.

Зачем нужен Trainer

Ручной цикл обучения в PyTorch — это десятки строк: перебор эпох, батчей, обнуление градиентов, шаг оптимизатора, валидация. Trainer скрывает всё это. Вы описываете что обучать и как (гиперпараметры), а Trainer делает как именно.

Минимальный пример

Код требует transformers и не исполняется в браузере — он для чтения:

from transformers import (
    AutoModelForSequenceClassification,
    TrainingArguments, Trainer
)

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=2
)

args = TrainingArguments(
    output_dir="out",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
    eval_strategy="epoch",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
)

trainer.train()

TrainingArguments — пульт управления

ПараметрЧто задаёт
learning_rateскорость обучения (для fine-tuning малая, ~2e-5)
num_train_epochsсколько раз пройти по данным
per_device_train_batch_sizeразмер батча
eval_strategyкогда оценивать на валидации

Как работает под капотом

При вызове trainer.train() Trainer создаёт загрузчики данных, оптимизатор (по умолчанию AdamW) и расписание learning rate, затем гоняет цикл: для каждого батча считает предсказание, вычисляет loss (модель делает это сама, если в батче есть метки labels), считает градиенты и делает шаг оптимизатора. Параллельно логирует метрики и периодически сохраняет чекпойнты в output_dir. Если доступен GPU, Trainer использует его автоматически; библиотека accelerate позволяет распределить обучение на несколько устройств.

Частые ошибки

  • Большой learning rate. Для fine-tuning берут малые значения (2e-5…5e-5), иначе модель «забывает» предобучение.
  • Несовпадение num_labels и числа классов. Голова классификации должна иметь столько выходов, сколько классов в данных.
  • Передать нетокенизированный датасет. Trainer ждёт уже токенизированные данные с input_ids.

Итог

  • Trainer скрывает ручной цикл обучения: батчи, оптимизатор, логи, чекпойнты.
  • TrainingArguments задаёт гиперпараметры (learning rate, эпохи, батч).
  • Для fine-tuning нужен малый learning rate.
  • Trainer сам использует GPU и считает loss, если в данных есть labels.
Проверьте себя
1. Что берёт на себя Trainer?
AТолько токенизацию
BВесь цикл обучения: батчи, оптимизатор, логи, чекпойнты
CТолько скачивание модели
DТолько инференс
2. Какой learning rate обычно берут для fine-tuning?
AБольшой, около 0.1
BМалый, около 2e-5
CНулевой
DСлучайный на каждом шаге
3. Где Trainer задаёт гиперпараметры обучения?
AВ токенизаторе
BВ объекте TrainingArguments
CВ config.json модели
DВ attention_mask