Skip to content

Instantly share code, notes, and snippets.

@viksit
Created July 6, 2025 04:48
Show Gist options
  • Save viksit/c67d1d960c4cec89488290496defb324 to your computer and use it in GitHub Desktop.
Save viksit/c67d1d960c4cec89488290496defb324 to your computer and use it in GitHub Desktop.
# Code for the blog post
# Optimizing Tool Selection for LLM Workflows: Differentiable Programming with PyTorch and DSPy
# How local, learnable routers can reduce token overhead, lower costs, and bring structure back to agentic workflows.
# https://viksit.substack.com/p/optimizing-tool-selection-for-llm
# Ping @viksit on X with feedback/questions
# ----------------------------------------------------
import torch, torch.nn as nn, torch.nn.functional as F
## --- simple RNN for tool controller --
## can be a more sophisticated encoder model too!
class ToolController(nn.Module):
def __init__(self, vocab: dict, dim: int = 64):
super().__init__()
self.vocab, self.emb = vocab, nn.Embedding(256, dim)
self.rnn, self.lin = nn.GRU(dim, dim, batch_first=True), nn.Linear(dim, 2)
def _tok(self, txt):
ids = [self.vocab.setdefault(w, len(self.vocab)) for w in txt.lower().split()]
return torch.tensor([ids])
def forward(self, txt):
x, _ = self.rnn(self.emb(self._tok(txt)))
return F.gumbel_softmax(self.lin(x[:,-1]), hard=False) # probs [2]
## --- train via a synthetic dataset -----
import random
NUM_SYN = 400 # keep tiny so cell runs fast
search_queries = ["who is ceo of dropbox",
"population of japan",
"capital of france",
"define large language model"]
calc_queries = ["2 + 2", "15 * 7", "sqrt(81)", "log(100, 10)"]
dataset = [(q,0) for q in search_queries for _ in range(NUM_SYN//8)] + \
[(q,1) for q in calc_queries for _ in range(NUM_SYN//8)]
random.shuffle(dataset)
# ------------------ 3. Train controller ------------------
vocab, net = {}, ToolController(vocab={})
opt = torch.optim.Adam(net.parameters(), lr=3e-3)
loss_fn = nn.NLLLoss()
for epoch in range(4):
tot = 0
for q,label in dataset:
probs = net(q) # tensor [1,2]
loss = loss_fn(torch.log(probs+1e-9), torch.tensor([label]))
opt.zero_grad(); loss.backward(); opt.step()
tot += loss.item()
print(f"epoch {epoch} avg-loss {tot/len(dataset):.4f}")
## ---
epoch 0 avg-loss 0.0669
epoch 1 avg-loss 0.0003
epoch 2 avg-loss 0.0002
epoch 3 avg-loss 0.0001
## exercise the code above
def route(q: str):
probs = net(q)
tool = "SEARCH" if torch.argmax(probs).item() == 0 else "CALCULATE"
return tool, probs.detach().numpy().round(3).tolist()
tests = ["who is ceo of dropbox",
"2 + 2",
"define transformers architecture",
"sqrt(256)"]
for t in tests:
tool, p = route(t)
print(f"{t:<40}{tool:10} probs={p}")
## output
who is ceo of dropboxSEARCH probs=[[1.0, 0.0]]
2 + 2CALCULATE probs=[[0.0, 1.0]]
define transformers architectureSEARCH probs=[[0.9990000128746033, 0.0010000000474974513]]
sqrt(256) → CALCULATE probs=[[0.18299999833106995, 0.8169999718666077]]
## Add on the ToolController to DSPy
import dspy
# Define DSPy tools
class SearchTool(dspy.Module):
def forward(self, query: str) -> dspy.Prediction:
return dspy.Prediction(result=f"SEARCH")
class CalcTool(dspy.Module):
def forward(self, query: str) -> dspy.Prediction:
return dspy.Prediction(result=f"CALCULATE")
# Router module that uses the controller
class DiffRouter(dspy.Module):
def __init__(self, controller, tools: dict[str, dspy.Module]):
super().__init__()
self.controller = controller
self.tools = tools
self.tool_keys = list(tools.keys())
def forward(self, query: str) -> dspy.Prediction:
with torch.no_grad():
probs = self.controller(query)
selected = self.tool_keys[int(probs.argmax())]
return self.tools[selected](query=query)
# Instantiate and run
vocab = {}
tools = {
"search": SearchTool(),
"calc": CalcTool()
}
router = DiffRouter(net, tools)
tests = ["who is ceo of dropbox",
"2 + 2",
"define transformers architecture",
"sqrt(256)"]
for query in tests:
result = router(query)
print("Selected tool output:", result.result)
## Dspy output
Selected tool output: SEARCH
Selected tool output: CALCULATE
Selected tool output: SEARCH
Selected tool output: CALCULATE
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment