好的,這是一個使用 PyTorch 實現深度卷積生成對抗網絡(DCGAN)的完整範例。這個範例展示了如何構建和訓練 DCGAN 來生成手寫數字(例如 MNIST 數據集)。你可以根據自己的數據集進行修改。

1. 安裝 PyTorch
首先,確保已經安裝 PyTorch。如果還沒有,請執行以下命令:

 

pip install torch torchvision

 

 


2. 準備資料
我們使用 MNIST 數據集作為範例:

 

import torch
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

# 設定參數
batch_size = 128
image_size = 64

# 轉換函數
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 加載 MNIST 資料集
dataset = MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

 

3. 定義模型
我們定義生成器和判別器模型:

 

import torch.nn as nn

# 生成器模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# 判別器模型
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

 


4. 訓練模型
定義損失函數和優化器,並訓練 DCGAN:

import torch.optim as optim

# 設定裝置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 創建模型實例
netG = Generator().to(device)
netD = Discriminator().to(device)

# 損失函數
criterion = nn.BCELoss()

# 噪聲向量
fixed_noise = torch.randn(64, 100, 1, 1, device=device)

# 優化器
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 訓練DCGAN
num_epochs = 5
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) 更新 D: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## 使用真實數據訓練 D
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        labels = torch.full((batch_size,), 1, device=device)
        output = netD(real_cpu)
        errD_real = criterion(output, labels)
        errD_real.backward()
        D_x = output.mean().item()

        ## 使用假數據訓練 D
        noise = torch.randn(batch_size, 100, 1, 1, device=device)
        fake = netG(noise)
        labels.fill_(0)
        output = netD(fake.detach())
        errD_fake = criterion(output, labels)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) 更新 G: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        labels.fill_(1)
        output = netD(fake)
        errG = criterion(output, labels)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # 輸出訓練狀況
        if i % 100 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')

# 保存模型
torch.save(netG.state_dict(), 'generator.pth')
torch.save(netD.state_dict(), 'discriminator.pth')

5. 生成影像

訓練完成後,你可以使用生成器來生成新影像:

import matplotlib.pyplot as plt
import numpy as np

# 生成影像
netG.eval()
with torch.no_grad():
    noise = torch.randn(64, 100, 1, 1, device=device)
    fake = netG(noise).detach().cpu()

# 顯示生成影像
def show_generated_images(images):
    fig, ax = plt.subplots(8, 8, figsize=(8, 8))
    for i in range(8):
        for j in range(8):
            ax[i, j].imshow(images[i*8 + j].permute(1, 2, 0).squeeze(), cmap='gray')
            ax[i, j].axis('off')
    plt.show()

# 展示生成影像
show_generated_images(fake)

 

 

 

這是一個基本的 DCGAN 範例,用於生成 MNIST 手寫數字。你可以根據自己的數據集進行修改,包括調整網絡架構、超參數和訓練流程。

這個範例應該能夠幫助你理解如何使用 DCGAN 來生成訓練資料,從而避免過擬合問題。