Created
August 25, 2024 00:44
-
-
Save jooni22/1af1cebc6b01784a5920c2710ad0d219 to your computer and use it in GitHub Desktop.
Basic execution Dialog Act model from HuggingFace pzelasko/longformer-swda-nolower
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import warnings | |
from transformers import AutoTokenizer, AutoModelForTokenClassification, logging | |
import torch | |
from termcolor import colored | |
# Disable warnings | |
warnings.filterwarnings("ignore") | |
# Disable transformers logging | |
logging.set_verbosity_error() | |
# Path to the folder containing the model | |
model_path = "pzelasko/longformer-swda-nolower" | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForTokenClassification.from_pretrained(model_path, ignore_mismatched_sizes=True) | |
# Example sentence for classification | |
sentence = """ | |
Hello there! How's it going? You're planning to attend the conference next week, aren't you? Or... What exactly are your plans? I mean, it's not like you have to go, but— Actually, forget I asked. | |
Oh, you will be there? Fantastic! I'm glad you're attending, right? Well, I think it'll be a great opportunity. The speakers are top-notch, from what I've heard. | |
By the way, did you manage to book your hotel yet? No? Well, there are still some options available. I could help you find a place if you'd like. | |
Hmm, let me see... The Hilton is nice, but the Marriott is closer to the venue. Which would you prefer? Or perhaps neither of those? | |
Oh, I see you're leaning towards the Hilton. That's a good choice. Although, now that I think about it... The Marriott does have that rooftop bar. But who needs a fancy bar, right? | |
Anyway, as I was saying... 'Location, location, location,' as they always say. So, have you considered the proximity to the conference center? | |
I'm sorry, I didn't catch that. Could you repeat what you said? Oh, I see. Well, if that's what you prefer, go for it! | |
Just to summarize, you're going to the conference and staying at the Hilton, correct? Great! I'm sure you'll have a wonderful time. | |
Oh, before I forget, John wanted me to ask you about the presentation materials. He said something about needing them by Friday? | |
What's that? You can't get them done by then? I understand, no worries. These things happen. Perhaps we could extend the deadline? | |
Well, I should probably let you go now. Thanks for chatting! Have a great day! Bye! | |
""" | |
# Tokenize the sentence | |
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=512) | |
offset_mapping = tokenizer(sentence, return_offsets_mapping=True)['offset_mapping'] | |
# Prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get the predicted labels | |
predictions = torch.argmax(outputs.logits, dim=2) | |
# Mapping from id to label | |
id2label = model.config.id2label | |
# Function to color the labels | |
def color_label(label): | |
colors = { | |
'YES-NO-QUESTION': 'yellow', | |
'WH-QUESTION': 'yellow', | |
'DECLARATIVE-WH-QUESTION': 'yellow', | |
'DECLARATIVE-YES-NO-QUESTION': 'yellow', | |
'OPEN-QUESTION': 'yellow', | |
'TAG-QUESTION': 'yellow', | |
'RHETORICAL-QUESTIONS': 'yellow', | |
'BACKCHANNEL-IN-QUESTION-FORM': 'yellow', | |
'YES-ANSWERS': 'green', | |
'NO-ANSWERS': 'green', | |
'AFFIRMATIVE-NON-YES-ANSWERS': 'green', | |
'NEGATIVE-NON-NO-ANSWERS': 'green', | |
'OTHER-ANSWERS': 'green', | |
'DISPREFERRED-ANSWERS': 'green', | |
'CONVENTIONAL-OPENING': 'blue', | |
'CONVENTIONAL-CLOSING': 'blue', | |
'HOLD-BEFORE-ANSWER-AGREEMENT': 'blue', | |
'COLLABORATIVE-COMPLETION': 'blue', | |
'OR-CLAUSE': 'blue', | |
'ACKNOWLEDGE-BACKCHANNEL': 'magenta', | |
'RESPONSE-ACKNOWLEDGEMENT': 'magenta', | |
'SIGNAL-NON-UNDERSTANDING': 'magenta', | |
'STATEMENT-OPINION': 'cyan', | |
'STATEMENT-NON-OPINION': 'cyan', | |
'THANKING': 'red', | |
'APOLOGY': 'red', | |
'APPRECIATION': 'red', | |
'AGREE-ACCEPT': 'white', | |
'REJECT': 'white', | |
'MAYBE-ACCEPT-PART': 'white', | |
'ACTION-DIRECTIVE': 'light_green', | |
'OFFERS-OPTIONS-COMMITS': 'light_green', | |
'QUOTATION': 'light_blue', | |
'REPEAT-PHRASE': 'light_blue', | |
'SUMMARIZE/REFORMULATE': 'light_blue', | |
'HEDGE': 'light_magenta', | |
'DOWNPLAYER': 'light_magenta', | |
'NON-VERBAL': 'light_cyan', | |
'UNINTERPRETABLE': 'light_cyan', | |
'3RD-PARTY-TALK': 'light_cyan', | |
'SELF-TALK': 'light_cyan', | |
} | |
return colored(label.upper(), colors.get(label.upper(), 'grey')) | |
# Prepare the results | |
result = [] | |
current_phrase = "" | |
current_label = "" | |
for token, prediction, (start, end) in zip(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]), predictions[0], offset_mapping): | |
if token.startswith("Ġ") or token in [".", "?", "!"]: # New word or punctuation mark | |
if current_phrase: | |
if current_label and current_label not in ["O", "I-"]: | |
result.append(f"{current_phrase.strip()} {color_label(current_label)}") | |
else: | |
result.append(current_phrase.strip()) | |
current_phrase = sentence[start:end].replace("Ġ", " ") | |
current_label = id2label[prediction.item()] | |
else: | |
current_phrase += sentence[start:end] | |
if current_label in ["O", "I-"]: | |
current_label = "" | |
# Add the last phrase | |
if current_phrase: | |
if current_label and current_label not in ["O", "I-"]: | |
result.append(f"{current_phrase.strip()} {color_label(current_label)}") | |
else: | |
result.append(current_phrase.strip()) | |
# Display the results | |
print(" ".join(result)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment