admin管理员组

文章数量:1530020

他们各自的概念看以下链接就可以了:https://blog.csdn/weixin_43135178/category_11543123.html

 这里主要谈一下他们的区别?


先说结论:

  • VAE是AE的升级版,VAE也可以被看作是一种特殊的AE
  • AE主要用于数据的压缩与还原,不能用于生成(因为编码空间是不规整的),VAE主要用于生成
  • AE是将数据映直接映射为数值code(确定的数值),而VAE是先将数据映射为分布,再从分布中采样得到数值code。
  • 损失函数和优化目标不同


一、AE(Auto Encoder, 自动编码器)

1、AE的结构

如上图所示,自动编码器主要由两部分组成:编码器(Encoder)和解码器(Decoder)。编码器和解码器可以看作是两个函数,一个用于将高维输入(如图片)映射为低维编码(code),另一个用于将低维编码(code)映射为高维输出(如生成的图片)。这两个函数可以是任意形式,但在深度学习中,我们用神经网络去学习这两个函数。

这时候我们只要拿出Decoder部分,随机生成一个code然后输入,就可以得到一张生成的图像。但实际上这样的生成效果并不好(下面解释原因),因此AE多用于数据压缩,而数据生成则使用下面所介绍的VAE更好。

2、AE的缺陷

由上面介绍可以看出,AE的Encoder是将图片映射成“数值编码”,Decoder是将“数值编码”映射成图片。这样存在的问题是,在训练过程中,随着不断降低输入图片与输出图片之间的误差,模型会过拟合,泛化性能不好。也就是说对于一个训练好的AE,输入某个图片,就只会将其编码为某个确定的code,输入某个确定的code就只会输出某个确定的图片,如果这个latent code来自于没见过的图片,那么生成的图片也不会好。下面举个例子来说明:

假设我们训练好的AE将“新月”图片encode成code=1(这里假设code只有1维),将其decode能得到“新月”的图片;将“满月”encode成code=10,同样将其decode能得到“满月”图片。这时候如果我们给AE一个code=5,我们希望是能得到“半月”的图片,但由于之前训练时并没有将“半月”的图片编码,或者将一张非月亮的图片编码为5,那么我们就不太可能得到“半月”的图片。因此AE多用于数据的压缩和恢复,用于数据生成时效果并不理想。

3、AE的代码实现

3.1)AE encoder + decoder + AE的模型

import torch
from torch import nn
from torch.autograd import Variable


# Define the encoder and decoder networks
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, latent_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return x


class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, output_dim)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        return x


class Autoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder = Decoder(latent_dim, input_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

3.2)AE模型训练过程

# Define the model, loss function, and optimizer
input_dim = 784  # For MNIST images
latent_dim = 32
model = Autoencoder(input_dim, latent_dim)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = img.view(img.size(0), -1)
        img = Variable(img)

        # Forward pass
        output = model(img)
        loss = criterion(output, img)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

3.3)AE模型的推理过程

# Evaluate the model
model.eval()
with torch.no_grad():
    for data in dataloader:
        img, _ = data
        img = img.view(img.size(0), -1)
        img = Variable(img)
        output = model(img)

3.4)AE怎么通过latent code生成新的图像

# Initialize the autoencoder
input_dim = 784  # Assuming input is a flattened 28x28 image
latent_dim = 20  
autoencoder = Autoencoder(input_dim, latent_dim)
latent_vector = torch.randn(latent_dim)

# Generate a new image by passing the latent vector through the decoder
with torch.no_grad(): 
    generated_image = autoencoder.decoder(latent_vector)

4、如何解决AE的问题呢?

这时候我们转变思路,不将图片映射成“数值编码”,而将其映射成“分布”。还是刚刚的例子,我们将“新月”图片映射成μ=1的正态分布,那么就相当于在1附近加了噪声,此时不仅1表示“新月”,1附近的数值也表示“新月”,只是1的时候最像“新月”。将"满月"映射成μ=10的正态分布,10的附近也都表示“满月”。那么code=5时,就同时拥有了“新月”和“满月”的特点,那么这时候decode出来的大概率就是“半月”了。这就是VAE的思想。

