Created
May 20, 2025 20:03
-
-
Save seyyedaliayati/683f566cccf93f636375276e1f95d5af to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import combinations | |
from typing import Optional, Tuple | |
import time | |
import tracemalloc | |
from ds import * | |
# Global variables to keep track of the maximum Malware and Goodware counts | |
total_mw = 0 | |
total_gw = 0 | |
def insert_path(root, path): | |
global total_mw, total_gw | |
path = path.lower().replace("\\\\", "\\") | |
debug_print("Inserting path:", path) | |
type_dict = {"(m)": NodeTypes.Malware, "(g)": NodeTypes.Goodware} | |
components = path.split(BACKSLASH) | |
# remove empty components if any | |
components = [c for c in components if c != ""] | |
node_type = type_dict[components[-1].split(" ")[-1].strip()] | |
last_index = len(components) - 1 | |
to_add = [TreeNode(c, is_filename=False, is_subdir=True) | |
for c in components[0:last_index]] | |
to_add.append(TreeNode(name=" ".join( | |
components[-1].split(" ")[:-1]).strip(), is_filename=True, is_subdir=False)) | |
# Insert into the tree | |
current_node = root | |
for node in to_add: | |
if node.name not in current_node.children: | |
current_node.children[node.name] = node | |
current_node = current_node.children[node.name] | |
# Update the counters for Malware and Goodware | |
if node_type == NodeTypes.Malware: | |
current_node.malware_count += 1 | |
total_mw = max(total_mw, current_node.malware_count) | |
elif node_type == NodeTypes.Goodware: | |
current_node.goodware_count += 1 | |
total_gw = max(total_gw, current_node.goodware_count) | |
def dfs(node, depth=0): | |
# Print the current node's name and details | |
subdir_str = "[subdir]" if node.is_subdir else "" | |
extension_str = "(filename)" if node.is_filename else "" | |
print(" " * depth + | |
f"{node.name} {subdir_str} {extension_str} Malware: {node.malware_count} Goodware: {node.goodware_count}, Type: {node.str_type}") | |
# Recursively process all children | |
for child in node.children.values(): | |
dfs(child, depth + 1) | |
def gen_local_candidates(node: TreeNode): | |
assert node is not None, "Node is None" | |
assert isinstance(node, TreeNode), "Node is not a TreeNode" | |
r = RuleCandidates() | |
action = Actions.Block if node.is_malware() else Actions.Allow | |
if node.is_subdir: | |
# subdir nodes can be single wildcard or any wildcard | |
# e.g. "mal", "*", "?" | |
candidate = RuleSet(node) | |
candidate.add(Rule(action, node.name, node.malware_count, | |
node.goodware_count, node)) | |
r.add(candidate) | |
candidate2 = RuleSet(node) | |
candidate2.add(Rule(action, ANY_WILDCARD, | |
node.malware_count, node.goodware_count, node)) | |
r.add(candidate2) | |
# candidate3 = RuleSet(node) | |
# candidate3.add(Rule(action, SINGLE_WILDCARD, | |
# node.malware_count, node.goodware_count, node)) | |
# r.add(candidate3) | |
return r | |
if node.is_filename and DOT not in node.name: | |
# e.g. "mal" | |
candidate = RuleSet(node) | |
candidate.add(Rule(action, node.name, node.malware_count, | |
node.goodware_count, node)) | |
r.add(candidate) | |
candidate2 = RuleSet(node) | |
candidate2.add(Rule(action, ANY_WILDCARD, | |
node.malware_count, node.goodware_count, node)) | |
r.add(candidate2) | |
return r | |
if node.is_filename and DOT in node.name: | |
# e.g. "mal.exe", "mal.x64.exe" | |
# Split the filename into parts by DOT | |
name_parts = node.name.split(DOT) | |
num_parts = len(name_parts) | |
# Create RuleSet for the full filename (e.g., "mal.exe") | |
candidate = RuleSet(node) | |
candidate.add(Rule(action, node.name, node.malware_count, | |
node.goodware_count, node)) | |
r.add(candidate) | |
# Create RuleSet for the wildcard rule (e.g., ANY_WILDCARD) | |
candidate2 = RuleSet(node) | |
candidate2.add(Rule(action, ANY_WILDCARD, | |
node.malware_count, node.goodware_count, node)) | |
r.add(candidate2) | |
if not NO_REGEX_AT_ALL: | |
# Generate additional rules for each combination of parts | |
for i in range(1, num_parts): | |
rule = ".".join(name_parts[:i]) + ".*" | |
candidate3 = RuleSet(node) | |
candidate3.add( | |
Rule(action, rule, node.malware_count, node.goodware_count, node)) | |
r.add(candidate3) | |
rule = "*." + ".".join(name_parts[i:]) | |
candidate4 = RuleSet(node) | |
candidate4.add( | |
Rule(action, rule, node.malware_count, node.goodware_count, node)) | |
r.add(candidate4) | |
if len(name_parts) > 2: | |
rule_set = set() # Use a set to store unique rule names | |
# Efficiently generate index pairs | |
for i, j in combinations(range(1, num_parts), 2): | |
rule = ".".join(name_parts[:i]) + \ | |
".*." + ".".join(name_parts[j:]) | |
if rule not in rule_set: # Ensure uniqueness | |
rule_set.add(rule) | |
candidate5 = RuleSet(node) | |
candidate5.add( | |
Rule(action, rule, node.malware_count, node.goodware_count, node)) | |
r.add(candidate5) | |
return r | |
return r | |
def merge_two_rules_same_dir(r1: Rule, r2: Rule, rs1_has_block_all=False, rs2_has_block_all=False): | |
debug_print(f"[merge_two_rules_same_dir] Merging {r1} and {r2}") | |
update_r1 = False | |
update_r2 = False | |
fork_required = False | |
if r1.is_simple_rule() and r2.is_simple_rule(): | |
if r1 == r2: | |
return [r2], update_r1, fork_required | |
return [r1, r2], update_r1, fork_required | |
# at least one has a wildcard | |
# if same action, merge --> more general | |
if r1.action == r2.action: | |
if r1.is_subset_of(r2): | |
# r2 is more general --> return r2 | |
update_r2 = True | |
merged = Rule(r2.action, r2.name, r1.mw + r2.mw, r1.gw + r2.gw) | |
return [merged], update_r1, fork_required | |
elif r2.is_subset_of(r1): | |
# r1 is more general --> return r1 | |
update_r1 = True | |
merged = Rule(r1.action, r1.name, r1.mw + r2.mw, r1.gw + r2.gw) | |
return [merged], update_r1, fork_required | |
elif r1 == r2: | |
merged = Rule(r1.action, r1.name, r1.mw + r2.mw, r1.gw + r2.gw) | |
update_r1 = True | |
return [merged], update_r1, fork_required | |
else: | |
return [r1, r2], update_r1, fork_required | |
# conflict | |
if r1.is_simple_rule(): | |
# pseudo conflict | |
return [r1, r2], update_r1, fork_required | |
elif r2.is_simple_rule(): | |
# pseudo conflict | |
return [r2, r1], update_r1, fork_required | |
elif r1.is_subset_of(r2) or r2.is_subset_of(r1): | |
if r1.specific_score > r2.specific_score: | |
return [r1, r2], update_r1, fork_required | |
elif r1.specific_score < r2.specific_score: | |
return [r2, r1], update_r1, fork_required | |
else: | |
raise Exception("should not happen") | |
elif r1.name == r2.name: | |
# strict conflict | |
if SOLVE_STRICT_CONFLICTS: | |
if r1.is_allowed(): | |
debug_print("[merge_two_rules_same_dir] Fork case 1") | |
# allow *, block * --> either we accept the allow (mr1) or the block (mr2) | |
# e.g. Merging Rule(action=Allow, name=*, mw=0, gw=2) and Rule(action=Block, name=*, mw=1, gw=0) | |
fork_required = True | |
mr1 = Rule(Actions.Allow, r1.name, r1.mw - r2.mw, r1.gw + abs(r2.gw)) | |
mr2 = Rule(Actions.Block, r2.name, r2.mw - abs(r1.mw), r2.gw - r1.gw) | |
return [mr1, mr2], update_r1, fork_required | |
elif r1.is_blocked(): | |
debug_print("[merge_two_rules_same_dir] Fork case 2") | |
# block *, allow * --> either we accept the block (mr1) or the allow (mr2) | |
# e.g. Merging Rule(action=Block, name=*, mw=2, gw=0) and Rule(action=Allow, name=*, mw=0, gw=1) | |
fork_required = True | |
mr1 = Rule(Actions.Block, r1.name, r1.mw + abs(r2.mw), r1.gw - r2.gw) | |
mr2 = Rule(Actions.Allow, r2.name, r2.mw - r1.mw, r2.gw + abs(r1.gw)) | |
return [mr1, mr2], update_r1, fork_required | |
else: | |
raise Exception("should not happen") | |
else: | |
return [], update_r1, fork_required | |
else: | |
return [r1, r2], update_r1, fork_required | |
def merge_same_dir(node, rc) -> RuleCandidates: | |
if len(rc) == 0: | |
return RuleCandidates() | |
if len(rc) == 1: | |
return rc[0] | |
if len(rc) > 2: | |
mid = int(len(rc)/2) | |
p1 = merge_same_dir(node, rc[:mid]) | |
p2 = merge_same_dir(node, rc[mid:]) | |
return merge_same_dir(node, [p1, p2]) | |
assert len(rc) == 2, "rc should be 2" | |
debug_print("[merge_same_dir] to merge:", rc) | |
rc1 = rc[0] | |
rc2 = rc[1] | |
if len(rc1) > 0 and len(rc2) == 0: | |
return rc1 | |
if len(rc1) == 0 and len(rc2) > 0: | |
return rc2 | |
new_rc = RuleCandidates(no_validation=True) | |
for rs1 in rc1.candidates: | |
rs1_has_block_all = rs1.has_block_all() | |
rs1_value = rs1.get_value() | |
for rs2 in rc2.candidates: | |
rs2_has_block_all = rs2.has_block_all() | |
rs2_value = rs2.get_value() | |
new_rs_list = [RuleSet(node)] | |
debug_print(f"[merge_same_dir] current list:\n", new_rs_list) | |
debug_print(f"[merge_same_dir] merging\nrs1: {rs1}, \nrs2: {rs2}") | |
for r1 in rs1.rules: | |
for r2 in rs2.rules: | |
update_r1 = False | |
merged_rules, update_r1, fork_required = merge_two_rules_same_dir(r1, r2, rs1_has_block_all, rs2_has_block_all) | |
assert len(merged_rules) < 3, "More than 2 rules merged" | |
mr1, mr2 = (merged_rules + [None, None])[:2] | |
debug_print(f"[merge_same_dir] MR1:", mr1) | |
debug_print(f"[merge_same_dir] MR2:", mr2) | |
if fork_required: | |
assert mr1 is not None and mr2 is not None, "Fork required but no rules" | |
debug_print("Applying fork") | |
debug_print( | |
f"[merge_same_dir] rs before fork: {new_rs_list}") | |
forked_rs_list = [] | |
for existing_rs in new_rs_list: | |
forked_rs = existing_rs.fork() | |
existing_rs.add(mr1) | |
forked_rs.add(mr2) | |
forked_rs_list.append(forked_rs) | |
new_rs_list.extend(forked_rs_list) | |
debug_print( | |
f"[merge_same_dir] rs after fork: {new_rs_list}") | |
else: | |
for existing_rs in new_rs_list: | |
existing_rs.add(mr1) | |
existing_rs.add(mr2) | |
if update_r1 and mr1: | |
r1 = mr1 | |
for new_rs in new_rs_list: | |
new_rs.calculate_cost_benefit() | |
new_rs_value = new_rs.get_value() | |
is_valid = rs1_value + rs2_value == new_rs_value | |
if len(new_rs) > 0: | |
if is_valid: | |
new_rc.add(new_rs) | |
debug_print("[merge_same_dir] New Rule Set:", new_rs) | |
else: | |
debug_print("[merge_same_dir] Invalid Rule Set:", new_rs, | |
f"Because {rs1_value} + {rs2_value} != {new_rs_value}") | |
return new_rc | |
def merge_two_rules_cross_dir(r1: Rule, r2: Rule) -> Optional[Rule]: | |
# always pick the action of r2 | |
new_action = r2.action | |
if r1 == r2: | |
return r2 | |
debug_print(f"[merge_two_rules_cross_dir] Merging {r1} and {r2}") | |
if r1.is_any_wildcard() and r2.starts_with_any_wildcard(): | |
return Rule(new_action, r2.name, | |
min(r1.mw, r2.mw), | |
min(r1.gw, r2.gw)) | |
new_name = r1.name + BACKSLASH + r2.name | |
if r1.is_single_wildcard() and r2.starts_with_single_wildcard(): | |
return Rule(new_action, new_name, | |
min(r1.mw, r2.mw), | |
min(r1.gw, r2.gw)) | |
return Rule(new_action, new_name, | |
min(r1.mw, r2.mw), | |
min(r1.gw, r2.gw)) | |
def merge_cross_dir(node: TreeNode, current, childs): | |
merged = RuleCandidates() | |
if current.is_empty() and childs.is_empty(): | |
debug_print( | |
"[merge_cross_dir][Warning] Both current and childs are empty") | |
return merged | |
if childs.is_empty(): | |
return current | |
if current.is_empty(): | |
raise Exception("Current is empty") | |
for c1 in current.candidates: | |
for c2 in childs.candidates: | |
# each candidate is a RuleSet | |
merged_ruleset = RuleSet(node) | |
for r1 in c1.rules: | |
for r2 in c2.rules: | |
merged_rule = merge_two_rules_cross_dir(r1, r2) | |
if merged_rule: | |
# debug_print("[merge_cross_dir] Merged rule:", merged_rule) | |
merged_ruleset.add(merged_rule) | |
# if merged_ruleset.mw > node.malware_count or merged_ruleset.gw > node.goodware_count: | |
# debug_print("[merge_cross_dir] Merged RuleSet:", merged_ruleset) | |
if merged_ruleset.mw > node.malware_count or merged_ruleset.gw > node.goodware_count: | |
debug_print( | |
f"[merge_cross_dir] RuleSet too big: {merged_ruleset.mw} > {node.malware_count} OR {merged_ruleset.gw} > {node.goodware_count}") | |
continue | |
merged.add(merged_ruleset) | |
debug_print( | |
f"[merge_cross_dir] Merged RuleSet: {merged_ruleset}, Node: {node.malware_count} {node.goodware_count}") | |
return merged | |
def reduce(node, depth=0): | |
debug_print("[reduce] Entered node:", node) | |
child_candidates = merge_same_dir( | |
node, [reduce(child, depth+1) for child in node.children.values()]) | |
if node.is_root: | |
return child_candidates | |
# apply validation | |
for rs in child_candidates.candidates: | |
if not rs.is_valid(): | |
# drop the ruleset | |
debug_print( | |
"[reduce] Dropping invalid RuleSet:", rs) | |
child_candidates.drop(rs) | |
debug_print("[reduce] Children of", node.name, "are", child_candidates) | |
current_candidates = gen_local_candidates(node) | |
debug_print("[reduce] Candidates of", node.name, "are", current_candidates) | |
merged = merge_cross_dir(node, current_candidates, child_candidates) | |
debug_print("[reduce] Merge result for", node.name, "is", merged) | |
return merged | |
def reduce_parallel(node, depth=0): | |
raise NotImplementedError("Parallel processing is not implemented yet") | |
def measure_time_and_memory(testcase_function, *args, **kwargs): | |
""" | |
Measure the time and memory consumed by a function. | |
""" | |
# Start tracking memory usage | |
tracemalloc.start() | |
# Start timing | |
start_time = time.time() | |
# Execute the function | |
result = testcase_function(*args, **kwargs) | |
# Stop timing | |
end_time = time.time() | |
# Stop tracking memory usage and get the peak memory usage | |
snapshot = tracemalloc.take_snapshot() | |
top_stats = snapshot.statistics('lineno') | |
current_memory, peak_memory = tracemalloc.get_traced_memory() | |
tracemalloc.stop() | |
# Calculate elapsed time | |
elapsed_time = end_time - start_time | |
return result, elapsed_time, peak_memory, top_stats | |
def process_testcase(testcase): | |
""" | |
Process a single test case and measure time and memory. | |
""" | |
paths = testcase | |
root = TreeNode("root", is_root=True) | |
for path in paths: | |
print(f"Inserting path: {path}") | |
insert_path(root, path) | |
print("\n=== DFS ===") | |
dfs(root) # for visualization only | |
# visualize_graph(root) | |
print("===========\n") | |
candidate_rules = reduce(root) | |
return candidate_rules | |
def test(): | |
# what about lower and upper case? | |
paths = [ | |
"C:\\temp\\mal1 (M)", | |
"C:\\temp\\mal2 (M)", | |
"C:\\temp\\mal3 (M)", | |
"C:\\temp\\good1 (G)", | |
"C:\\temp\\mal.exe (M)", | |
"C:\\temp\\mal.x64.exe (M)", | |
"C:\\system1\\calc1 (G)", | |
"C:\\system1\\calc2 (G)", | |
"C:\\system2\\calc3 (G)", | |
# "C:\\system3\\calc3 (G)", | |
"C:\\Windows\\calc.exe (G)", | |
# "C:\\temp\\file1 (M)", | |
# "C:\\temp\\file1.exe (M)", | |
# "C:\\temp\\file1.exe.bin (M)", | |
# "C:\\temp\\file3 (M)", | |
# "C:\\file3 (M)", | |
# "C:\\files\\mw3.exe (M)", | |
# "C:\\temp\\file4 (G)", | |
# "C:\\lol\\subdir\\ (M)", | |
# "C:\\temp\\mw1.exe (M)", | |
# "C:\\temp\\mw1.txt.exe (M)", | |
# "C:\\lol\\dir\\xpto.exe (G)", | |
# "C:\\lol\\dir\\foobar.exe (G)", | |
# "C:\\win\\cmd.exe (G)", | |
# "C:\\win\\cmd.exe (M)", | |
# "C:\\win\\cmd.exe (G)", | |
# "C:\\win\\calc.exe (G)", | |
# "C:\\lol\\mw5.dll (M)", | |
# "C:\\lol2\\mw6.dll (M)", | |
# "C:\\mw5.exe (M)" | |
] | |
from testcases import all_testcases | |
# from realcases import all_testcases | |
for counter, testcase in enumerate(all_testcases, start=1): | |
print(f"=== Testcase {counter} ===") | |
# Measure time and memory for the test case | |
result, elapsed_time, peak_memory, top_stats = measure_time_and_memory( | |
process_testcase, testcase) | |
candidate_rules = result | |
print("\n=== Candidate Rules ===") | |
print(candidate_rules) | |
print("========================\n") | |
# print(f"Number of candidates: {len(candidate_rules.candidates)}") | |
# for candidate in candidate_rules.candidates: | |
# print(f"Cost-benefit: {candidate.cost_benefit}, MW: {candidate.mw}, GW: {candidate.gw}") | |
# for rule in candidate.rules: | |
# print(rule.str_action, rule.name) | |
# print("--------") | |
# print("========================\n") | |
# Print time and memory results | |
print(f"[main] Testcase {counter} completed") | |
print("[main] Top 10 memory usage:") | |
for stat in top_stats[:10]: | |
print(stat) | |
print(f"[main] Time taken: {elapsed_time:.4f} seconds") | |
print(f"[main] Peak memory usage: {peak_memory / 1024:.2f} KB") | |
print(f"[main] Number of candidates: {len(candidate_rules)}") | |
print("========================\n") | |
def main(): | |
input_file = "input.txt" | |
input_file = "temp/input_test_2.txt" | |
output_file = input_file.replace("input", "output") | |
print(f"Input file: {input_file}") | |
paths = [] | |
with open(input_file, "r") as f: | |
for line in f: | |
if line.startswith("#"): | |
continue | |
if line.strip() == "": | |
continue | |
paths.append(line.strip()) | |
print(f"Number of paths: {len(paths)}") | |
print("Processing test case") | |
result, elapsed_time, peak_memory, top_stats = measure_time_and_memory( | |
process_testcase, paths) | |
candidate_rules = result | |
print("\n=== Candidate Rules ===") | |
print(candidate_rules) | |
print("========================\n") | |
# print(f"Number of candidates: {len(candidate_rules.candidates)}") | |
# for candidate in candidate_rules.candidates: | |
# print(f"Cost-benefit: {candidate.cost_benefit}") | |
# for rule in candidate.rules: | |
# print(rule.str_action, rule.name) | |
# print("--------") | |
# print("========================\n") | |
# Print time and memory results | |
print(f"[main] Testcase completed") | |
print("[main] Top 10 memory usage:") | |
for stat in top_stats[:10]: | |
print(stat) | |
print(f"[main] Time taken: {elapsed_time:.4f} seconds") | |
print(f"[main] Peak memory usage: {peak_memory / 1024:.2f} KB") | |
print(f"[main] Number of candidates: {len(candidate_rules)}") | |
print("========================\n") | |
if __name__ == "__main__": | |
# [RuleSet] Rule is a subset of an existing rule: < | |
r1 = Rule(action=Actions.Allow, name="*.temp", mw=0, gw=1) | |
r2 = Rule(action=Actions.Allow, name="*\\file1.*", mw=0, gw=1) | |
r1 = Rule(Actions.Allow, "temp\\mal\\file.txt", 0, 0) | |
r2 = Rule(Actions.Allow, "temp\\*", 0, 0) | |
r1 = Rule(Actions.Allow, "temp\\file.txt", 0, 0) | |
r2 = Rule(Actions.Allow, "?\\file.txt", 0, 0) | |
r1 = Rule(Actions.Allow, "temp\\file.txt", 0, 0) | |
r2 = Rule(Actions.Allow, "temp\\*.txt", 0, 0) | |
r1 = Rule(Actions.Allow, "file3.*", 0, 0) | |
r2 = Rule(Actions.Allow, "*", 0, 0) | |
r1 = Rule(Actions.Allow, "?\*.dll", 0, 0) | |
r2 = Rule(Actions.Allow, "?\*", 0, 0) | |
r1 = Rule(Actions.Allow, "", 0, 0) | |
r2 = Rule(Actions.Allow, "", 0, 0) | |
r1 = Rule(Actions.Allow, "temp\\file.txt", 0, 0) | |
r2 = Rule(Actions.Allow, "?\\file.txt", 0, 0) | |
# print(r1.is_subset_of(r2)) | |
# test() | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment