好的,這是一個使用 PyTorch 實現深度卷積生成對抗網絡(DCGAN)的完整範例。這個範例展示了如何構建和訓練 DCGAN 來生成手寫數字(例如 MNIST 數據集)。你可以根據自己的數據集進行修改。
1. 安裝 PyTorch
首先,確保已經安裝 PyTorch。如果還沒有,請執行以下命令:
pip install torch torchvision
2. 準備資料
我們使用 MNIST 數據集作為範例:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
# 設定參數
batch_size = 128
image_size = 64
# 轉換函數
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# 加載 MNIST 資料集
dataset = MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
3. 定義模型
我們定義生成器和判別器模型:
import torch.nn as nn
# 生成器模型
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
# 判別器模型
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input).view(-1, 1).squeeze(1)
4. 訓練模型
定義損失函數和優化器,並訓練 DCGAN:
import torch.optim as optim
# 設定裝置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 創建模型實例
netG = Generator().to(device)
netD = Discriminator().to(device)
# 損失函數
criterion = nn.BCELoss()
# 噪聲向量
fixed_noise = torch.randn(64, 100, 1, 1, device=device)
# 優化器
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 訓練DCGAN
num_epochs = 5
for epoch in range(num_epochs):
for i, data in enumerate(dataloader, 0):
############################
# (1) 更新 D: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## 使用真實數據訓練 D
netD.zero_grad()
real_cpu = data[0].to(device)
batch_size = real_cpu.size(0)
labels = torch.full((batch_size,), 1, device=device)
output = netD(real_cpu)
errD_real = criterion(output, labels)
errD_real.backward()
D_x = output.mean().item()
## 使用假數據訓練 D
noise = torch.randn(batch_size, 100, 1, 1, device=device)
fake = netG(noise)
labels.fill_(0)
output = netD(fake.detach())
errD_fake = criterion(output, labels)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
############################
# (2) 更新 G: maximize log(D(G(z)))
###########################
netG.zero_grad()
labels.fill_(1)
output = netD(fake)
errG = criterion(output, labels)
errG.backward()
D_G_z2 = output.mean().item()
optimizerG.step()
# 輸出訓練狀況
if i % 100 == 0:
print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')
# 保存模型
torch.save(netG.state_dict(), 'generator.pth')
torch.save(netD.state_dict(), 'discriminator.pth')
5. 生成影像
訓練完成後,你可以使用生成器來生成新影像:
import matplotlib.pyplot as plt
import numpy as np
# 生成影像
netG.eval()
with torch.no_grad():
noise = torch.randn(64, 100, 1, 1, device=device)
fake = netG(noise).detach().cpu()
# 顯示生成影像
def show_generated_images(images):
fig, ax = plt.subplots(8, 8, figsize=(8, 8))
for i in range(8):
for j in range(8):
ax[i, j].imshow(images[i*8 + j].permute(1, 2, 0).squeeze(), cmap='gray')
ax[i, j].axis('off')
plt.show()
# 展示生成影像
show_generated_images(fake)
這是一個基本的 DCGAN 範例,用於生成 MNIST 手寫數字。你可以根據自己的數據集進行修改,包括調整網絡架構、超參數和訓練流程。
這個範例應該能夠幫助你理解如何使用 DCGAN 來生成訓練資料,從而避免過擬合問題。

留言板
歡迎留下建議與分享!希望一起交流!感恩!