Last active
September 13, 2024 14:40
-
-
Save felipemello1/5f2002433c6da3a21f33d6cdf82e702a to your computer and use it in GitHub Desktop.
script update configs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script to update configs in torchtune in bulk | |
Goes over every .yaml file in configs that also has "lora" in the name | |
Finds the line that has "lora_alpha: 16" | |
Replaces the "lora_alpha: 16" with "lora_dropout: 0.0", while keeping the spacing and \n | |
Saves the file | |
Prints every file that was not updated | |
""" | |
import os | |
import shutil | |
def modify_yaml_file(file_path): | |
updated = False | |
with open(file_path, 'r') as file: | |
lines = file.readlines() | |
with open(file_path, 'w') as file: | |
for line in lines: | |
file.write(line) | |
# Check if the line contains 'lora_alpha: 16' | |
if 'lora_alpha: 16' in line: | |
# Create a new line by replacing 'lora_alpha: 16' with 'lora_dropout: 0.0' | |
new_line = line.replace('lora_alpha: 16', 'lora_dropout: 0.0') | |
# Write the new line to the file | |
file.write(new_line) | |
updated = True | |
return updated | |
def search_yaml_files(directory): | |
updated_files = [] | |
not_updated_files = [] | |
for root, dirs, files in os.walk(directory): | |
for file in files: | |
if file.endswith('.yaml'): | |
file_path = os.path.join(root, file) | |
if modify_yaml_file(file_path): | |
updated_files.append(file_path) | |
else: | |
not_updated_files.append(file_path) | |
print("Updated files:") | |
for file in updated_files: | |
print(file) | |
print("\nFiles not updated (no 'lora_alpha: 16' found):") | |
for file in not_updated_files: | |
print(file) | |
directory = 'recipes/configs' | |
search_yaml_files(directory) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment