Last active
April 13, 2020 08:28
-
-
Save emuccino/006d61fdf1af563f7e2ea80517663b8a to your computer and use it in GitHub Desktop.
Train classifiers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#function for building classifier (very similar to the discriminator) | |
def compile_classifier(): | |
inputs = {} | |
numeric_nets = [] | |
string_nets = [] | |
for name in numeric_data: | |
numeric_input = Input(shape=(1,),name=name) | |
inputs[name] = numeric_input | |
numeric_net = GaussianNoise(0.01)(numeric_input) | |
numeric_nets.append(numeric_net) | |
for name,n_token in n_tokens.items(): | |
string_input = Input(shape=(n_token,),name=name) | |
inputs[name] = string_input | |
string_net = GaussianNoise(0.05)(string_input) | |
string_net = Dense(n_embeddings[name],activation='relu',kernel_initializer='he_uniform')(string_net) | |
string_nets.append(string_net) | |
string_nets = Concatenate()(string_nets) | |
string_nets = BatchNormalization()(string_nets) | |
string_nets = [Dense(len(string_data),activation='relu', | |
kernel_initializer='he_uniform')(string_nets)] | |
net = Concatenate()(numeric_nets + string_nets) | |
net = BatchNormalization()(net) | |
for _ in range(4): | |
net = Dense(128, activation='relu', | |
kernel_initializer='he_uniform')(net) | |
net = BatchNormalization()(net) | |
outputs = Dense(2, activation='softmax', | |
kernel_initializer='glorot_uniform')(net) | |
classifier = Model(inputs=inputs, outputs=outputs) | |
classifier.compile(loss='categorical_crossentropy', | |
optimizer=Nadam(clipnorm=1.), metrics=['categorical_accuracy']) | |
return classifier | |
#classifier to be trained on real data | |
classifier = compile_classifier() | |
#classifier to be trained on combination of real and synthetic data | |
gan_classifier = compile_classifier() | |
batch_size = 512 | |
#train classifer with real data for 1000 epochs | |
for _ in range(1000): | |
x_real, y_real = generate_real_samples(batch_size) | |
classifier.train_on_batch(x_real,y_real[:,1:]) | |
#train classifer with real and synthetic data for 1000 epochs | |
for _ in range(1000): | |
#split batch into half real and half synthetic data | |
x_real, y_real = generate_real_samples(batch_size//2) | |
x_synth, y_synth = generate_synthetic_samples(batch_size//2) | |
x_total = {} | |
for key in x_real.keys(): | |
x_total[key] = np.vstack([x_real[key],x_synth[key]]) | |
y_total = np.vstack([y_real,y_synth]) | |
gan_classifier.train_on_batch(x_total,y_total[:,1:]) | |
#setup of test data for evaluating classifier results | |
test_inputs = {} | |
for name in numeric_data: | |
test_inputs[name] = test_df[[name]].values | |
for name in string_data: | |
test_inputs[name] = to_categorical(test_df[name].values,n_tokens[name]) | |
test_outputs = to_categorical(test_target_df[target].values,2) | |
classifier_eval = classifier.evaluate(test_inputs,test_outputs) | |
print('classifier accuracy:',classifier_eval[1]) | |
gan_classifier_eval = gan_classifier.evaluate(test_inputs,test_outputs) | |
print('gan classifier accuracy:',gan_classifier_eval[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment