1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
|
def train(dataset,epochs):
for _ in range(epochs):
for images in dataset:
images=tf.cast(images,tf.dtypes.float32)
train_step(images)
def train_step(images):
fake_image_noise=np.random.randn(BATCH_SIZE,100).astype("float32")
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images=generator(fake_image_noise)
real_output = model_discriminator(images)
fake_output = model_discriminator(generated_images)
gen_loss = generator_loss(fake_output)
disc_loss = get_discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, model_discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator,generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator,model_discriminator.trainable_variables))
print('generator_loss :', np.mean(gen_loss))
print('discriminator loss:',np.mean(disc_loss)) |
Partager