Как вычисляется MSE в вариационном автоэнкодере (VAE)
Вариационный автоэнкодер (VAE) - это генеративная модель, которая обучается восстанавливать входные данные, одновременно формируя непрерывное латентное пространство. Функция потерь VAE состоит из двух частей: дивергенции Кульбака - Лейблера (KL-дивергенции) и ошибки восстановления, часто измеряемой как среднеквадратичная ошибка (MSE). В этой статье подробно разберём, как именно вычисляется MSE при обучении VAE, особенно в контексте батчевой обработки.
Структура VAE и роль MSE
VAE состоит из энкодера, который преобразует входные данные в параметры распределения (среднее и дисперсия) в латентном пространстве, и декодера, который восстанавливает данные из сэмплированного латентного вектора. MSE выступает метрикой, показывающей, насколько восстановленные данные (Y) близки к исходным (X).
Ключевая особенность VAE - использование репараметризации: вместо прямого сэмплирования из распределения N(μ, σ²) мы генерируем случайный шум ε из стандартного нормального распределения и преобразуем его: z = μ + σ * ε. Это позволяет градиентам проходить через сэмплирование при обратном распространении.
Как MSE вычисляется для батча из 50 объектов
Рассмотрим практический пример с батчем, содержащим 50 объектов X = {x₁, x₂, ..., x₅₀}. Процесс вычисления MSE включает следующие шаги:
- Параметризация латентного пространства: Для каждого объекта xᵢ энкодер выдаёт два вектора: среднее μᵢ и логарифм дисперсии log(σᵢ²). Таким образом, для батча из 50 объектов мы получаем 50 пар (μᵢ, σᵢ).
- Сэмплирование латентных векторов: Используя репараметризацию, для каждого объекта генерируется свой латентный вектор zᵢ. То есть из 50 пар (μᵢ, σᵢ) мы получаем 50 различных zᵢ.
- Декодирование: Каждый zᵢ подаётся на декодер, который восстанавливает объект ŷᵢ. На выходе декодера - 50 восстановленных объектов.
- Вычисление MSE: MSE считается как среднее арифметическое квадратов разностей между соответствующими элементами исходного батча X и восстановленного батча Y: MSE = (1/50) * Σᵢ (xᵢ - ŷᵢ)². Если объекты многомерные (например, изображения), то сначала усредняется ошибка по всем пикселям/признакам внутри каждого объекта, а затем - по батчу.
Таким образом, ваше предположение верно: для батча из 50 объектов сэмплируется 50 латентных векторов, и MSE вычисляется попарно между каждым исходным и восстановленным объектом.
Почему MSE не усредняется по всему латентному пространству
Важно понимать: MSE не вычисляется напрямую из параметров распределения (μ и σ). MSE - это ошибка на выходе декодера, а KL-дивергенция отвечает за регуляризацию латентного пространства. Если бы мы усредняли MSE по всему латентному пространству, это привело бы к потере информации о конкретных объектах и модель перестала бы их различать.
На практике MSE может быть заменена на бинарную кросс-энтропию (BCE) для данных с бинарными пикселями, но принцип остаётся тем же: ошибка считается для каждого элемента батча индивидуально.
Пример кода на Python (PyTorch)
import torch import torch.nn.functional as F # Пусть batch_size = 50, dim = 784 (например, MNIST) x = torch.randn(50, 784) # исходные данные # Энкодер выдаёт mu и logvar mu = torch.randn(50, 20) # 20 - размерность латентного пространства logvar = torch.randn(50, 20) # Репараметризация std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mu + eps * std # Декодер восстанавливает данные recon = torch.sigmoid(torch.randn(50, 784)) # MSE loss mse_loss = F.mse_loss(recon, x, reduction='mean') print(mse_loss.item()) В этом коде MSE вычисляется как среднее по всем 50 объектам и всем 784 признакам.
Частые ошибки при расчёте MSE в VAE
- Сэмплирование одного вектора для всего батча: Новички иногда сэмплируют один z для всех объектов, что ломает процесс обучения, так как модель не может восстановить разные объекты из одного вектора.
- Усреднение MSE по латентному пространству: Некоторые путают MSE с KL-дивергенцией. KL-дивергенция вычисляется между распределением N(μ, σ²) и стандартным нормальным распределением, а MSE - между реальными и восстановленными данными.
- Неправильное reduction: В фреймворках вроде PyTorch параметр reduction='sum' даст сумму ошибок, а не среднее. Для MSE обычно используют reduction='mean'.
Заключение
MSE в VAE вычисляется путём попарного сравнения каждого объекта из батча с его восстановленной версией. Количество сэмплированных латентных векторов равно размеру батча. Это обеспечивает как точное восстановление, так и возможность генерации новых данных из непрерывного латентного пространства. Понимание этого механизма - ключ к успешному обучению VAE и его модификаций.