Created
March 5, 2016 16:49
-
-
Save ltrgoddard/78b4ad7bb8df8b16b00d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# This is a simple script to set up a Twitter 'bot' based on a character-level recurrent neural network. Clone sherjilozair's | |
# char-rnn-tensorflow (https://github.com/sherjilozair/char-rnn-tensorflow) and train it on the material of your choice. | |
# Then drop this script into the main directory, create a Twitter account and Twitter app for the bot and enter the | |
# relevant authentication information at the commented points below. Run this script and whenever somebody | |
# @mentions the bot it will reply with a sample from your neural network. | |
# Louis Goddard <[email protected]> | |
import numpy as np | |
import tensorflow as tf | |
import argparse | |
import time | |
import os | |
import re | |
import cPickle | |
from utils import TextLoader | |
from model import Model | |
from twython import Twython | |
APP_KEY = '' # consumer key | |
APP_SECRET = '' # consumer secret | |
OAUTH_TOKEN = '' # access token | |
OAUTH_TOKEN_SECRET = '' # access token secret | |
twitter = Twython(APP_KEY, APP_SECRET, OAUTH_TOKEN, OAUTH_TOKEN_SECRET) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--save_dir', type=str, default='save', | |
help='model directory to store checkpointed models') | |
parser.add_argument('-n', type=int, default=140, | |
help='number of characters to sample') | |
parser.add_argument('--prime', type=str, default=' ', | |
help='prime text') | |
args = parser.parse_args() | |
sample(args) | |
def sample(args): | |
with open(os.path.join(args.save_dir, 'config.pkl')) as f: | |
saved_args = cPickle.load(f) | |
with open(os.path.join(args.save_dir, 'chars_vocab.pkl')) as f: | |
chars, vocab = cPickle.load(f) | |
model = Model(saved_args, True) | |
with tf.Session() as sess: | |
tf.initialize_all_variables().run() | |
saver = tf.train.Saver(tf.all_variables()) | |
ckpt = tf.train.get_checkpoint_state(args.save_dir) | |
if ckpt and ckpt.model_checkpoint_path: | |
saver.restore(sess, ckpt.model_checkpoint_path) | |
latest = twitter.get_home_timeline(count = 1) | |
ident = latest[0]['id'] | |
while True: | |
mentions = twitter.get_mentions_timeline(contributor_details = True, since_id = ident) | |
for mention in mentions: | |
ident = mention['id'] | |
target = mention['user']['screen_name'] | |
incoming = re.sub('[^A-Za-z0-9]+', '', mention['text']) | |
output = str(model.sample(sess, chars, vocab, len(mention['text'])+(140-len(target)), incoming)) | |
twitter.update_status(status = '@' + target + ' ' + output[len(incoming):139], in_reply_to_status_id = ident) | |
time.sleep(60) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment