103 lines
3.9 KiB
Python
103 lines
3.9 KiB
Python
import tensorflow as tf
|
|
import glob
|
|
import numpy as np
|
|
import os
|
|
from tensorflow.keras import layers, models
|
|
import time
|
|
import cv2
|
|
import model
|
|
|
|
batch_size=256
|
|
epochs=500
|
|
noise_dim=100
|
|
tab_size=5
|
|
num_examples_to_generate=tab_size*tab_size
|
|
dir_images='images_gan'
|
|
checkpoint_dir='./training_checkpoints_gan'
|
|
checkpoint_prefix=os.path.join(checkpoint_dir, "ckpt")
|
|
|
|
if not os.path.isdir(dir_images):
|
|
os.mkdir(dir_images)
|
|
|
|
(train_images, train_labels), (test_images, test_labels)=tf.keras.datasets.mnist.load_data()
|
|
|
|
train_images=train_images.reshape(-1, 28, 28, 1).astype('float32')
|
|
train_images=(train_images-127.5)/127.5
|
|
|
|
train_dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(len(train_images)).batch(batch_size)
|
|
|
|
def discriminator_loss(real_output, fake_output):
|
|
real_loss=cross_entropy(tf.ones_like(real_output), real_output)
|
|
fake_loss=cross_entropy(tf.zeros_like(fake_output), fake_output)
|
|
total_loss=real_loss+fake_loss
|
|
return total_loss
|
|
|
|
def generator_loss(fake_output):
|
|
return cross_entropy(tf.ones_like(fake_output), fake_output)
|
|
|
|
generator=model.generator_model()
|
|
discriminator=model.discriminator_model()
|
|
|
|
cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)
|
|
|
|
train_generator_loss=tf.keras.metrics.Mean()
|
|
train_discriminator_loss=tf.keras.metrics.Mean()
|
|
|
|
generator_optimizer=tf.keras.optimizers.Adam(1e-4)
|
|
discriminator_optimizer=tf.keras.optimizers.Adam(1e-4)
|
|
|
|
checkpoint=tf.train.Checkpoint(generator_optimizer=generator_optimizer,
|
|
discriminator_optimizer=discriminator_optimizer,
|
|
generator=generator,
|
|
discriminator=discriminator)
|
|
|
|
seed=tf.random.normal([num_examples_to_generate, noise_dim])
|
|
|
|
@tf.function
|
|
def train_step(images):
|
|
noise=tf.random.normal([batch_size, noise_dim])
|
|
|
|
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
|
|
generated_images=generator(noise, training=True)
|
|
|
|
real_output=discriminator(images, training=True)
|
|
fake_output=discriminator(generated_images, training=True)
|
|
|
|
gen_loss=generator_loss(fake_output)
|
|
disc_loss=discriminator_loss(real_output, fake_output)
|
|
|
|
train_generator_loss(gen_loss)
|
|
train_discriminator_loss(disc_loss)
|
|
|
|
gradients_of_generator=gen_tape.gradient(gen_loss, generator.trainable_variables)
|
|
gradients_of_discriminator=disc_tape.gradient(disc_loss, discriminator.trainable_variables)
|
|
|
|
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
|
|
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
|
|
|
|
def train(dataset, epochs):
|
|
for epoch in range(epochs):
|
|
start=time.time()
|
|
for image_batch in dataset:
|
|
train_step(image_batch)
|
|
generate_and_save_images(generator, epoch+1, seed)
|
|
if (epoch+1)%15==0:
|
|
checkpoint.save(file_prefix=checkpoint_prefix)
|
|
print ('Epoch {}: loss generator: {:.4f} loss discriminator: {:.4f} {:.4f} sec'.format(epoch+1,
|
|
train_generator_loss.result(),
|
|
train_discriminator_loss.result(),
|
|
time.time()-start))
|
|
train_generator_loss.reset_states()
|
|
train_discriminator_loss.reset_states()
|
|
|
|
def generate_and_save_images(model, epoch, test_input):
|
|
labels=tf.one_hot(tf.range(0, num_examples_to_generate, 1)%10, 10)
|
|
predictions=model([test_input, labels], training=False)
|
|
img=np.empty(shape=(tab_size*28, tab_size*28), dtype=np.float32)
|
|
for i in range(tab_size):
|
|
for j in range(tab_size):
|
|
img[j*28:(j+1)*28, i*28:(i+1)*28]=predictions[j*tab_size+i, :, :, 0]*127.5+127.5
|
|
cv2.imwrite('{}/image_{:04d}.png'.format(dir_images, epoch), img)
|
|
|
|
train(train_dataset, epochs)
|