二、VAE(Variational Auto-Encoder, 变分自动编码器)

1、VAE的结构

vae和常规自编码器不同的地方在于它的encoder的output不是一个latent vector让decoder去做decode,而是从某个连续的分布里(常见的是高斯分布)采样得到一个随机数or随机向量,然后decode再去针对这个scaler做解码。

不使用连续的AE这种分布,而是使用符合某种分布的VAE,这是因为:

常规的ae的潜在空间的规律性是一个难点,它取决于初始空间中数据的分布、潜在空间的维度和编码器的架构等。因此,我们基本不可能就认为ae的latent vector的distribution和我们产生随机数的distribution是一个distribution(不可能),那就很尴尬了,假设latent vector的取值范围都在0~1之间,然后我们产生了一个包含了大量的负数的随机数让decoder去decode,那么decoder压根就decode不出什么正常的东西,training阶段压根就没见过嘛。想一想,潜在空间中的latent vector 的分布规律很难知道也是正常的,因为常规的自动编码器的任务中没有任何东西被训练来强制获得这样的规律(但是vae就会假设latent vector服从高斯分布)自动编码器被训练成以尽可能少的损失进行编码和解码,压根就不care latent vector服从什么分布。那么自然,我们是不可能使用一个预定义的随机分布产生随机的input然后又期望decoder能够decode出有意义的东西的.

既然我们不知道latent vector服从什么分布,我们就直接人为对其进行约束满足某种预定义的分布,这个预定义的分布和我们产生随机数的分布保持一致,不就完美解决问题了吗?

所以通过VAE求出均值和方差,然后使用重参数化技巧在得到的这个分布中进行采样,就可以得到符合此分布的latent vector了。

为什么使用重参数化?

具体来说,在不使用重参数化的情况下,模型会直接从参数化的分布(例如,正态分布,由均值 μ 和方差 σ2 参数化)中采样,这使得梯度无法回传。

重参数化技巧通过引入一个不依赖于模型参数的外部噪声源(通常是标准正态分布中抽取的),并对这个噪声与模型的均值和方差进行变换,来生成符合目标分布的样本。这样,模型的随机输出就可以表示为模型参数的确定性函数和一个随机噪声的组合。便可以完成梯度回传。

其中的过程很复杂,你只需要知道可以将在某个分布进行采样的随机过程与标准正态分布通过重参数化技巧使得整个采样过程变为可微即可。

2、VAE的代码实现

整体架构,VAE计算以下两方面之间的损失:

  1. 重构损失(Reconstruction Loss):这一部分的损失计算的是输入数据与重构数据之间的差异。

  2. KL散度(Kullback-Leibler Divergence Loss):这一部分的损失衡量的是学习到的潜在表示的分布与先验分布(通常假设为标准正态分布)之间的差异。KL散度是一种衡量两个概率分布相似度的指标,VAE通过最小化KL散度来确保学习到的潜在表示的分布尽可能接近先验分布。这有助于模型生成性能的提升,因为它约束了潜在空间的结构,使其更加规整,便于采样和推断。

1)Encoder

image --> 均值 + 标准差

import torch
from torch import nn
from torch.nn import functional as F

# Encoder class definition
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        # 使用FC将输入变为隐藏层hidden_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        # Two fully connected layers to produce mean and log variance
        # These will represent the latent space distribution parameters
        self.fc21 = nn.Linear(hidden_dim, latent_dim) # 隐藏层hidden_dim --> 均值Mean μ
        self.fc22 = nn.Linear(hidden_dim, latent_dim) # 隐藏层hidden_dim --> 标准差Log variance σ

    def forward(self, x):
        # 使用RELU非线性变换,增加网络的表达能力
        h1 = F.relu(self.fc1(x))
        # Return the mean and log variance for the latent space
        return self.fc21(h1), self.fc22(h1)

