Created
May 13, 2019 09:27
-
-
Save Belrestro/a6b8e5b8d59924ad2dac97d0b4ca5260 to your computer and use it in GitHub Desktop.
train rasa core, and rasa nlu models programmatically rasa.core 0.14.3 nlu 0.15
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rasa_core.agent import Agent | |
from rasa_core.policies import KerasPolicy | |
from rasa_nlu.training_data import load_data | |
from rasa_nlu import config | |
from rasa_nlu.model import Trainer | |
import datetime | |
def _archive_name (type, environment): | |
t = datetime.datetime.now() | |
timestring = t.strftime('%Y%m%d-%H%M%S') | |
return '%s__model_%s' % (prefix, timestring) | |
def train_core (params): | |
domain = params.get('domain') if 'domain' in params else '' | |
stories = params.get('stories') if 'stories' in params else '' | |
environment = params.get('environment') if 'environment' in params else None | |
model_name = _archive_name('core', environment) | |
additional_arguments = { | |
"epochs": 100, | |
"batch_size": 20, | |
"validation_split": 0.1, | |
"augmentation_factor": 50, | |
"debug_plots": True, | |
"max_history": 5 | |
} | |
agent = Agent(domain_path, | |
policies=[KerasPolicy(**additional_arguments)]) | |
training_data = agent.load_data(md_stories_file_path if stories_in_json else stories_path) | |
agent.train(training_data) | |
# persist | |
agent.persist(model_dir) | |
def train_nlu (params): | |
intents_file = params.get('intents') if 'intents' in params else '' | |
config_file = params.get('config') if 'config' in params else {} | |
environment = params.get('environment') if 'environment' in params else None | |
model_name = _archive_name('nlu', environment) | |
model_dir = '%s/%s' % (BASE_DIR, model_name) | |
nlu_config = config.load(intents_file) | |
data = load_data(intents_file) | |
trainer = Trainer(nlu_config) | |
trainer.train(data) | |
trainer.persist(BASE_DIR, project_name= '', fixed_model_name = model_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment