Skip to content

Instantly share code, notes, and snippets.

@jooni22
Created August 25, 2024 00:44
Show Gist options
  • Save jooni22/1af1cebc6b01784a5920c2710ad0d219 to your computer and use it in GitHub Desktop.
Save jooni22/1af1cebc6b01784a5920c2710ad0d219 to your computer and use it in GitHub Desktop.
Basic execution Dialog Act model from HuggingFace pzelasko/longformer-swda-nolower
import warnings
from transformers import AutoTokenizer, AutoModelForTokenClassification, logging
import torch
from termcolor import colored
# Disable warnings
warnings.filterwarnings("ignore")
# Disable transformers logging
logging.set_verbosity_error()
# Path to the folder containing the model
model_path = "pzelasko/longformer-swda-nolower"
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForTokenClassification.from_pretrained(model_path, ignore_mismatched_sizes=True)
# Example sentence for classification
sentence = """
Hello there! How's it going? You're planning to attend the conference next week, aren't you? Or... What exactly are your plans? I mean, it's not like you have to go, but— Actually, forget I asked.
Oh, you will be there? Fantastic! I'm glad you're attending, right? Well, I think it'll be a great opportunity. The speakers are top-notch, from what I've heard.
By the way, did you manage to book your hotel yet? No? Well, there are still some options available. I could help you find a place if you'd like.
Hmm, let me see... The Hilton is nice, but the Marriott is closer to the venue. Which would you prefer? Or perhaps neither of those?
Oh, I see you're leaning towards the Hilton. That's a good choice. Although, now that I think about it... The Marriott does have that rooftop bar. But who needs a fancy bar, right?
Anyway, as I was saying... 'Location, location, location,' as they always say. So, have you considered the proximity to the conference center?
I'm sorry, I didn't catch that. Could you repeat what you said? Oh, I see. Well, if that's what you prefer, go for it!
Just to summarize, you're going to the conference and staying at the Hilton, correct? Great! I'm sure you'll have a wonderful time.
Oh, before I forget, John wanted me to ask you about the presentation materials. He said something about needing them by Friday?
What's that? You can't get them done by then? I understand, no worries. These things happen. Perhaps we could extend the deadline?
Well, I should probably let you go now. Thanks for chatting! Have a great day! Bye!
"""
# Tokenize the sentence
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
offset_mapping = tokenizer(sentence, return_offsets_mapping=True)['offset_mapping']
# Prediction
with torch.no_grad():
outputs = model(**inputs)
# Get the predicted labels
predictions = torch.argmax(outputs.logits, dim=2)
# Mapping from id to label
id2label = model.config.id2label
# Function to color the labels
def color_label(label):
colors = {
'YES-NO-QUESTION': 'yellow',
'WH-QUESTION': 'yellow',
'DECLARATIVE-WH-QUESTION': 'yellow',
'DECLARATIVE-YES-NO-QUESTION': 'yellow',
'OPEN-QUESTION': 'yellow',
'TAG-QUESTION': 'yellow',
'RHETORICAL-QUESTIONS': 'yellow',
'BACKCHANNEL-IN-QUESTION-FORM': 'yellow',
'YES-ANSWERS': 'green',
'NO-ANSWERS': 'green',
'AFFIRMATIVE-NON-YES-ANSWERS': 'green',
'NEGATIVE-NON-NO-ANSWERS': 'green',
'OTHER-ANSWERS': 'green',
'DISPREFERRED-ANSWERS': 'green',
'CONVENTIONAL-OPENING': 'blue',
'CONVENTIONAL-CLOSING': 'blue',
'HOLD-BEFORE-ANSWER-AGREEMENT': 'blue',
'COLLABORATIVE-COMPLETION': 'blue',
'OR-CLAUSE': 'blue',
'ACKNOWLEDGE-BACKCHANNEL': 'magenta',
'RESPONSE-ACKNOWLEDGEMENT': 'magenta',
'SIGNAL-NON-UNDERSTANDING': 'magenta',
'STATEMENT-OPINION': 'cyan',
'STATEMENT-NON-OPINION': 'cyan',
'THANKING': 'red',
'APOLOGY': 'red',
'APPRECIATION': 'red',
'AGREE-ACCEPT': 'white',
'REJECT': 'white',
'MAYBE-ACCEPT-PART': 'white',
'ACTION-DIRECTIVE': 'light_green',
'OFFERS-OPTIONS-COMMITS': 'light_green',
'QUOTATION': 'light_blue',
'REPEAT-PHRASE': 'light_blue',
'SUMMARIZE/REFORMULATE': 'light_blue',
'HEDGE': 'light_magenta',
'DOWNPLAYER': 'light_magenta',
'NON-VERBAL': 'light_cyan',
'UNINTERPRETABLE': 'light_cyan',
'3RD-PARTY-TALK': 'light_cyan',
'SELF-TALK': 'light_cyan',
}
return colored(label.upper(), colors.get(label.upper(), 'grey'))
# Prepare the results
result = []
current_phrase = ""
current_label = ""
for token, prediction, (start, end) in zip(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]), predictions[0], offset_mapping):
if token.startswith("Ġ") or token in [".", "?", "!"]: # New word or punctuation mark
if current_phrase:
if current_label and current_label not in ["O", "I-"]:
result.append(f"{current_phrase.strip()} {color_label(current_label)}")
else:
result.append(current_phrase.strip())
current_phrase = sentence[start:end].replace("Ġ", " ")
current_label = id2label[prediction.item()]
else:
current_phrase += sentence[start:end]
if current_label in ["O", "I-"]:
current_label = ""
# Add the last phrase
if current_phrase:
if current_label and current_label not in ["O", "I-"]:
result.append(f"{current_phrase.strip()} {color_label(current_label)}")
else:
result.append(current_phrase.strip())
# Display the results
print(" ".join(result))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment