DCGAN
DCGAN,即deep convolutional GAN的簡寫數據集
CIFAR10為32x32x3的照片,並擁有10個輸出類別,每個類別5,000張,共計50,000張照片範例說明
為了簡化範例,單純的使用一個類別做測試這邊會著重在實作,不會有過深的理論說明
作業開始
GAN由Generator與Discriminator兩個nn
結合而成,迭代過程大致如下說明:- 初始化Generator-v1與Discriminator-v1
- 空間中sample出一筆資料,經過Generator-v1生成
- Discriminator-v1驗證真假,發現是假
- Generator-v1升級為Generator-v2
- 成功騙過Discriminator
- Discriminator-v1升級為Discriminator-v2
- Discriminator-v2驗證真假,發現是假
- …
import os
# 限制gpu資源
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
首先我們建置Generator載入建置Generator的需求套件
import keras
from keras import layers
import numpy as np
這是後面保有照片會用到的套件from keras.preprocessing import image
定義dimension,CIFAR10的照片維度為32x32x3Generator會從空間中sample出一個點,那個點是一個Vector,在下面範例中即是laten_dim的設置
# CIFAR10資料維度
height = 32
width = 32
channel =3
# 預計生成資料維度
laten_dim = 32
- W:輸入維度
- F:filter size
- S:stride
- P:padding size
- N:輸入維度
layers.Conv2D
:卷積,經過卷積之後的維度計算如下:layers.Conv2DTranspose
:反卷積,經過反卷積之後的維度計算如下:generator_input = keras.Input(shape=(laten_dim, ))
x = layers.Dense(128 * 16 * 16)(generator_input)
x = layers.LeakyReLU()(x)
x = layers.Reshape((16, 16, 128))(x)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
# 輸出為32x32
x = layers.Conv2DTranspose(256, 4, strides=2, padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
# 圖片channel設置為3,即輸出為32x32x3
x = layers.Conv2D(channel, 7, activation='tanh', padding='same')(x)
generator = keras.models.Model(generator_input, x)
利用Model.summary()
來確認模型的資料維度變化generator.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) (None, 32) 0
_________________________________________________________________
dense_2 (Dense) (None, 32768) 1081344
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 32768) 0
_________________________________________________________________
reshape_2 (Reshape) (None, 16, 16, 128) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 16, 16, 256) 819456
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 16, 16, 256) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 32, 32, 256) 1048832
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 32, 32, 256) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 32, 32, 256) 1638656
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU) (None, 32, 32, 256) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 32, 32, 256) 1638656
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU) (None, 32, 32, 256) 0
_________________________________________________________________
conv2d_8 (Conv2D) (None, 32, 32, 3) 37635
=================================================================
Total params: 6,264,579
Trainable params: 6,264,579
Non-trainable params: 0
_________________________________________________________________
Generator完成之後我們要架構Discriminator,Discriminator的作用就是判斷給定的資料是真或假,因此它的輸入維度即照片維度(範例為32x32x3)discriminator_input = layers.Input(shape=(height, width, channel))
x = layers.Conv2D(128, 3)(discriminator_input)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.5)(x)
# 判斷真假,因此輸出為1個unit,並搭配sigmoid
x = layers.Dense(1, activation='sigmoid')(x)
discriminator = keras.models.Model(discriminator_input, x)
利用Model.summary()
確認模型維度變化discriminator.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) (None, 32, 32, 3) 0
_________________________________________________________________
conv2d_9 (Conv2D) (None, 30, 30, 128) 3584
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU) (None, 30, 30, 128) 0
_________________________________________________________________
conv2d_10 (Conv2D) (None, 14, 14, 128) 262272
_________________________________________________________________
leaky_re_lu_12 (LeakyReLU) (None, 14, 14, 128) 0
_________________________________________________________________
conv2d_11 (Conv2D) (None, 6, 6, 128) 262272
_________________________________________________________________
leaky_re_lu_13 (LeakyReLU) (None, 6, 6, 128) 0
_________________________________________________________________
conv2d_12 (Conv2D) (None, 2, 2, 128) 262272
_________________________________________________________________
leaky_re_lu_14 (LeakyReLU) (None, 2, 2, 128) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 512) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 512) 0
_________________________________________________________________
dense_3 (Dense) (None, 1) 513
=================================================================
Total params: 790,913
Trainable params: 790,913
Non-trainable params: 0
_________________________________________________________________
定義模型最佳化方式範例使用
RMSprop
做為最佳化的方式,透過clipvalue
來限制梯度範圍discriminator_optimizer = keras.optimizers.RMSprop(
lr = 0.0008,
clipvalue=1.0,
decay=1e-8
)
discriminator.compile(optimizer=discriminator_optimizer,
loss='binary_crossentropy')
Generator與Discriminator都已經設置好了,現在我們要讓他們對抗,幾個點簡單說明:- 原始GAN中雖然提到Generator要訓練多次,但實際上Ian Goodfellow只訓練一次而以
- 訓練Generator的時候Discriminator是凍結的
discriminator.trainable = False
gan model的input是一開始所設置的laten_dim
,而output的部份則是generator model
的output給discriminator model
的input所做的判斷,即判斷generator model
所生成的資料是真還是假gan_input = keras.Input(shape=(laten_dim, ))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)
設置gan model
的最佳化方式gan_optimizer = keras.optimizers.RMSprop(
lr=0.0004,
clipvalue=1.0,
decay=1e-8
)
gan.compile(optimizer=gan_optimizer,
loss='binary_crossentropy')
gan.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_5 (InputLayer) (None, 32) 0
_________________________________________________________________
model_2 (Model) (None, 32, 32, 3) 6264579
_________________________________________________________________
model_3 (Model) (None, 1) 790913
=================================================================
Total params: 7,055,492
Trainable params: 6,264,579
Non-trainable params: 790,913
_________________________________________________________________
現在我們可以開始來訓練模型,上面提到,我們會從空間隨機sample出資料,經過generator model
來生成,這筆資料當然假資料,其標記為Negative
,另外我們也會有真實的資料,其標記為Positive
。首先我們先下載keras自帶的資料集,
CIFAR10
:- 0: airplane
- 1: automobile
- 2: bird
- 3: cat
- 4: deer
- 5: dog
- 6: frog
- 7: horse
- 8: ship
- 9: truck
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 109s 1us/step
任何資料集都請務必先行驗證確認資料維度x_train.shape, y_train.shape
((50000, 32, 32, 3), (50000, 1))
下載下來的照片資料,我們只會採用其中一個類別來簡化範例的進行以
y_train.flatten() == 5
為mask來取得資料x_train = x_train[y_train.flatten() == 5]
調整之後確認資料維度x_train.shape
(5000, 32, 32, 3)
後面不需要太過在意y_train
,因為對我們目前需求來說,所有真實資料皆為Positive
x_train.dtype
dtype('uint8')
資料標準化x_train = x_train / 255.
x_train.dtype
dtype('float64')
轉換型別,理論上使用float32
足矣x_train = x_train.astype(np.float32)
x_train.dtype
dtype('float32')
設置幾個簡單參數# 預計執行迭代次數
iterations = 10000
# 每次處理數量
batch_size = 20
# 預計保存生成照片資料集,依個人需求設置
folder_path = '/tf/GAN/DCGAN'
前置作業完畢之後就可以正式的來訓練模型# 記錄起始索引的參數
idx_start = 0
for i in range(iterations):
# 從高斯分佈空間中隨機sample出資料點
random_vectors = np.random.normal(size=(batch_size, laten_dim))
# 從generator model取得output
# 依我們所設置的模型,會回傳Nx32x32x3的資料
generator_img = generator.predict(random_vectors)
# 目前的結束索引
idx_stop = idx_start + batch_size
# 利用索引取得相對應的真實照片
true_imgs = x_train[idx_start: idx_stop]
# 將真假堆疊在一起
combian_imgs = np.concatenate([generator_img, true_imgs])
# 產生相對應的label,生成為假:1,真實照片:0
labels = np.concatenate([np.ones((batch_size, 1)),
np.zeros((batch_size, 1))])
# 在label中加入隨機的噪點,很多事不能說明為什麼
# 跟著打牌就不會放槍就是了
labels += 0.01 * np.random.random(labels.shape)
# 目前已經隨機生成照片,要拿這些照片來訓練discriminator
discriminator_loss = discriminator.train_on_batch(combian_imgs, labels)
# 現在,discriminator已經知道這些是假照片,因此要更新generator
# 接下來的生成照片,印象中李弘毅老師的線上課程中有提到,要不要重新生成都可以
random_vectors = np.random.normal(size=(batch_size, laten_dim))
# 產生上面新生成照片的labels
# 這邊設置的標籤是『Positive』,這是因為我們要欺騙discriminator
# 讓discriminator感覺這個generator生成的照片是真的
# 這個過程之後generator就升級了
generator_labels = np.zeros((batch_size, 1))
# 訓練gan model,記得這時候訓練的是我們凍結discriminator的模型-gan
gan_loss = gan.train_on_batch(random_vectors, generator_labels)
# 更新索引
idx_start = idx_start + batch_size
# 判斷索引是否超過資料集索引,超過的話就重新計算
# 這邊也要特別注意到資料集數量與batch_size的設置關係,要可以整除
if idx_start > len(x_train) - batch_size:
idx_start = 0
# 這邊每100次迭代記錄一次
if i % 100 == 0:
gan.save_weights('gan.h5')
print('discriminator loss: ', discriminator_loss)
print('gan loss: ', gan_loss)
img = image.array_to_img(generator_img[0] * 255., scale=False)
img.save(os.path.join(folder_path, 'epoch-' + str(i) + '-generator.jpg'))
img = image.array_to_img(true_imgs[0] * 255., scale=False)
img.save(os.path.join(folder_path, 'epoch-' + str(i) + '-true_image.jpg'))
discriminator loss: 5.1067543
gan loss: 13.483853
discriminator loss: 0.55595607
gan loss: 0.9773754
discriminator loss: 0.6778943
gan loss: 0.7218493
discriminator loss: 0.7140575
gan loss: 0.74051535
discriminator loss: 0.6997445
gan loss: 0.74154454
discriminator loss: 0.6921107
gan loss: 0.7020267
discriminator loss: 0.69686
gan loss: 0.7030369
discriminator loss: 0.6800611
gan loss: 0.6471882
discriminator loss: 0.6967187
gan loss: 0.77145576
discriminator loss: 0.7020233
gan loss: 0.71897966
discriminator loss: 0.7126138
gan loss: 0.73591757
discriminator loss: 0.7110381
gan loss: 0.7163498
discriminator loss: 0.6996112
gan loss: 0.71271735
discriminator loss: 0.73551977
gan loss: 0.6986626
discriminator loss: 0.67326313
gan loss: 0.71304166
discriminator loss: 0.7034636
gan loss: 1.0788002
discriminator loss: 0.7141048
gan loss: 0.7726129
discriminator loss: 0.696671
gan loss: 0.720789
discriminator loss: 0.7359967
gan loss: 0.7729639
discriminator loss: 0.69505966
gan loss: 0.65915006
discriminator loss: 0.67265576
gan loss: 0.89056253
discriminator loss: 0.70829356
gan loss: 0.6985985
discriminator loss: 0.6924041
gan loss: 0.6630027
discriminator loss: 0.71791756
gan loss: 0.6948942
discriminator loss: 0.7060956
gan loss: 0.70852387
discriminator loss: 0.69083965
gan loss: 0.7589138
discriminator loss: 0.7217103
gan loss: 0.7043828
discriminator loss: 0.69030714
gan loss: 0.7687561
discriminator loss: 0.7466472
gan loss: 0.83582276
discriminator loss: 0.6909427
gan loss: 0.75084436
discriminator loss: 0.692063
gan loss: 0.66303813
discriminator loss: 0.6578157
gan loss: 0.7915683
discriminator loss: 0.660645
gan loss: 0.717038
discriminator loss: 0.6897063
gan loss: 0.71162593
discriminator loss: 0.72813225
gan loss: 0.7520748
discriminator loss: 0.70937884
gan loss: 0.9043194
discriminator loss: 0.6839577
gan loss: 0.72316694
discriminator loss: 0.68821317
gan loss: 0.6758722
discriminator loss: 0.66450214
gan loss: 0.7818069
discriminator loss: 0.68543863
gan loss: 0.7772101
discriminator loss: 0.72352195
gan loss: 1.1552292
discriminator loss: 0.66245687
gan loss: 0.8431295
discriminator loss: 0.64920413
gan loss: 0.6762968
discriminator loss: 0.6851862
gan loss: 0.6802823
discriminator loss: 0.6546106
gan loss: 0.8981015
discriminator loss: 0.66509783
gan loss: 0.8244335
discriminator loss: 0.71177804
gan loss: 0.64901894
discriminator loss: 0.6947624
gan loss: 0.6984755
discriminator loss: 0.70206606
gan loss: 0.8204377
discriminator loss: 0.7477706
gan loss: 0.9481209
discriminator loss: 0.67449886
gan loss: 0.83607197
discriminator loss: 0.6896846
gan loss: 0.6858331
discriminator loss: 0.6856981
gan loss: 0.71470636
discriminator loss: 0.7168747
gan loss: 0.58466285
discriminator loss: 0.69816226
gan loss: 0.76343626
discriminator loss: 0.7020701
gan loss: 0.8184387
discriminator loss: 0.64522004
gan loss: 0.6753376
discriminator loss: 0.6865121
gan loss: 0.7008316
discriminator loss: 0.7476796
gan loss: 0.87338126
discriminator loss: 0.7263964
gan loss: 0.6889228
discriminator loss: 0.70218945
gan loss: 0.6831338
discriminator loss: 0.6857745
gan loss: 0.6565727
discriminator loss: 0.68626153
gan loss: 0.6991853
discriminator loss: 0.68162763
gan loss: 0.74716157
discriminator loss: 0.7115513
gan loss: 0.7860044
discriminator loss: 0.7063864
gan loss: 0.74256736
discriminator loss: 0.6947535
gan loss: 0.73066074
discriminator loss: 0.70494205
gan loss: 0.64904773
discriminator loss: 0.69508123
gan loss: 0.72680366
discriminator loss: 0.6899625
gan loss: 0.81384104
discriminator loss: 0.7159114
gan loss: 0.7416827
discriminator loss: 0.6912014
gan loss: 0.7360595
discriminator loss: 0.72471696
gan loss: 0.6513818
discriminator loss: 0.6992035
gan loss: 0.6941215
discriminator loss: 0.6863073
gan loss: 0.7213627
discriminator loss: 0.6924833
gan loss: 0.71849513
discriminator loss: 0.6969613
gan loss: 0.72678196
discriminator loss: 0.685106
gan loss: 0.7152079
discriminator loss: 0.68887657
gan loss: 0.73665506
discriminator loss: 0.71015817
gan loss: 0.732918
discriminator loss: 0.67974985
gan loss: 0.8060713
discriminator loss: 0.693206
gan loss: 0.67660487
discriminator loss: 0.68673164
gan loss: 0.7380837
discriminator loss: 0.6858455
gan loss: 0.7205802
discriminator loss: 0.69764405
gan loss: 0.7232281
discriminator loss: 0.7109472
gan loss: 0.66223055
discriminator loss: 0.707724
gan loss: 1.1661718
discriminator loss: 0.6913794
gan loss: 0.6941385
discriminator loss: 0.69227374
gan loss: 0.7601951
discriminator loss: 0.6878415
gan loss: 0.7479371
discriminator loss: 0.71728295
gan loss: 0.83168066
discriminator loss: 0.6864918
gan loss: 0.7533535
discriminator loss: 0.7127117
gan loss: 0.6897801
discriminator loss: 0.69414794
gan loss: 0.7193168
discriminator loss: 0.67149776
gan loss: 0.7141426
discriminator loss: 0.6828767
gan loss: 0.7632409
discriminator loss: 0.67732584
gan loss: 0.7958939
discriminator loss: 0.69132864
gan loss: 0.721274
discriminator loss: 0.71267766
gan loss: 0.71191424
discriminator loss: 0.6940938
gan loss: 0.8811827
確認生成照片
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
img_path = 'epoch-9600-generator.jpg'
gen_img = mpimg.imread(img_path)
plt.imshow(gen_img)
像哈巴狗?
柯基與柴犬?
結論
範例可以看的出來,在Keras的高階api協助之下,實作GAN並不是那麼樣子的困難,但GAN的訓練有許多比程式碼還要困難的部份(上面結果可以發現,優化的並不是那麼好),不管是Generator還是Discriminator,過強或過弱對模型來說都不是好事,許多情況下GAN難以訓練,需要的是參數上的不斷調校,唯有不斷的踩雷才能擁有足夠的經驗。
後記:後面繼續訓練約50,000次迭代之後就整個壞掉了!
延伸
單純保存一張照片似乎較難以判斷生成狀況,下面function將圖片串接起來
直接使用keras自帶工具來處理照片保存的作業
import numpy as np
from keras.preprocessing import image
def save_generator_img(generator_img, h_int, w_int, epoch, save_file_name='generator', save_dir='.'):
"""保存generator生成的照片
function:
利用keras自帶工具keras.preprocessing.Image保存照片
parameters:
generator_img: generator生成的numpy object
h_int: 湊成一張照片的時候高要幾張
w_ing: 湊成一張照片的時候寬要幾張
epoch: 第幾次的迭代,保存照片的時候插入檔案名稱
save_file_name: 檔案名稱,預設為generator
save_dir: 保存的路徑,預設在執行程式的根目錄
remark:
h_int x w_int 不能超過generator_img的長度,一定只能等於
example:
save_generator_img(generator_img, 4, 5, 9900)
"""
# 取得資料維度
N, H, W, C = generator_img.shape
# 驗證拼湊照片數量相符
assert int(h_int) * int(w_int) == N
# 開一張全zero的大陣列
target_img = np.zeros(shape=(h_int * H, w_int * W, C))
# 索引
now_idx = 0
for h_idx in range(h_int):
for w_idx in range(w_int):
# 取得照片
_img = generator_img[now_idx]
# 取代相對應陣列位置
target_img[h_idx * H: h_idx * H + H, w_idx * W: w_idx * W + W] = _img
now_idx += 1
file_name = os.path.join(save_dir, save_file_name + str(epoch) + '.png')
save_img = image.array_to_img(target_img * 255., scale=False)
save_img.save(file_name)
沒有留言:
張貼留言