Skip to content

Instantly share code, notes, and snippets.

@dirkgr
Last active March 5, 2025 19:57
Show Gist options
  • Save dirkgr/f37e1d37bb01c0dcd5cd759f1326d8d3 to your computer and use it in GitHub Desktop.
Save dirkgr/f37e1d37bb01c0dcd5cd759f1326d8d3 to your computer and use it in GitHub Desktop.
This submits server logs to Google Gemini and lets Gemini analyze them.
#
# Run like this:
# beaker experiment logs <experiment id> | python logs.py
#
import sys
import re
from collections import defaultdict, Counter
from typing import List, Dict
import vertexai
from vertexai.generative_models import GenerativeModel, Part
NUM_LINES = 100
def main():
vertexai.init(project="ai2-allennlp", location="us-central1")
model = GenerativeModel("gemini-1.5-pro-002")
def response_from_gemini(logs: List[str]) -> str:
response = model.generate_content(
[
Part.from_data(
mime_type="text/plain",
data="\n".join(logs).encode("UTF8")
),
f"""These are the last {NUM_LINES} lines of logs from compute nodes of a distributed computation job that's failing, or has already failed. It is likely that one or two of these nodes show something different than the others. Which nodes are suspicious?""",
],
generation_config={
"max_output_tokens": 1024,
"temperature": 1,
"top_p": 0.95,
},
stream=False,
)
return response.text
# read the logs
if len(sys.argv) <= 1 or sys.argv[1] == "-":
f = sys.stdin
else:
f = open(sys.argv[1])
raw_logs = [line.rstrip() for line in f]
node_starts = [i for i, line in enumerate(raw_logs) if line.startswith("Logs for job ")]
node_starts.append(len(raw_logs))
logs_per_node: Dict[int, List[str]] = {}
for node_rank, (node_start, node_end) in enumerate(zip(node_starts[:-1], node_starts[1:])):
logs_per_node[node_rank] = raw_logs[node_start:node_end]
def fallback():
truncated_logs = []
for node_start, node_end in zip(node_starts[:-1], node_starts[1:]):
truncated_logs.append(raw_logs[node_start])
truncated_logs.extend(raw_logs[max(node_start, node_end - 10):node_end])
print(response_from_gemini(truncated_logs))
return
# find steps
find_steps_re = re.compile(r".*\[step=(\d+)/(\d+)")
steps_per_node = defaultdict(set)
for node_rank, logs in logs_per_node.items():
for line in logs:
match = find_steps_re.match(line)
if match:
steps_per_node[node_rank].add((int(match.group(1)), int(match.group(2))))
# If we can't find steps, fall back to the last 1000 lines from each rank
if len(steps_per_node) < len(logs_per_node):
sys.stderr.write(f"Could not find step counts in these logs. Falling back to the last {NUM_LINES} lines for each node rank.\n")
fallback()
return
# find the last step
final_step_counts = Counter()
for steps in steps_per_node.values():
for _, final_step in steps:
final_step_counts[final_step] += 1
most_common_final_step = final_step_counts.most_common(1)[0][0]
# delete all step counters that aren't most common
for node_rank in steps_per_node.keys():
steps_per_node[node_rank] = {
step for step, final_step in steps_per_node[node_rank]
if final_step == most_common_final_step
}
if len(steps_per_node[node_rank]) <= 0:
sys.stderr.write(f"Node rank {node_rank} doesn't seem to have any proper steps. Falling back to the last 1000 lines for each node rank.\n")
fallback()
return
# truncate the logs from each node rank to start three before the last step, and extend for 1000 lines
steps_in_all_ranks = set.union(*steps_per_node.values())
if len(steps_in_all_ranks) <= 0:
sys.stderr.write(f"No common steps across node ranks. Falling back to the last {NUM_LINES} lines for each node rank.\n")
fallback()
return
last_included_step = max(1, max(steps_in_all_ranks) - 3)
while last_included_step < max(steps_in_all_ranks):
if last_included_step in steps_in_all_ranks:
break
last_included_step += 1
last_included_step_substring = f"[step={last_included_step}/{most_common_final_step}"
for node_rank, node_logs in logs_per_node.items():
first_line = node_logs.pop(0)
first_included_line_number = max(
line_number for line_number, line in enumerate(node_logs)
if last_included_step_substring in line
)
last_included_line_number = min(first_included_line_number + NUM_LINES, len(node_logs))
logs_per_node[node_rank] = [first_line] + node_logs[first_included_line_number:last_included_line_number]
# send it all to Gemini
concatenated_logs = []
for logs in logs_per_node.values():
concatenated_logs.extend(logs)
print(response_from_gemini(concatenated_logs))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment