Created
May 29, 2019 15:02
-
-
Save gorzechowski/939af6778dd364df8e7c946e379eb898 to your computer and use it in GitHub Desktop.
Simple text classification using tflearn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow | |
import tflearn | |
import nltk | |
from nltk.stem.lancaster import LancasterStemmer | |
stemmer = LancasterStemmer() | |
intents = ( | |
('greeting', 'hi'), | |
('greeting', 'hi there'), | |
('greeting', 'hello'), | |
('farewell', 'good bye'), | |
('farewell', 'bye'), | |
('farewell', 'see you'), | |
) | |
def process_sentence(sentence): | |
tokens = nltk.word_tokenize(sentence) | |
return [stemmer.stem(token.lower()) for token in tokens] | |
def normalize_sentence(sentence, words): | |
tokens = process_sentence(sentence) | |
result = [] | |
for word in words: | |
result.append(1 if word in tokens else 0) | |
return result | |
classes = sorted(list(set([intent[0] for intent in intents]))) | |
words = [] | |
for _, sentence in intents: | |
tokens = process_sentence(sentence) | |
words.extend(tokens) | |
words = sorted(list(set(words))) | |
X = [] | |
Y = [] | |
for class_name, sentence in intents: | |
input_values = normalize_sentence(sentence, words) | |
output_values = [0] * len(classes) | |
output_values[classes.index(class_name)] = 1 | |
X.append(input_values) | |
Y.append(output_values) | |
tensorflow.reset_default_graph() | |
model = tflearn.input_data(shape=[None, len(X[0])]) | |
model = tflearn.fully_connected(model, 64) | |
model = tflearn.fully_connected(model, 64) | |
model = tflearn.fully_connected(model, len(Y[0]), activation='softmax') | |
model = tflearn.regression(model) | |
model = tflearn.DNN(model) | |
model.fit(X, Y, n_epoch=1000, batch_size=8, show_metric=False) | |
def predict(sentence): | |
input = normalize_sentence(sentence, words) | |
return model.predict([input]) | |
print(list(zip(classes, predict("Hello there")[0]))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment