Last active
April 13, 2020 09:35
-
-
Save emuccino/bb14e6c18958210d8cdab10ce4f82cc1 to your computer and use it in GitHub Desktop.
Train GAN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import matplotlib.pyplot as plt | |
def train_gan(n_epochs,n_batch,n_plot,n_eval): | |
#discriminator/generator training logs | |
disc_loss_hist = [] | |
gen_loss_hist = [] | |
for epoch in range(n_epochs): | |
if epoch%100 == 0: | |
print(epoch,end=' ') | |
#enable discriminator training | |
discriminator.trainable = True | |
#sample equal portions of real/synthetic data | |
x_real, y_real = generate_real_samples(int(n_batch / 2)) | |
x_synth, y_synth = generate_synthetic_samples(int(n_batch / 2)) | |
x_total = {} | |
for key in x_real.keys(): | |
x_total[key] = np.vstack([x_real[key],x_synth[key]]) | |
y_total = np.vstack([y_real,y_synth]) | |
#train discriminator | |
hist = discriminator.train_on_batch(x_total, y_total) | |
disc_loss_hist.append(hist) | |
discriminator.trainable = False | |
x_gan, y_gan = generate_latent_samples(n_batch) | |
#train generator | |
hist = gan.train_on_batch(x_gan, y_gan) | |
gen_loss_hist.append(hist) | |
#after set number of epochs, evaluate GAN training progress | |
if (epoch+1) % n_eval == 0: | |
print('\n') | |
#pull real and synthetic data to compare distributions and relationships | |
x_real, _ = generate_real_samples(int(n_plot / 2)) | |
x_synth, _ = generate_synthetic_samples(int(n_plot / 2)) | |
for name,n_token in n_tokens.items(): | |
x_real[name] = x_real[name].argmax(1).reshape(-1,1) | |
x_synth[name] = x_synth[name].argmax(1).reshape(-1,1) | |
print('numeric data') | |
for i,name1 in enumerate(numeric_data): | |
print(name1) | |
plt.hist([x_real[name1].flatten(),x_synth[name1].flatten()], | |
bins=16) #compare data distributions | |
plt.legend(['Real','Synthetic']) | |
plt.show() | |
for name2 in numeric_data[i+1:]: | |
print(name1,name2) | |
plt.scatter(x_real[name1],x_real[name2],s=1) #compare data realtionships | |
plt.scatter(x_synth[name1],x_synth[name2],s=1) | |
plt.legend(['Real','Synthetic']) | |
plt.show() | |
print('string data') | |
for i,name1 in enumerate(string_data): | |
print(name1) | |
plt.hist([x_real[name1].flatten(),x_synth[name1].flatten()], | |
bins=n_tokens[name1]) #compare data distributions | |
plt.legend(['Real','Synthetic']) | |
plt.show() | |
for name2 in string_data[i+1:]: | |
print(name1,name2) | |
#create numerical index to represent combinations of tokens | |
lookup = {tup:p for p,tup in enumerate(itertools.product(range(n_tokens[name1]), | |
range(n_tokens[name2])))} | |
hist_real = [lookup[tuple(x)] for x in np.hstack([x_real[name1],x_real[name2]])] | |
hist_synth = [lookup[tuple(x)] for x in np.hstack([x_synth[name1],x_synth[name2]])] | |
plt.hist([hist_real,hist_synth], | |
bins=len(set(hist_real+hist_synth)), | |
color=['blue','orange']) #compare data realtionships | |
plt.legend(['Real','Synthetic']) | |
plt.show() | |
#plot loss history | |
print('loss history') | |
plt.plot(disc_loss_hist,linewidth=2) | |
plt.plot(gen_loss_hist,linewidth=2) | |
plt.legend(['Discriminator','Generator']) | |
plt.show() | |
print('\n') | |
n_epochs = 3000 | |
n_batch = 1024*16 | |
n_eval = 500 | |
n_plot = 2048 | |
train_gan(n_epochs,n_batch,n_plot,n_eval) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment