Last active
April 16, 2023 02:25
-
-
Save ramiil/389faa6798df038d349212b19259f124 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import os | |
import tiktoken | |
import time | |
import multiprocessing | |
working_dir = os.path.dirname(os.path.realpath(__file__)) | |
dataset = 'data' | |
ws = 512*1024*1024 # 128k per chunk | |
def chunks(arr, size): | |
for i in range(0, len(arr), size): | |
yield arr[i:i + size] | |
def tofile(lst, name): | |
with open(name, 'ab') as fh: | |
for i in lst: | |
fh.write(i.to_bytes(2, 'little')) | |
def process_files(pid, lst): | |
enc = tiktoken.get_encoding("gpt2") | |
data = '' | |
for i in lst: | |
if not i[-4:]=='.txt': | |
continue | |
print(' [{0}] {1}'.format(pid, i)) | |
with open(working_dir+'\\'+dataset+'\\'+i, 'r', encoding="utf8") as f: | |
data += f.read() | |
if len(data)>=ws: | |
# encode with tiktoken gpt2 bpe | |
print(' [{0}] Encoding {1} mb of data'.format(pid, len(data)//(1024*1024))) | |
ids = enc.encode_ordinary(data) | |
tofile(ids, dataset+'_'+str(pid)+'.bin') | |
data = '' | |
MAX_THREADS = multiprocessing.cpu_count() | |
threads = [] | |
if __name__ == "__main__": | |
nowtime = time.time() | |
files = os.listdir(working_dir+'\\'+dataset) | |
for pid, ch in enumerate(chunks(files, len(files)//MAX_THREADS)): | |
print('Running process {0} of {1}'.format(pid, MAX_THREADS)) | |
threads.append(multiprocessing.Process(target=process_files, args=(pid, ch))) | |
threads[pid].start() | |
for i in range(0, MAX_THREADS): | |
threads[i].join() | |
print(time.time() - nowtime) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment