import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.datasets import mnist from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Reshape, Flatten from tensorflow.keras.layers import LeakyReLU from tensorflow.keras.optimizers import Adam # Load MNIST dataset (X_train, _), (_, _) = mnist.load_data() # Normalize dataset X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = X_train.reshape(-1, 784) # Generator model generator = Sequential([ Dense(256, input_dim=100), LeakyReLU(alpha=0.2), Dense(512), LeakyReLU(alpha=0.2), Dense(1024), LeakyReLU(alpha=0.2), Dense(784, activation='tanh'), Reshape((28, 28)) ]) # Discriminator model discriminator = Sequential([ Flatten(input_shape=(28, 28)), Dense(1024), LeakyReLU(alpha=0.2), Dense(512), LeakyReLU(alpha=0.2), Dense(256), LeakyReLU(alpha=0.2), Dense(1, activation='sigmoid') ]) # Compile discriminator discriminator.compile(optimizer=Adam(lr=0.0002, beta_1=0.5), loss='binary_crossentropy', metrics=['accuracy']) # Combined model (generator -> discriminator) discriminator.trainable = False gan = Sequential([generator, discriminator]) gan.compile(optimizer=Adam(lr=0.0002, beta_1=0.5), loss='binary_crossentropy') # Train GAN batch_size = 64 epochs = 30000 for epoch in range(epochs): # Train discriminator noise = np.random.normal(0, 1, (batch_size, 100)) gen_imgs = generator.predict(noise) idx = np.random.randint(0, X_train.shape[0], batch_size) real_imgs = X_train[idx] d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((batch_size, 1))) d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1))) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # Train generator noise = np.random.normal(0, 1, (batch_size, 100)) g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1))) # Print progress if epoch % 100 == 0: print(f"Epoch: {epoch}, Discriminator Loss: {d_loss[0]}, Generator Loss: {g_loss}") # Generate images noise = np.random.normal(0, 1, (10, 100)) gen_imgs = generator.predict(noise) # Plot generated images plt.figure(figsize=(10, 10)) for i in range(10): plt.subplot(1, 10, i+1) plt.imshow(gen_imgs[i], cmap='gray') plt.axis('off') plt.show()