好的,這是一個使用 PyTorch 實現深度卷積生成對抗網絡(DCGAN)的完整範例。這個範例展示了如何構建和訓練 DCGAN 來生成手寫數字(例如 MNIST 數據集)。你可以根據自己的數據集進行修改。
1. 安裝 PyTorch
首先,確保已經安裝 PyTorch。如果還沒有,請執行以下命令:
pip install torch torchvision
2. 準備資料
我們使用 MNIST 數據集作為範例:
import torch import torchvision.transforms as transforms from torchvision.datasets import MNIST from torch.utils.data import DataLoader # 設定參數 batch_size = 128 image_size = 64 # 轉換函數 transform = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) # 加載 MNIST 資料集 dataset = MNIST(root='./data', train=True, transform=transform, download=True) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
3. 定義模型
我們定義生成器和判別器模型:
import torch.nn as nn # 生成器模型 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.main = nn.Sequential( nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, input): return self.main(input) # 判別器模型 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = nn.Sequential( nn.Conv2d(1, 64, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(512, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, input): return self.main(input).view(-1, 1).squeeze(1)
4. 訓練模型
定義損失函數和優化器,並訓練 DCGAN:
import torch.optim as optim # 設定裝置 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 創建模型實例 netG = Generator().to(device) netD = Discriminator().to(device) # 損失函數 criterion = nn.BCELoss() # 噪聲向量 fixed_noise = torch.randn(64, 100, 1, 1, device=device) # 優化器 optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 訓練DCGAN num_epochs = 5 for epoch in range(num_epochs): for i, data in enumerate(dataloader, 0): ############################ # (1) 更新 D: maximize log(D(x)) + log(1 - D(G(z))) ########################### ## 使用真實數據訓練 D netD.zero_grad() real_cpu = data[0].to(device) batch_size = real_cpu.size(0) labels = torch.full((batch_size,), 1, device=device) output = netD(real_cpu) errD_real = criterion(output, labels) errD_real.backward() D_x = output.mean().item() ## 使用假數據訓練 D noise = torch.randn(batch_size, 100, 1, 1, device=device) fake = netG(noise) labels.fill_(0) output = netD(fake.detach()) errD_fake = criterion(output, labels) errD_fake.backward() D_G_z1 = output.mean().item() errD = errD_real + errD_fake optimizerD.step() ############################ # (2) 更新 G: maximize log(D(G(z))) ########################### netG.zero_grad() labels.fill_(1) output = netD(fake) errG = criterion(output, labels) errG.backward() D_G_z2 = output.mean().item() optimizerG.step() # 輸出訓練狀況 if i % 100 == 0: print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}') # 保存模型 torch.save(netG.state_dict(), 'generator.pth') torch.save(netD.state_dict(), 'discriminator.pth')
5. 生成影像
訓練完成後,你可以使用生成器來生成新影像:
import matplotlib.pyplot as plt import numpy as np # 生成影像 netG.eval() with torch.no_grad(): noise = torch.randn(64, 100, 1, 1, device=device) fake = netG(noise).detach().cpu() # 顯示生成影像 def show_generated_images(images): fig, ax = plt.subplots(8, 8, figsize=(8, 8)) for i in range(8): for j in range(8): ax[i, j].imshow(images[i*8 + j].permute(1, 2, 0).squeeze(), cmap='gray') ax[i, j].axis('off') plt.show() # 展示生成影像 show_generated_images(fake)
這是一個基本的 DCGAN 範例,用於生成 MNIST 手寫數字。你可以根據自己的數據集進行修改,包括調整網絡架構、超參數和訓練流程。
這個範例應該能夠幫助你理解如何使用 DCGAN 來生成訓練資料,從而避免過擬合問題。
留言板
歡迎留下建議與分享!希望一起交流!感恩!