Deep Q-Network: replay buffer и target network
Урок объясняет идею DQN и два приёма, без которых глубокий Q-learning нестабилен.
Deep Q-Network (DQN) — это Q-learning, где Q-функцию аппроксимирует нейросеть, а стабильность обеспечивают два приёма: experience replay (буфер опыта) и target network (целевая сеть).
Зачем DQN наделал шума
В 2013–2015 годах DQN научился играть в десятки игр Atari, видя только пиксели и счёт, и в некоторых превзошёл человека. Это доказало, что Q-learning масштабируется на сложные задачи, если аккуратно укротить нестабильность нейросетевой аппроксимации.
Проблема: наивный deep Q-learning разъезжается
Если просто подставить нейросеть в формулу Q-learning, обучение часто расходится. Две причины:
- Коррелированные данные. Последовательные кадры очень похожи. Обучать сеть на потоке почти одинаковых примеров плохо — нейросети нужны разнообразные, независимые батчи.
- Движущаяся цель. В цели обновления r + gamma·max Q(s',a') фигурирует та же сеть, что мы обучаем. Меняя веса, мы сдвигаем и цель — «собака гонится за собственным хвостом».
Приём 1: experience replay (буфер опыта)
Агент складывает переходы (s, a, r, s', done) в большой буфер, а для обучения случайно выбирает мини-батч из прошлого опыта. Это разрывает корреляцию между соседними примерами и позволяет переиспользовать один и тот же опыт многократно.
from collections import deque
import random
random.seed(0)
# Упрощённый replay buffer: храним переходы, берём случайный батч
buffer = deque(maxlen=5)
for i in range(8):
buffer.append(("s%d" % i, "a", i, "s%d" % (i + 1)))
print("В буфере (последние 5):", [t[0] for t in buffer])
batch = random.sample(buffer, 3)
print("Случайный батч для обучения:", [t[0] for t in batch])Вывод:
В буфере (последние 5): ['s3', 's4', 's5', 's6', 's7'] Случайный батч для обучения: ['s6', 's7', 's3']
Буфер хранит последние переходы (старые вытесняются), а обучение идёт по случайной выборке — батч получился перемешанным, а не из подряд идущих кадров.
Приём 2: target network (целевая сеть)
Заводят две копии сети. Основная (online) сеть обучается каждый шаг. Целевая (target) сеть — её «замороженная» копия, которую используют для вычисления цели r + gamma·max Q_target(s',a') и обновляют редко (раз в несколько тысяч шагов копируют в неё веса из основной). Так цель остаётся стабильной, и сеть не гонится за собственным хвостом.
Как работает под капотом
Полный шаг DQN: агент действует epsilon-greedy, складывает переход в буфер, берёт случайный батч, считает для него TD-цель через целевую сеть, делает шаг градиентного спуска по основной сети и периодически синхронизирует целевую. Реальная реализация требует PyTorch и не исполняется в браузере — ниже псевдокод для понимания. Нейросети и обучение разобраны в наших курсах «Глубокое обучение» и «PyTorch».
// Псевдокод одного шага обучения DQN (требует PyTorch, не запускается)
for step in range(N):
a = epsilon_greedy(online_net, s)
s2, r, done = env.step(a)
buffer.add((s, a, r, s2, done))
batch = buffer.sample(64)
target = r + gamma * max(target_net(s2)) // цель из target-сети
loss = mse(online_net(s)[a], target)
optimizer.step(loss) // учим online-сеть
if step % 1000 == 0:
copy_weights(online_net -> target_net) // синхронизация
s = s2Частые ошибки
- Обойтись без replay buffer. На коррелированных подряд идущих кадрах сеть быстро расходится.
- Не использовать target network. Движущаяся цель приводит к колебаниям и расхождению Q-значений.
- Переоценка ценности. Из-за max в цели DQN склонен завышать Q; это лечит вариант Double DQN, где выбор и оценку действия делают разные сети.
Итоги
- DQN — это Q-learning с нейросетью вместо таблицы.
- Experience replay разрывает корреляцию данных и переиспользует опыт.
- Target network делает цель обновления стабильной, предотвращая расхождение.