2)Decoder

# Decoder class definition
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        # latent_dim --> hidden_dim
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        # hidden_dim --> output_dim(输出的图像)
        self.fc4 = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

3)VAE

重参数化技巧:

这段代码对应的数学公式可以写作:

z= σ⋅ϵ + μ

其中:

  • z 是从潜在分布中采样得到的样本。
  • log var(log方差):  log(σ2)
  • μ 是潜在分布的均值,对应代码中的 mu
  • σ 是潜在分布的标准差 std ,通过 torch.exp(0.5*logvar) 计算得到,这里 logvar 是对数方差log(σ2),因此 σ=exp(0.5*​log(σ2))。
  • ϵ 是从标准正态分布 N(0,1) 中采样得到的随机噪声,对应 torch.randn_like(std)

代码中的 eps.mul(std).add_(mu) 实现了上述公式的计算,即首先将随机噪声 ϵ 与标准差 σ 相乘,然后将结果加上均值 μ。这样,得到的 z 既包含了模型学习到的分布的特征(通过 μ 和 σ),同时也引入了必要的随机性(通过 ϵ),允许模型通过采样生成多样化的数据。

# VAE class definition
# Encode the input --> reparameterize --> decode
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

    def reparameterize(self, mu, logvar):
        # Reparameterization trick to sample from the distribution represented by the mean and log variance
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, x):
        mu, logvar = self.encoder(x.view(-1, input_dim))
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

4)Loss

# Loss function for VAE
def vae_loss_function(recon_x, x, mu, logvar):
    # Binary cross entropy between the target and the output
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
    # KL divergence loss : 学习到的潜在表示的分布 <-->  先验分布(标准正态分布)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

5)训练过程

# Hyperparameters
input_dim = 784 # Assuming input is a flattened 28x28 image (e.g., from MNIST)
hidden_dim = 400
latent_dim = 20
epochs = 10
learning_rate = 1e-3

# Initialize VAE
vae = VAE(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)

# Training process function
for epoch in range(epochs):
    vae.train()  # Set the model to training mode
    train_loss = 0
    for batch_idx, (data, _) in enumerate(data_loader):
        optimizer.zero_grad()  # Zero the gradients
        recon_batch, mu, logvar = vae(data)  # Forward pass through VAE
        loss = vae_loss_function(recon_batch, data, mu, logvar)  # Compute the loss
        loss.backward()  # Backpropagate the loss
        train_loss += loss.item()
        optimizer.step()

三、VQ-VAE

之所以叫做VQ(向量量化),主要是因为将连续潜在空间的点映射到最近的一组离散的向量(即码本中的向量)上

VQ-VAE的全称是Vector Quantized-Variational AutoEncoder,即向量量化变分自编码器。这是一种结合了变分自编码器(VAE)和向量量化(VQ)的深度学习模型,主要用于高效地学习数据的潜在表示。VQ-VAE通过将连续的潜在表示空间离散化来改进传统VAE模型。向量量化的过程实质上是将连续潜在空间的点映射到最近的一组离散的向量(即码本中的向量)上,这有助于模型捕捉和表示更加丰富和复杂的数据分布,由于维护了一个codebook,编码范围更加可控,VQVAE相对于VAE(VAE的隐变量 z 的每一维都是一个连续的值, 而VQ-VAE最大的特点就是, z 的每一维都是离散的整数。),可以生成更大更高清的图片(这也为后续DALLE和VQGAN的出现做了铺垫)。【原文中说的是避免了“后验坍塌”的问题】

1、算法步骤:

  1. 通过Encoder学习出输入图像x(256*256)中间编码 Ze(x)(32*32)【绿色】
  2. 事先定义好codebook(512,64),它有N个e组成【紫色】
  3. 然后通过最邻近搜索与中间编码Ze(x)最相似(接近)的codebook中K个向量之一,并记住这个向量的index【每一个图像都有共512^32^32次方种选择,其实这一步你可以看成类似NLP中预测下一个词的方法,就是从字典中寻找到一个对应的值】【青色】
  4. 根据得到的所有index去映射对应的codebook中的vector,得到输入图像对应的特征表征Zq(x)【紫色】
  5. 然后通过Decoder对Zq(x)进行重建

另外由于最邻近搜索使用argmax来找codebook中的索引位置,导致不可导问题,VQVAE通过stop gradient操作来避免最邻近搜索的不可导问题,也就是通过stop gradient操作,将decoder输入的梯度复制到encoder的输出上【红色的线】。

总结来讲:VQ-VAE的过程就是将原始连续的高斯分布变为了离散的分布(通过从有限的字典中找到每一个encoder latent code对应的向量)。

2、一些问题

A. 为什么要进行向量量化?(为什么要将 z 离散化)?

1. 离散的表示通常更适合于捕捉数据中的类别性质,如不同种类的对象、语音或文本数据的不同模式等。

2. 离散的潜在表示有助于模型生成更加清晰的输出。在连续潜在空间中,模型可能在生成新样本时产生模糊的结果,特别是在空间的某些区域中。而离散潜在空间能够降低这种模糊性,提高生成样本的质量。解决VAE生成不清晰的问题

3. 增强模型的解释性,相比于连续潜在空间,离散潜在空间可以为每个离散的潜在变量赋予更明确的语义解释。例如,在处理图像的任务中,不同的离散潜在变量可能对应于不同的视觉特征或对象类别,这使得模型的行为和学习到的表示更易于理解和解释。

5. 缓解潜在空间的过度平滑问题,VAE有时会遇到潜在空间的"过度平滑"问题,即潜在空间中不同区域之间的过渡太平滑,导致生成的样本缺乏多样性或区分度不够(容易模型崩塌)。通过引入离散潜在空间,VQ-VAE可以缓解这个问题,因为离散空间天然具有区分不同区域的能力。

B. 如何将 z 离散化?

1)构建codebook进行VQ 

将 z 离散化的关键就是VQ, 即vector quatization.

简单来说, 就是要先有一个codebook, 这个codebook是一个embedding table,然后再利用均匀分布对权重初始化

self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)

2)找到与图像embedding(flat_input)最近的codebook中的embedding 

我们在这个codebook 中找到和 vector 最接近(比如欧氏距离最近)的一个embedding, 用这个embedding的index来代表这个vector.

∑ [flat_input(16384, 64) - self._embedding.weight(512,64)]^2

# Calculate the Z_e(x) and e distances
# 这里使用欧几里得距离的平方求距离
distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
             + torch.sum(self._embedding.weight ** 2, dim=1)
             - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)

3)怎么解决的梯度截断(不可导)问题?

具体的介绍看:VQ-VAE中如何解决梯度截断(不可导)问题?直通估计、(stop gradient,停止梯度)-CSDN博客

另外由于最邻近搜索使用argmin来找codebook中的索引位置,导致不可导问题,VQVAE通过stop gradient操作来避免最邻近搜索的不可导问题,也就是通过stop gradient操作,将decoder输入(quantize)的梯度复制到encoder的输出上(input)。

quantize = input + (quantize - input).detach()
# 正向传播和往常一样

# 反向传播时,detach()这部分梯度为0,quantize和input的梯度相同
# 即实现将quantize复制给input

VQVAE相比于VAE最大的不同是,直接找每个属性的离散值,通过类似于查表的

直接按照下面这样,而不进行 (quantized - inputs).detach() 是不可以的,因为这样会让模型退化到一个VAE的样子,因为你是直接将输出复制给输入的。

quantize = input

C. 怎么根据一个codebook解码生成多种不同的图像?

当解码器只能接受一组 codebook 向量作为输入时,人们怎么能指望它产生大量多样化的图像呢?

例如,对于图像,编码器可能会输出一个32x32 的矢量网格,每个网格都被量化,然后将整个网格送到解码器。

例如,假设我们正在处理图像,我们有一个尺寸为512的codebook(每个维度是64),我们的编码器输出一个 32x32 的矢量网格。在这种情况下,我们的解码器可以输出

 个不同图像。

所以encoder后的维度大小是:32*32*64(这里做的依据就是nn.Embedding)

所以我们可以通过一个简单的小的codebook就可以采样出无穷无尽的图像。

深入理解 VQ-VAE – sunlin-ai

D. 为什么VQ-VAE是离散化的过程?怎么体现的?

你可能会有一个问题,就是codebook是由nn.Embedding()构成的,那么 nn.Embedding()里面分明是一堆浮点数,为什么VQ-VAE叫做离散化?

这是因为虽然 nn.Embedding()里面是一堆浮点数,但是这些浮点数是从有限的集合中选出来的:

离散化后的隐空间zq向量不再是原始的连续实数向量,而是codebook中的离散的嵌入向量集合。这些离散的嵌入向量是从有限的codebook e中选出来的,因此可以看作是离散化的表示。

3、VQVAE的损失

与VAE的不同:去掉VAE的KL loss,增加了两项loss

1)重构损失(reconstruction loss):

  • 目标:衡量重构数据和原始数据之间的相似度。
  • 计算:通常使用均方误差(MSE)或交叉熵损失来计算重构图像和原始图像之间的差异。

2)代码本损失(codebook loss):

代码本损失训练更新码本向量,使其更好地代表输入数据的连续潜在表示。

假设Ze​(x)是编码器对输入x的连续潜在表示,e是选取的最接近的码本向量。

3)提交损失(Commitment Loss)

提交损失 只训练encoder确保编码器的输出不会偏离它选择的码本向量太远,从而保证训练过程的稳定性。

计算模型编码器输出的连续潜在表示和量化后的表示(即选取的码本向量)之间的距离。这有助于稳定训练,确保编码器输出与码本的选择保持一致(原文:encourage the output of encoder to stay close to the chosen codebook vector to prevent it from flucturating too frequently from one code vector to another, 即防止encoder的输出频繁在各个codebook embedding之间跳),也通常使用均方误差(MSE)计算。

假设Ze​(x)是编码器对输入x的连续潜在表示,e是选取的最接近的码本向量。

备注:关于loss怎么起到的将模型的某些参数更新,某些参数不更新,看:VQ-VAE中如何解决梯度截断(不可导)问题?直通估计、(stop gradient,停止梯度)-CSDN博客

4、如何采样?

离散向量的另一个问题是它不好采样。回忆一下,VAE之所以把图片编码成符合正态分布的连续向量,就是为了能在图像生成时把编码器扔掉,让随机采样出的向量也能通过解码器变成图片。现在倒好,VQ-VAE把图片编码了一个离散向量,这个离散向量构成的空间是不好采样的。VQ-VAE不是面临着和AE一样的问题嘛。

在离散空间直接采样这个问题是无解的。没错!VQ-VAE根本不是一个图像生成模型。它和AE一样,只能很好地完成图像压缩,把图像变成一个短得多的向量,而不支持随机图像生成。VQ-VAE和AE的唯一区别,就是VQ-VAE会编码出离散向量,而AE会编码出连续向量。

可为什么VQ-VAE会被归类到图像生成模型中呢?这是因为VQ-VAE的作者利用VQ-VAE能编码离散向量的特性,使用了一种特别的方法对VQ-VAE的离散编码空间采样。VQ-VAE的作者之前设计了一种图像生成网络,叫做PixelCNN。PixelCNN能拟合一个离散的分布。比如对于图像,PixelCNN能输出某个像素的某个颜色通道取0~255中某个值的概率分布。这不刚好嘛,VQ-VAE也是把图像编码成离散向量。换个更好理解的说法,VQ-VAE能把图像映射成一个「小图像」。我们可以把PixelCNN生成图像的方法搬过来,让PixelCNN学习生成「小图像」。这样,我们就可以用PixelCNN生成离散编码,再利用VQ-VAE的解码器把离散编码变成图像。

  1. 训练VQ-VAE的编码器和解码器,使得VQ-VAE能把图像变成「小图像(32*32*64的codebook对应的值)」,也能把「小图像」变回图像。
  2. 训练PixelCNN,让它学习怎么生成「小图像」(这一步类似于RCG的RDM)。
  3. 随机采样时,先用PixelCNN采样出「小图像」,再用VQ-VAE把「小图像」翻译成最终的生成图像。

4、代码示例

需要确保经过encoder的图像的通道数D(此时经过encoder后的图像“通道数”不一定再是3了,可能会更大,例如64,我们这里只是形象的把它叫做“通道数”罢了)与codebook中的向量维度是相同的。

整体步骤

具体实现步骤如下:

完整代码如下:

from __future__ import print_function


import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter


from six.moves import xrange

import umap

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ===================================================== Load Data =====================================================
training_data = datasets.CIFAR10(root="data", train=True, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

validation_data = datasets.CIFAR10(root="data", train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))
# 计算 方差
data_variance = np.var(training_data.data / 255.0)



# ===================================================== Vector Quantizer Layer =====================================================
"""
This layer takes a tensor to be quantized. 
The channel dimension will be used as the space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.
The output tensor will have the same shape as the input.
As an example for a BCHW tensor of shape [16, 64, 32, 32], we will first convert it to an BHWC tensor of shape [16, 32, 32, 64] and then reshape it into [16384, 64] and all 16384 vectors of size 64 will be quantized independently. 
In otherwords, the channels are used as the space in which to quantize. All other dimensions will be flattened and be seen as different examples to quantize, 16384 in this case.
"""

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        '''

        :param num_embeddings: codebook的大小
        :param embedding_dim: codebook中每个vector的维度
        :param commitment_cost: commit loss的β
        '''

        super(VectorQuantizer, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        # 构建一个codebook,用均匀分布对codebook的权重进行初始化
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):

        # convert inputs(encoder's output) from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate the Z_e(x) and e distances
        # 这里使用欧几里得距离的平方求距离distance(16384,512)
        distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
                     + torch.sum(self._embedding.weight ** 2, dim=1)
                     - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # 通过distance得到距离最近的index
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        # 转为one-hot格式
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten:得到最近邻的Embedding Vector
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Loss
        # commit loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        # codebook loss
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        # trick(梯度复制),通过添加一个常数让编码器和解码器连续可导
        quantized = inputs + (quantized - inputs).detach()
        # 利用困惑度监测分布,困惑度越大,信息熵也就越大,分布就没有这么均匀
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings


"""
We will also implement a slightly modified version which will use exponential moving averages to update the embedding vectors instead of an auxillary loss.
This has the advantage that the embedding updates are independent of the choice of optimizer for the encoder, decoder and other parts of the architecture. 
For most experiments the EMA version trains faster than the non-EMA version.
"""

class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost

        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()

        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
                     + torch.sum(self._embedding.weight ** 2, dim=1)
                     - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)

            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                    (self._ema_cluster_size + self._epsilon)
                    / (n + self._num_embeddings * self._epsilon) * n)

            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)

            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss

        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings


# ===================================================== Encoder & Decoder Architecture =====================================================
class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,
                      out_channels=num_residual_hiddens,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(in_channels=num_residual_hiddens,
                      out_channels=num_hiddens,
                      kernel_size=1, stride=1, bias=False)
        )

    def forward(self, x):
        return x + self._block(x)


class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()
        self._num_residual_layers = num_residual_layers
        self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                                      for _ in range(self._num_residual_layers)])

    def forward(self, x):
        for i in range(self._num_residual_layers):
            x = self._layers[i](x)
        return F.relu(x)


class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Encoder, self).__init__()

        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens // 2,
                                 kernel_size=4,
                                 stride=2, padding=1)
        self._conv_2 = nn.Conv2d(in_channels=num_hiddens // 2,
                                 out_channels=num_hiddens,
                                 kernel_size=4,
                                 stride=2, padding=1)
        self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
                                 out_channels=num_hiddens,
                                 kernel_size=3,
                                 stride=1, padding=1)
        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)

    def forward(self, inputs):
        x = self._conv_1(inputs)
        x = F.relu(x)

        x = self._conv_2(x)
        x = F.relu(x)

        x = self._conv_3(x)
        return self._residual_stack(x)


class Decoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Decoder, self).__init__()

        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens,
                                 kernel_size=3,
                                 stride=1, padding=1)

        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)

        self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
                                                out_channels=num_hiddens // 2,
                                                kernel_size=4,
                                                stride=2, padding=1)

        self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2,
                                                out_channels=3,
                                                kernel_size=4,
                                                stride=2, padding=1)

    def forward(self, inputs):
        x = self._conv_1(inputs)

        x = self._residual_stack(x)

        x = self._conv_trans_1(x)
        x = F.relu(x)

        return self._conv_trans_2(x)


# ===================================================== Train =====================================================
batch_size = 256
num_training_updates = 15000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2

# codebook的维度
embedding_dim = 64
num_embeddings = 512

commitment_cost = 0.25

decay = 0.99
decay = 0
learning_rate = 1e-3

training_loader = DataLoader(training_data,
                             batch_size=batch_size,
                             shuffle=True,
                             pin_memory=True)
validation_loader = DataLoader(validation_data,
                               batch_size=32,
                               shuffle=True,
                               pin_memory=True)


class Model(nn.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay=0):
        super(Model, self).__init__()

        self._encoder = Encoder(3, num_hiddens, num_residual_layers, num_residual_hiddens)
        # 对encoder的输出进行后处理,得到与embedding table一样大小的维度
        self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
                                      out_channels=embedding_dim,
                                      kernel_size=1,
                                      stride=1)
        if decay > 0.0:
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)

        self._decoder = Decoder(embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)

    def forward(self, x):
        z = self._encoder(x)
        z = self._pre_vq_conv(z)
        loss, quantized, perplexity, _ = self._vq_vae(z)
        x_recon = self._decoder(quantized)

        return loss, x_recon, perplexity


model = Model(num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)
model.train()
train_res_recon_error = []
train_res_perplexity = []

for i in xrange(num_training_updates):
    (data, _) = next(iter(training_loader))
    data = data.to(device)
    optimizer.zero_grad()

    vq_loss, data_recon, perplexity = model(data)
    recon_error = F.mse_loss(data_recon, data) / data_variance
    loss = recon_error + vq_loss
    loss.backward()

    optimizer.step()

    train_res_recon_error.append(recon_error.item())
    train_res_perplexity.append(perplexity.item())

    if (i + 1) % 100 == 0:
        print('%d iterations' % (i + 1))
        print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]))
        print('perplexity: %.3f' % np.mean(train_res_perplexity[-100:]))
        print()

pytorch-vq-vae/vq-vae.ipynb at master · zalandoresearch/pytorch-vq-vae · GitHub

四、总结

  • AE主要用于数据的压缩与还原,在生成数据上使用VAE。
  • AE是将数据映直接映射为数值code,而VAE是先将数据映射为分布,再从分布中采样得到数值code。
  • VQ-VAE是将中间编码映射为codebook中K个向量之一,然后通过Decoder对latent code进行重建

因此AutoEncoder、VAE和VQ-VAE可以统一为latent code的概率分布设计不一样,AutoEncoder通过网络学习得到任意概率分布VAE设计为正态分布VQVAE设计为codebook的离散分布总之,AutoEncoder的重构思想就是用低纬度的latent code分布来表达高纬度的数据分布,VAE和VQVAE的重构思想是通过设计latent code的分布形式,进而控制图片生成的过程。

漫谈VAE和VQVAE,从连续分布到离散分布 - 知乎

本文标签: 编码器向量区别aeVQ