Skip to content

Instantly share code, notes, and snippets.

@seyyedaliayati
Created May 20, 2025 20:03
Show Gist options
  • Save seyyedaliayati/683f566cccf93f636375276e1f95d5af to your computer and use it in GitHub Desktop.
Save seyyedaliayati/683f566cccf93f636375276e1f95d5af to your computer and use it in GitHub Desktop.
from itertools import combinations
from typing import Optional, Tuple
import time
import tracemalloc
from ds import *
# Global variables to keep track of the maximum Malware and Goodware counts
total_mw = 0
total_gw = 0
def insert_path(root, path):
global total_mw, total_gw
path = path.lower().replace("\\\\", "\\")
debug_print("Inserting path:", path)
type_dict = {"(m)": NodeTypes.Malware, "(g)": NodeTypes.Goodware}
components = path.split(BACKSLASH)
# remove empty components if any
components = [c for c in components if c != ""]
node_type = type_dict[components[-1].split(" ")[-1].strip()]
last_index = len(components) - 1
to_add = [TreeNode(c, is_filename=False, is_subdir=True)
for c in components[0:last_index]]
to_add.append(TreeNode(name=" ".join(
components[-1].split(" ")[:-1]).strip(), is_filename=True, is_subdir=False))
# Insert into the tree
current_node = root
for node in to_add:
if node.name not in current_node.children:
current_node.children[node.name] = node
current_node = current_node.children[node.name]
# Update the counters for Malware and Goodware
if node_type == NodeTypes.Malware:
current_node.malware_count += 1
total_mw = max(total_mw, current_node.malware_count)
elif node_type == NodeTypes.Goodware:
current_node.goodware_count += 1
total_gw = max(total_gw, current_node.goodware_count)
def dfs(node, depth=0):
# Print the current node's name and details
subdir_str = "[subdir]" if node.is_subdir else ""
extension_str = "(filename)" if node.is_filename else ""
print(" " * depth +
f"{node.name} {subdir_str} {extension_str} Malware: {node.malware_count} Goodware: {node.goodware_count}, Type: {node.str_type}")
# Recursively process all children
for child in node.children.values():
dfs(child, depth + 1)
def gen_local_candidates(node: TreeNode):
assert node is not None, "Node is None"
assert isinstance(node, TreeNode), "Node is not a TreeNode"
r = RuleCandidates()
action = Actions.Block if node.is_malware() else Actions.Allow
if node.is_subdir:
# subdir nodes can be single wildcard or any wildcard
# e.g. "mal", "*", "?"
candidate = RuleSet(node)
candidate.add(Rule(action, node.name, node.malware_count,
node.goodware_count, node))
r.add(candidate)
candidate2 = RuleSet(node)
candidate2.add(Rule(action, ANY_WILDCARD,
node.malware_count, node.goodware_count, node))
r.add(candidate2)
# candidate3 = RuleSet(node)
# candidate3.add(Rule(action, SINGLE_WILDCARD,
# node.malware_count, node.goodware_count, node))
# r.add(candidate3)
return r
if node.is_filename and DOT not in node.name:
# e.g. "mal"
candidate = RuleSet(node)
candidate.add(Rule(action, node.name, node.malware_count,
node.goodware_count, node))
r.add(candidate)
candidate2 = RuleSet(node)
candidate2.add(Rule(action, ANY_WILDCARD,
node.malware_count, node.goodware_count, node))
r.add(candidate2)
return r
if node.is_filename and DOT in node.name:
# e.g. "mal.exe", "mal.x64.exe"
# Split the filename into parts by DOT
name_parts = node.name.split(DOT)
num_parts = len(name_parts)
# Create RuleSet for the full filename (e.g., "mal.exe")
candidate = RuleSet(node)
candidate.add(Rule(action, node.name, node.malware_count,
node.goodware_count, node))
r.add(candidate)
# Create RuleSet for the wildcard rule (e.g., ANY_WILDCARD)
candidate2 = RuleSet(node)
candidate2.add(Rule(action, ANY_WILDCARD,
node.malware_count, node.goodware_count, node))
r.add(candidate2)
if not NO_REGEX_AT_ALL:
# Generate additional rules for each combination of parts
for i in range(1, num_parts):
rule = ".".join(name_parts[:i]) + ".*"
candidate3 = RuleSet(node)
candidate3.add(
Rule(action, rule, node.malware_count, node.goodware_count, node))
r.add(candidate3)
rule = "*." + ".".join(name_parts[i:])
candidate4 = RuleSet(node)
candidate4.add(
Rule(action, rule, node.malware_count, node.goodware_count, node))
r.add(candidate4)
if len(name_parts) > 2:
rule_set = set() # Use a set to store unique rule names
# Efficiently generate index pairs
for i, j in combinations(range(1, num_parts), 2):
rule = ".".join(name_parts[:i]) + \
".*." + ".".join(name_parts[j:])
if rule not in rule_set: # Ensure uniqueness
rule_set.add(rule)
candidate5 = RuleSet(node)
candidate5.add(
Rule(action, rule, node.malware_count, node.goodware_count, node))
r.add(candidate5)
return r
return r
def merge_two_rules_same_dir(r1: Rule, r2: Rule, rs1_has_block_all=False, rs2_has_block_all=False):
debug_print(f"[merge_two_rules_same_dir] Merging {r1} and {r2}")
update_r1 = False
update_r2 = False
fork_required = False
if r1.is_simple_rule() and r2.is_simple_rule():
if r1 == r2:
return [r2], update_r1, fork_required
return [r1, r2], update_r1, fork_required
# at least one has a wildcard
# if same action, merge --> more general
if r1.action == r2.action:
if r1.is_subset_of(r2):
# r2 is more general --> return r2
update_r2 = True
merged = Rule(r2.action, r2.name, r1.mw + r2.mw, r1.gw + r2.gw)
return [merged], update_r1, fork_required
elif r2.is_subset_of(r1):
# r1 is more general --> return r1
update_r1 = True
merged = Rule(r1.action, r1.name, r1.mw + r2.mw, r1.gw + r2.gw)
return [merged], update_r1, fork_required
elif r1 == r2:
merged = Rule(r1.action, r1.name, r1.mw + r2.mw, r1.gw + r2.gw)
update_r1 = True
return [merged], update_r1, fork_required
else:
return [r1, r2], update_r1, fork_required
# conflict
if r1.is_simple_rule():
# pseudo conflict
return [r1, r2], update_r1, fork_required
elif r2.is_simple_rule():
# pseudo conflict
return [r2, r1], update_r1, fork_required
elif r1.is_subset_of(r2) or r2.is_subset_of(r1):
if r1.specific_score > r2.specific_score:
return [r1, r2], update_r1, fork_required
elif r1.specific_score < r2.specific_score:
return [r2, r1], update_r1, fork_required
else:
raise Exception("should not happen")
elif r1.name == r2.name:
# strict conflict
if SOLVE_STRICT_CONFLICTS:
if r1.is_allowed():
debug_print("[merge_two_rules_same_dir] Fork case 1")
# allow *, block * --> either we accept the allow (mr1) or the block (mr2)
# e.g. Merging Rule(action=Allow, name=*, mw=0, gw=2) and Rule(action=Block, name=*, mw=1, gw=0)
fork_required = True
mr1 = Rule(Actions.Allow, r1.name, r1.mw - r2.mw, r1.gw + abs(r2.gw))
mr2 = Rule(Actions.Block, r2.name, r2.mw - abs(r1.mw), r2.gw - r1.gw)
return [mr1, mr2], update_r1, fork_required
elif r1.is_blocked():
debug_print("[merge_two_rules_same_dir] Fork case 2")
# block *, allow * --> either we accept the block (mr1) or the allow (mr2)
# e.g. Merging Rule(action=Block, name=*, mw=2, gw=0) and Rule(action=Allow, name=*, mw=0, gw=1)
fork_required = True
mr1 = Rule(Actions.Block, r1.name, r1.mw + abs(r2.mw), r1.gw - r2.gw)
mr2 = Rule(Actions.Allow, r2.name, r2.mw - r1.mw, r2.gw + abs(r1.gw))
return [mr1, mr2], update_r1, fork_required
else:
raise Exception("should not happen")
else:
return [], update_r1, fork_required
else:
return [r1, r2], update_r1, fork_required
def merge_same_dir(node, rc) -> RuleCandidates:
if len(rc) == 0:
return RuleCandidates()
if len(rc) == 1:
return rc[0]
if len(rc) > 2:
mid = int(len(rc)/2)
p1 = merge_same_dir(node, rc[:mid])
p2 = merge_same_dir(node, rc[mid:])
return merge_same_dir(node, [p1, p2])
assert len(rc) == 2, "rc should be 2"
debug_print("[merge_same_dir] to merge:", rc)
rc1 = rc[0]
rc2 = rc[1]
if len(rc1) > 0 and len(rc2) == 0:
return rc1
if len(rc1) == 0 and len(rc2) > 0:
return rc2
new_rc = RuleCandidates(no_validation=True)
for rs1 in rc1.candidates:
rs1_has_block_all = rs1.has_block_all()
rs1_value = rs1.get_value()
for rs2 in rc2.candidates:
rs2_has_block_all = rs2.has_block_all()
rs2_value = rs2.get_value()
new_rs_list = [RuleSet(node)]
debug_print(f"[merge_same_dir] current list:\n", new_rs_list)
debug_print(f"[merge_same_dir] merging\nrs1: {rs1}, \nrs2: {rs2}")
for r1 in rs1.rules:
for r2 in rs2.rules:
update_r1 = False
merged_rules, update_r1, fork_required = merge_two_rules_same_dir(r1, r2, rs1_has_block_all, rs2_has_block_all)
assert len(merged_rules) < 3, "More than 2 rules merged"
mr1, mr2 = (merged_rules + [None, None])[:2]
debug_print(f"[merge_same_dir] MR1:", mr1)
debug_print(f"[merge_same_dir] MR2:", mr2)
if fork_required:
assert mr1 is not None and mr2 is not None, "Fork required but no rules"
debug_print("Applying fork")
debug_print(
f"[merge_same_dir] rs before fork: {new_rs_list}")
forked_rs_list = []
for existing_rs in new_rs_list:
forked_rs = existing_rs.fork()
existing_rs.add(mr1)
forked_rs.add(mr2)
forked_rs_list.append(forked_rs)
new_rs_list.extend(forked_rs_list)
debug_print(
f"[merge_same_dir] rs after fork: {new_rs_list}")
else:
for existing_rs in new_rs_list:
existing_rs.add(mr1)
existing_rs.add(mr2)
if update_r1 and mr1:
r1 = mr1
for new_rs in new_rs_list:
new_rs.calculate_cost_benefit()
new_rs_value = new_rs.get_value()
is_valid = rs1_value + rs2_value == new_rs_value
if len(new_rs) > 0:
if is_valid:
new_rc.add(new_rs)
debug_print("[merge_same_dir] New Rule Set:", new_rs)
else:
debug_print("[merge_same_dir] Invalid Rule Set:", new_rs,
f"Because {rs1_value} + {rs2_value} != {new_rs_value}")
return new_rc
def merge_two_rules_cross_dir(r1: Rule, r2: Rule) -> Optional[Rule]:
# always pick the action of r2
new_action = r2.action
if r1 == r2:
return r2
debug_print(f"[merge_two_rules_cross_dir] Merging {r1} and {r2}")
if r1.is_any_wildcard() and r2.starts_with_any_wildcard():
return Rule(new_action, r2.name,
min(r1.mw, r2.mw),
min(r1.gw, r2.gw))
new_name = r1.name + BACKSLASH + r2.name
if r1.is_single_wildcard() and r2.starts_with_single_wildcard():
return Rule(new_action, new_name,
min(r1.mw, r2.mw),
min(r1.gw, r2.gw))
return Rule(new_action, new_name,
min(r1.mw, r2.mw),
min(r1.gw, r2.gw))
def merge_cross_dir(node: TreeNode, current, childs):
merged = RuleCandidates()
if current.is_empty() and childs.is_empty():
debug_print(
"[merge_cross_dir][Warning] Both current and childs are empty")
return merged
if childs.is_empty():
return current
if current.is_empty():
raise Exception("Current is empty")
for c1 in current.candidates:
for c2 in childs.candidates:
# each candidate is a RuleSet
merged_ruleset = RuleSet(node)
for r1 in c1.rules:
for r2 in c2.rules:
merged_rule = merge_two_rules_cross_dir(r1, r2)
if merged_rule:
# debug_print("[merge_cross_dir] Merged rule:", merged_rule)
merged_ruleset.add(merged_rule)
# if merged_ruleset.mw > node.malware_count or merged_ruleset.gw > node.goodware_count:
# debug_print("[merge_cross_dir] Merged RuleSet:", merged_ruleset)
if merged_ruleset.mw > node.malware_count or merged_ruleset.gw > node.goodware_count:
debug_print(
f"[merge_cross_dir] RuleSet too big: {merged_ruleset.mw} > {node.malware_count} OR {merged_ruleset.gw} > {node.goodware_count}")
continue
merged.add(merged_ruleset)
debug_print(
f"[merge_cross_dir] Merged RuleSet: {merged_ruleset}, Node: {node.malware_count} {node.goodware_count}")
return merged
def reduce(node, depth=0):
debug_print("[reduce] Entered node:", node)
child_candidates = merge_same_dir(
node, [reduce(child, depth+1) for child in node.children.values()])
if node.is_root:
return child_candidates
# apply validation
for rs in child_candidates.candidates:
if not rs.is_valid():
# drop the ruleset
debug_print(
"[reduce] Dropping invalid RuleSet:", rs)
child_candidates.drop(rs)
debug_print("[reduce] Children of", node.name, "are", child_candidates)
current_candidates = gen_local_candidates(node)
debug_print("[reduce] Candidates of", node.name, "are", current_candidates)
merged = merge_cross_dir(node, current_candidates, child_candidates)
debug_print("[reduce] Merge result for", node.name, "is", merged)
return merged
def reduce_parallel(node, depth=0):
raise NotImplementedError("Parallel processing is not implemented yet")
def measure_time_and_memory(testcase_function, *args, **kwargs):
"""
Measure the time and memory consumed by a function.
"""
# Start tracking memory usage
tracemalloc.start()
# Start timing
start_time = time.time()
# Execute the function
result = testcase_function(*args, **kwargs)
# Stop timing
end_time = time.time()
# Stop tracking memory usage and get the peak memory usage
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
current_memory, peak_memory = tracemalloc.get_traced_memory()
tracemalloc.stop()
# Calculate elapsed time
elapsed_time = end_time - start_time
return result, elapsed_time, peak_memory, top_stats
def process_testcase(testcase):
"""
Process a single test case and measure time and memory.
"""
paths = testcase
root = TreeNode("root", is_root=True)
for path in paths:
print(f"Inserting path: {path}")
insert_path(root, path)
print("\n=== DFS ===")
dfs(root) # for visualization only
# visualize_graph(root)
print("===========\n")
candidate_rules = reduce(root)
return candidate_rules
def test():
# what about lower and upper case?
paths = [
"C:\\temp\\mal1 (M)",
"C:\\temp\\mal2 (M)",
"C:\\temp\\mal3 (M)",
"C:\\temp\\good1 (G)",
"C:\\temp\\mal.exe (M)",
"C:\\temp\\mal.x64.exe (M)",
"C:\\system1\\calc1 (G)",
"C:\\system1\\calc2 (G)",
"C:\\system2\\calc3 (G)",
# "C:\\system3\\calc3 (G)",
"C:\\Windows\\calc.exe (G)",
# "C:\\temp\\file1 (M)",
# "C:\\temp\\file1.exe (M)",
# "C:\\temp\\file1.exe.bin (M)",
# "C:\\temp\\file3 (M)",
# "C:\\file3 (M)",
# "C:\\files\\mw3.exe (M)",
# "C:\\temp\\file4 (G)",
# "C:\\lol\\subdir\\ (M)",
# "C:\\temp\\mw1.exe (M)",
# "C:\\temp\\mw1.txt.exe (M)",
# "C:\\lol\\dir\\xpto.exe (G)",
# "C:\\lol\\dir\\foobar.exe (G)",
# "C:\\win\\cmd.exe (G)",
# "C:\\win\\cmd.exe (M)",
# "C:\\win\\cmd.exe (G)",
# "C:\\win\\calc.exe (G)",
# "C:\\lol\\mw5.dll (M)",
# "C:\\lol2\\mw6.dll (M)",
# "C:\\mw5.exe (M)"
]
from testcases import all_testcases
# from realcases import all_testcases
for counter, testcase in enumerate(all_testcases, start=1):
print(f"=== Testcase {counter} ===")
# Measure time and memory for the test case
result, elapsed_time, peak_memory, top_stats = measure_time_and_memory(
process_testcase, testcase)
candidate_rules = result
print("\n=== Candidate Rules ===")
print(candidate_rules)
print("========================\n")
# print(f"Number of candidates: {len(candidate_rules.candidates)}")
# for candidate in candidate_rules.candidates:
# print(f"Cost-benefit: {candidate.cost_benefit}, MW: {candidate.mw}, GW: {candidate.gw}")
# for rule in candidate.rules:
# print(rule.str_action, rule.name)
# print("--------")
# print("========================\n")
# Print time and memory results
print(f"[main] Testcase {counter} completed")
print("[main] Top 10 memory usage:")
for stat in top_stats[:10]:
print(stat)
print(f"[main] Time taken: {elapsed_time:.4f} seconds")
print(f"[main] Peak memory usage: {peak_memory / 1024:.2f} KB")
print(f"[main] Number of candidates: {len(candidate_rules)}")
print("========================\n")
def main():
input_file = "input.txt"
input_file = "temp/input_test_2.txt"
output_file = input_file.replace("input", "output")
print(f"Input file: {input_file}")
paths = []
with open(input_file, "r") as f:
for line in f:
if line.startswith("#"):
continue
if line.strip() == "":
continue
paths.append(line.strip())
print(f"Number of paths: {len(paths)}")
print("Processing test case")
result, elapsed_time, peak_memory, top_stats = measure_time_and_memory(
process_testcase, paths)
candidate_rules = result
print("\n=== Candidate Rules ===")
print(candidate_rules)
print("========================\n")
# print(f"Number of candidates: {len(candidate_rules.candidates)}")
# for candidate in candidate_rules.candidates:
# print(f"Cost-benefit: {candidate.cost_benefit}")
# for rule in candidate.rules:
# print(rule.str_action, rule.name)
# print("--------")
# print("========================\n")
# Print time and memory results
print(f"[main] Testcase completed")
print("[main] Top 10 memory usage:")
for stat in top_stats[:10]:
print(stat)
print(f"[main] Time taken: {elapsed_time:.4f} seconds")
print(f"[main] Peak memory usage: {peak_memory / 1024:.2f} KB")
print(f"[main] Number of candidates: {len(candidate_rules)}")
print("========================\n")
if __name__ == "__main__":
# [RuleSet] Rule is a subset of an existing rule: <
r1 = Rule(action=Actions.Allow, name="*.temp", mw=0, gw=1)
r2 = Rule(action=Actions.Allow, name="*\\file1.*", mw=0, gw=1)
r1 = Rule(Actions.Allow, "temp\\mal\\file.txt", 0, 0)
r2 = Rule(Actions.Allow, "temp\\*", 0, 0)
r1 = Rule(Actions.Allow, "temp\\file.txt", 0, 0)
r2 = Rule(Actions.Allow, "?\\file.txt", 0, 0)
r1 = Rule(Actions.Allow, "temp\\file.txt", 0, 0)
r2 = Rule(Actions.Allow, "temp\\*.txt", 0, 0)
r1 = Rule(Actions.Allow, "file3.*", 0, 0)
r2 = Rule(Actions.Allow, "*", 0, 0)
r1 = Rule(Actions.Allow, "?\*.dll", 0, 0)
r2 = Rule(Actions.Allow, "?\*", 0, 0)
r1 = Rule(Actions.Allow, "", 0, 0)
r2 = Rule(Actions.Allow, "", 0, 0)
r1 = Rule(Actions.Allow, "temp\\file.txt", 0, 0)
r2 = Rule(Actions.Allow, "?\\file.txt", 0, 0)
# print(r1.is_subset_of(r2))
# test()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment