生成對抗模型 FastGAN 有詳細範例嗎?
當然,FastGAN 是一種能夠在少量標註數據上快速訓練並生成高品質影像的生成對抗網絡。FastGAN 使用了一種新的自監督判別器設計(self-supervised discriminator)和「迴歸生成」技術,使其能夠高效地訓練和生成影像。
以下是 FastGAN 的詳細範例,包括模型架構、損失函數的選擇和 PyTorch 代碼示例。
FastGAN 的核心技術
-
自監督判別器(Self-Supervised Discriminator):
- FastGAN 的判別器同時學習判斷影像的真實性(真或假)以及額外的自監督任務,如分類或旋轉預測。這種設計使得判別器能夠從生成的影像中學習更多特徵,提高判別器的表達能力。
-
迴歸生成(Regression-based Generation):
- FastGAN 的生成器使用迴歸生成技術,學習從潛在空間中生成真實感較高的影像。這種技術有助於在小樣本數據集上快速收斂。
-
Loss 函數設計:
- FastGAN 使用標準的 adversarial loss,同時引入自監督任務的損失來增強訓練效果。
- Generator loss: 由生成影像被判別器誤判為真影像的程度來計算。
- Discriminator loss: 除了判別生成影像和真實影像之外,還包括自監督任務的損失。
PyTorch 範例代碼
以下是使用 PyTorch 實現 FastGAN 的基本框架:
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms # 定義生成器 class Generator(nn.Module): def __init__(self, latent_dim, img_channels): super(Generator, self).__init__() self.init_size = 64 // 4 self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2)) self.conv_blocks = nn.Sequential( nn.BatchNorm2d(128), nn.Upsample(scale_factor=2), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, img_channels, 3, stride=1, padding=1), nn.Tanh(), ) def forward(self, z): out = self.l1(z) out = out.view(out.shape[0], 128, self.init_size, self.init_size) img = self.conv_blocks(out) return img # 定義判別器 class Discriminator(nn.Module): def __init__(self, img_channels): super(Discriminator, self).__init__() def discriminator_block(in_filters, out_filters, bn=True): block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] if bn: block.append(nn.BatchNorm2d(out_filters, 0.8)) return block self.model = nn.Sequential( *discriminator_block(img_channels, 16, bn=False), *discriminator_block(16, 32), *discriminator_block(32, 64), *discriminator_block(64, 128), ) # 影像分類頭 self.adv_layer = nn.Sequential(nn.Linear(128 * (64 // 2 ** 4) ** 2, 1), nn.Sigmoid()) # 自監督任務分類頭(可選,例如旋轉分類) self.aux_layer = nn.Sequential(nn.Linear(128 * (64 // 2 ** 4) ** 2, 4), nn.Softmax(dim=1)) def forward(self, img): out = self.model(img) out = out.view(out.shape[0], -1) validity = self.adv_layer(out) label = self.aux_layer(out) return validity, label # 超參數設置 latent_dim = 100 img_size = 64 img_channels = 3 lr = 0.0002 b1 = 0.5 b2 = 0.999 batch_size = 64 epochs = 100 # 準備數據集和 DataLoader transform = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 初始化模型 generator = Generator(latent_dim, img_channels).cuda() discriminator = Discriminator(img_channels).cuda() # 設置優化器 optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2)) optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2)) adversarial_loss = nn.BCELoss().cuda() auxiliary_loss = nn.CrossEntropyLoss().cuda() # 開始訓練 for epoch in range(epochs): for i, (imgs, _) in enumerate(dataloader): # --------------- 訓練生成器 --------------- optimizer_G.zero_grad() # 生成隨機噪聲 z = torch.randn(imgs.size(0), latent_dim).cuda() gen_imgs = generator(z) # 判別生成影像的真實性 validity, _ = discriminator(gen_imgs) g_loss = adversarial_loss(validity, torch.ones_like(validity).cuda()) g_loss.backward() optimizer_G.step() # --------------- 訓練判別器 --------------- optimizer_D.zero_grad() # 判別真實影像 real_validity, real_label = discriminator(imgs.cuda()) d_real_loss = adversarial_loss(real_validity, torch.ones_like(real_validity).cuda()) # 判別生成影像 fake_validity, fake_label = discriminator(gen_imgs.detach()) d_fake_loss = adversarial_loss(fake_validity, torch.zeros_like(fake_validity).cuda()) # 自監督任務損失(假設是旋轉分類) real_aux_loss = auxiliary_loss(real_label, torch.randint(0, 4, (imgs.size(0),)).cuda()) d_loss = (d_real_loss + d_fake_loss + real_aux_loss) / 3 d_loss.backward() optimizer_D.step() print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
代碼說明
-
Generator 和 Discriminator:
- Generator:使用全連接層和轉置卷積(上採樣)來生成影像。
- Discriminator:使用卷積層來判別影像的真偽,並增加一個自監督頭(auxiliary layer)進行輔助任務。
-
損失函數:
- 使用
BCELoss
計算真實性損失。 - 使用
CrossEntropyLoss
計算自監督任務損失(如旋轉預測)。
- 使用
-
訓練過程:
- 在每個 epoch 中,首先更新生成器,再更新判別器,交替進行。生成器試圖生成更真實的影像以欺騙判別器,而判別器則學習分辨生成影像和真實影像的區別,並通過自監督任務增強其表現。
這個範例是 FastGAN 的基本實現。通過調整網絡結構、損失權重或數據集,您可以進一步提升生成影像的品質和模型的表現。
留言板
歡迎留下建議與分享!希望一起交流!感恩!