Last active
February 14, 2025 08:43
-
-
Save gaocegege/11cb5a0acf370ea8ca72a05eb69da0f8 to your computer and use it in GitHub Desktop.
simulator-LSH.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division, unicode_literals | |
import collections | |
import hashlib | |
import logging | |
import numbers | |
import re | |
import sys | |
from itertools import groupby | |
from bisect import bisect | |
import numpy as np | |
try: | |
from collections.abc import Iterable | |
except ImportError: | |
from collections import Iterable | |
def bytes_to_int(b): | |
return int.from_bytes(b, byteorder='big') | |
# List of texts (simulate requests) | |
base_texts = [ | |
"<|User|>Hello! How are you?", | |
"<|Assistant|>I'm doing well, thank you! How can I assist you today?<|end▁of▁sentence|>", | |
"<|User|>Can you help me with Python programming?", | |
"<|Assistant|>Of course! I'd be happy to help you with Python programming. What specific topic would you like to discuss?<|end▁of▁sentence|>", | |
"<|User|>How do I implement a binary search tree?", | |
"<|Assistant|>I'll help you implement a binary search tree. Here's a step-by-step guide...<|end▁of▁sentence|>", | |
"<|User|>What about sorting algorithms?", | |
"<|Assistant|>There are several sorting algorithms we can discuss. The most common ones are:<|end▁of▁sentence|>", | |
"<|User|>Can you explain quicksort?", | |
"<|Assistant|>Quicksort is an efficient sorting algorithm that uses a divide-and-conquer strategy...<|end▁of▁sentence|>", | |
"<|User|>How about merge sort?", | |
"<|Assistant|>Merge sort is another divide-and-conquer algorithm that splits the array into smaller subarrays...<|end▁of▁sentence|>", | |
"<|User|>What's the time complexity of these algorithms?", | |
"<|Assistant|>Let me break down the time complexities:<|end▁of▁sentence|>", | |
"<|User|>Thank you for explaining!", | |
"<|Assistant|>You're welcome! Let me know if you have any other questions.<|end▁of▁sentence|>", | |
"<|User|>What about space complexity?", | |
"<|Assistant|>The space complexity varies for different algorithms. Let me explain each one...<|end▁of▁sentence|>", | |
"<|User|>Can you compare bubble sort and insertion sort?", | |
"<|Assistant|>Both bubble sort and insertion sort are simple sorting algorithms, but they have different characteristics...<|end▁of▁sentence|>", | |
"<|User|>Which sorting algorithm is best for small datasets?", | |
"<|Assistant|>For small datasets, insertion sort often performs well because...<|end▁of▁sentence|>", | |
"<|User|>Goodbye and thanks for all the help!", | |
"<|Assistant|>You're welcome! Good luck with your programming journey!<|end▁of▁sentence|>" | |
] | |
class ConsistentHashing: | |
def __init__(self, servers, virtual_nodes=100): | |
""" | |
Initialize the consistent hash ring. | |
:param servers: list of servers | |
:param virtual_nodes: number of virtual nodes per server | |
""" | |
self.virtual_nodes = virtual_nodes | |
self.ring = {} # Hash ring | |
self.sorted_keys = [] # Sorted hash keys | |
# Create virtual nodes for each server | |
for server in servers: | |
self.add_server(server) | |
def add_server(self, server): | |
""" | |
Add a server and its virtual nodes to the hash ring. | |
:param server: server name | |
""" | |
for i in range(self.virtual_nodes): | |
virtual_node_name = f"{server}-{i}" | |
hash_key = self._hash(virtual_node_name) | |
self.ring[hash_key] = server | |
self.sorted_keys.append(hash_key) | |
self.sorted_keys.sort() | |
def remove_server(self, server): | |
""" | |
Remove a server and its virtual nodes from the hash ring. | |
:param server: server name | |
""" | |
for i in range(self.virtual_nodes): | |
virtual_node_name = f"{server}-{i}" | |
hash_key = self._hash(virtual_node_name) | |
del self.ring[hash_key] | |
self.sorted_keys.remove(hash_key) | |
def get_server(self, key): | |
""" | |
Find the corresponding server for a given key. | |
:param key: key | |
:return: server name | |
""" | |
if not self.ring: | |
return None | |
hash_key = self._hash(key) | |
idx = bisect(self.sorted_keys, hash_key) % len(self.sorted_keys) | |
return self.ring[self.sorted_keys[idx]] | |
def _hash(self, key): | |
""" | |
Calculate the hash value of a key. | |
:param key: key | |
:return: hash value (integer) | |
""" | |
return int(hashlib.md5(key.encode()).hexdigest(), 16) | |
def simhash(text, hash_size=64): | |
""" | |
Calculate the SimHash value using the Simhash class. | |
:param text: input text | |
:param hash_size: number of bits for the hash (default 64 bits) | |
:return: SimHash value (integer) | |
""" | |
return Simhash(text, f=hash_size).value | |
def fixed_window_lsh(text, window_size=10): | |
""" | |
Calculate the fixed-window Locality-Sensitive Hashing (LSH) value. | |
:param text: input text | |
:param window_size: size of the fixed window (default 10 characters) | |
:return: fixed-window LSH value (integer) | |
""" | |
# Split the text into fixed-size windows | |
windows = [text[i:i + window_size] for i in range(0, len(text), window_size)] | |
# Calculate the hash value for each window | |
hash_values = [simhash(window) for window in windows] | |
# Combine the hash values using XOR | |
lsh_value = 0 | |
for value in hash_values: | |
lsh_value ^= value | |
return lsh_value | |
def simulate_load_balancing(servers, texts, virtual_nodes=100, hash_algo=simhash): | |
""" | |
Simulate load balancing. | |
:param servers: list of servers | |
:param texts: list of texts | |
:param virtual_nodes: number of virtual nodes per server | |
:return: request distribution for each server | |
""" | |
ch = ConsistentHashing(servers, virtual_nodes) | |
server_load = {server: 0 for server in servers} | |
# Assigning requests | |
for text in texts: | |
hash_value = hash_algo(text) | |
server = ch.get_server(str(hash_value)) | |
print(f"Request: ...{text[len(''.join(base_texts)):]}") | |
print(f"Server: {server}") | |
print("---") | |
server_load[server] += 1 | |
return server_load | |
if sys.version_info[0] >= 3: | |
basestring = str | |
unicode = str | |
long = int | |
def int_to_bytes(n, length): | |
return n.to_bytes(length, 'big') | |
def bytes_to_int(b): | |
return int.from_bytes(b, 'big') | |
else: | |
range = xrange | |
def int_to_bytes(n, length): | |
return '{:0{}x}'.format(n, length * 2).decode('hex') | |
def bytes_to_int(b): | |
return int(b.encode('hex'), 16) | |
def _hashfunc(x): | |
return hashlib.md5(x).digest() | |
class Simhash(object): | |
# Constants used in calculating simhash. Larger values will use more RAM. | |
large_weight_cutoff = 50 | |
batch_size = 200 | |
def __init__( | |
self, value, f=64, reg=r'[\w\u4e00-\u9fcc]+', hashfunc=_hashfunc, log=None | |
): | |
""" | |
`f` is the dimensions of fingerprints, in bits. Must be a multiple of 8. | |
`reg` is meaningful only when `value` is basestring and describes | |
what is considered to be a letter inside parsed string. Regexp | |
object can also be specified (some attempt to handle any letters | |
is to specify reg=re.compile(r'\w', re.UNICODE)) | |
`hashfunc` accepts a utf-8 encoded string and returns either bytes | |
(preferred) or an unsigned integer, in at least `f // 8` bytes. | |
""" | |
if f % 8: | |
raise ValueError('f must be a multiple of 8') | |
self.f = f | |
self.f_bytes = f // 8 | |
self.reg = reg | |
self.value = None | |
self.hashfunc = hashfunc | |
self.hashfunc_returns_int = isinstance(hashfunc(b"test"), numbers.Integral) | |
if log is None: | |
self.log = logging.getLogger("simhash") | |
else: | |
self.log = log | |
if isinstance(value, Simhash): | |
self.value = value.value | |
elif isinstance(value, basestring): | |
self.build_by_text(unicode(value)) | |
elif isinstance(value, Iterable): | |
self.build_by_features(value) | |
elif isinstance(value, numbers.Integral): | |
self.value = value | |
else: | |
raise Exception('Bad parameter with type {}'.format(type(value))) | |
def __eq__(self, other): | |
""" | |
Compare two simhashes by their value. | |
:param Simhash other: The Simhash object to compare to | |
""" | |
return self.value == other.value | |
def _slide(self, content, width=4): | |
return [content[i:i + width] for i in range(max(len(content) - width + 1, 1))] | |
def _tokenize(self, content): | |
content = content.lower() | |
content = ''.join(re.findall(self.reg, content)) | |
ans = self._slide(content) | |
return ans | |
def build_by_text(self, content): | |
features = self._tokenize(content) | |
features = {k:sum(1 for _ in g) for k, g in groupby(sorted(features))} | |
return self.build_by_features(features) | |
def build_by_features(self, features): | |
""" | |
`features` might be a list of unweighted tokens (a weight of 1 | |
will be assumed), a list of (token, weight) tuples or | |
a token -> weight dict. | |
""" | |
sums = [] | |
batch = [] | |
count = 0 | |
w = 1 | |
truncate_mask = 2 ** self.f - 1 | |
if isinstance(features, dict): | |
features = features.items() | |
for f in features: | |
skip_batch = False | |
if not isinstance(f, basestring): | |
f, w = f | |
skip_batch = w > self.large_weight_cutoff or not isinstance(w, int) | |
count += w | |
if self.hashfunc_returns_int: | |
h = int_to_bytes(self.hashfunc(f.encode('utf-8')) & truncate_mask, self.f_bytes) | |
else: | |
h = self.hashfunc(f.encode('utf-8'))[-self.f_bytes:] | |
if skip_batch: | |
sums.append(self._bitarray_from_bytes(h) * w) | |
else: | |
batch.append(h * w) | |
if len(batch) >= self.batch_size: | |
sums.append(self._sum_hashes(batch)) | |
batch = [] | |
if len(sums) >= self.batch_size: | |
sums = [np.sum(sums, 0)] | |
if batch: | |
sums.append(self._sum_hashes(batch)) | |
combined_sums = np.sum(sums, 0) | |
self.value = bytes_to_int(np.packbits(combined_sums > count / 2).tobytes()) | |
def _sum_hashes(self, digests): | |
bitarray = self._bitarray_from_bytes(b''.join(digests)) | |
rows = np.reshape(bitarray, (-1, self.f)) | |
return np.sum(rows, 0) | |
@staticmethod | |
def _bitarray_from_bytes(b): | |
return np.unpackbits(np.frombuffer(b, dtype='>B')) | |
def distance(self, another): | |
assert self.f == another.f | |
x = (self.value ^ another.value) & ((1 << self.f) - 1) | |
ans = 0 | |
while x: | |
ans += 1 | |
x &= x - 1 | |
return ans | |
class SimhashIndex(object): | |
def __init__(self, objs, f=64, k=2, log=None): | |
""" | |
`objs` is a list of (obj_id, simhash) | |
obj_id is a string, simhash is an instance of Simhash | |
`f` is the same with the one for Simhash | |
`k` is the tolerance | |
""" | |
self.k = k | |
self.f = f | |
count = len(objs) | |
if log is None: | |
self.log = logging.getLogger("simhash") | |
else: | |
self.log = log | |
self.log.info('Initializing %s data.', count) | |
self.bucket = collections.defaultdict(set) | |
for i, q in enumerate(objs): | |
if i % 10000 == 0 or i == count - 1: | |
self.log.info('%s/%s', i + 1, count) | |
self.add(*q) | |
def get_near_dups(self, simhash): | |
""" | |
`simhash` is an instance of Simhash | |
return a list of obj_id, which is in type of str | |
""" | |
assert simhash.f == self.f | |
ans = set() | |
for key in self.get_keys(simhash): | |
dups = self.bucket[key] | |
self.log.debug('key:%s', key) | |
if len(dups) > 200: | |
self.log.warning('Big bucket found. key:%s, len:%s', key, len(dups)) | |
for dup in dups: | |
sim2, obj_id = dup.split(',', 1) | |
sim2 = Simhash(long(sim2, 16), self.f) | |
d = simhash.distance(sim2) | |
if d <= self.k: | |
ans.add(obj_id) | |
return list(ans) | |
def add(self, obj_id, simhash): | |
""" | |
`obj_id` is a string | |
`simhash` is an instance of Simhash | |
""" | |
assert simhash.f == self.f | |
for key in self.get_keys(simhash): | |
v = '%x,%s' % (simhash.value, obj_id) | |
self.bucket[key].add(v) | |
def delete(self, obj_id, simhash): | |
""" | |
`obj_id` is a string | |
`simhash` is an instance of Simhash | |
""" | |
assert simhash.f == self.f | |
for key in self.get_keys(simhash): | |
v = '%x,%s' % (simhash.value, obj_id) | |
if v in self.bucket[key]: | |
self.bucket[key].remove(v) | |
@property | |
def offsets(self): | |
""" | |
You may optimize this method according to <http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/33026.pdf> | |
""" | |
return [self.f // (self.k + 1) * i for i in range(self.k + 1)] | |
def get_keys(self, simhash): | |
for i, offset in enumerate(self.offsets): | |
if i == (len(self.offsets) - 1): | |
m = 2 ** (self.f - offset) - 1 | |
else: | |
m = 2 ** (self.offsets[i + 1] - offset) - 1 | |
c = simhash.value >> offset & m | |
yield '%x:%x' % (c, i) | |
def bucket_size(self): | |
return len(self.bucket) | |
# Test | |
if __name__ == "__main__": | |
# List of servers | |
servers = ["Server1", "Server2", "Server3", "Server4"] | |
# Generate variations with different suffixes | |
suffixes = [ | |
"<|User|>Hello! How are you?<|end▁of▁sentence|>", | |
"<|User|>Hello! How are you? What's new?<|end▁of▁sentence|>", | |
"<|User|>Can you help me?<|end▁of▁sentence|>", | |
"<|User|>Can you help me with Python programming?<|end▁of▁sentence|>", | |
"<|User|>How do I implement a binary search tree?<|end▁of▁sentence|>", | |
"<|User|>How do I implement a binary search tree? Can you help?<|end▁of▁sentence|>", | |
"<|User|>What's the time complexity?<|end▁of▁sentence|>", | |
"<|User|>What's the time complexity of these algorithms?<|end▁of▁sentence|>", | |
"<|User|>Can you compare bubble sort and insertion sort?<|end▁of▁sentence|>", | |
"<|User|>Can you compare bubble sort and insertion sort? Which is better?<|end▁of▁sentence|>", | |
"<|User|>Hello! How are you?<|end▁of▁sentence|><|Assistant|>I'm doing well, thank you! How can I assist you today?<|end▁of▁sentence|>", | |
"<|User|>Hello! How are you?<|end▁of▁sentence|><|Assistant|>I'm doing well, thank you! How can I assist you today?<|end▁of▁sentence|><|User|>Can you help me?<|end▁of▁sentence|>", | |
] | |
texts = [] | |
for suffix in suffixes: | |
# Concat base text with suffix to a single string | |
final_text = "".join([base_text for base_text in base_texts]) + suffix | |
texts.append(final_text) | |
# Simulate load balancing | |
server_load = simulate_load_balancing(servers, texts, virtual_nodes=100) | |
# Output result | |
print("Server Load Distribution:") | |
for server, load in server_load.items(): | |
print(f"{server}: {load} requests") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment