Created
July 6, 2025 04:48
-
-
Save viksit/c67d1d960c4cec89488290496defb324 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Code for the blog post | |
# Optimizing Tool Selection for LLM Workflows: Differentiable Programming with PyTorch and DSPy | |
# How local, learnable routers can reduce token overhead, lower costs, and bring structure back to agentic workflows. | |
# https://viksit.substack.com/p/optimizing-tool-selection-for-llm | |
# Ping @viksit on X with feedback/questions | |
# ---------------------------------------------------- | |
import torch, torch.nn as nn, torch.nn.functional as F | |
## --- simple RNN for tool controller -- | |
## can be a more sophisticated encoder model too! | |
class ToolController(nn.Module): | |
def __init__(self, vocab: dict, dim: int = 64): | |
super().__init__() | |
self.vocab, self.emb = vocab, nn.Embedding(256, dim) | |
self.rnn, self.lin = nn.GRU(dim, dim, batch_first=True), nn.Linear(dim, 2) | |
def _tok(self, txt): | |
ids = [self.vocab.setdefault(w, len(self.vocab)) for w in txt.lower().split()] | |
return torch.tensor([ids]) | |
def forward(self, txt): | |
x, _ = self.rnn(self.emb(self._tok(txt))) | |
return F.gumbel_softmax(self.lin(x[:,-1]), hard=False) # probs [2] | |
## --- train via a synthetic dataset ----- | |
import random | |
NUM_SYN = 400 # keep tiny so cell runs fast | |
search_queries = ["who is ceo of dropbox", | |
"population of japan", | |
"capital of france", | |
"define large language model"] | |
calc_queries = ["2 + 2", "15 * 7", "sqrt(81)", "log(100, 10)"] | |
dataset = [(q,0) for q in search_queries for _ in range(NUM_SYN//8)] + \ | |
[(q,1) for q in calc_queries for _ in range(NUM_SYN//8)] | |
random.shuffle(dataset) | |
# ------------------ 3. Train controller ------------------ | |
vocab, net = {}, ToolController(vocab={}) | |
opt = torch.optim.Adam(net.parameters(), lr=3e-3) | |
loss_fn = nn.NLLLoss() | |
for epoch in range(4): | |
tot = 0 | |
for q,label in dataset: | |
probs = net(q) # tensor [1,2] | |
loss = loss_fn(torch.log(probs+1e-9), torch.tensor([label])) | |
opt.zero_grad(); loss.backward(); opt.step() | |
tot += loss.item() | |
print(f"epoch {epoch} avg-loss {tot/len(dataset):.4f}") | |
## --- | |
epoch 0 avg-loss 0.0669 | |
epoch 1 avg-loss 0.0003 | |
epoch 2 avg-loss 0.0002 | |
epoch 3 avg-loss 0.0001 | |
## exercise the code above | |
def route(q: str): | |
probs = net(q) | |
tool = "SEARCH" if torch.argmax(probs).item() == 0 else "CALCULATE" | |
return tool, probs.detach().numpy().round(3).tolist() | |
tests = ["who is ceo of dropbox", | |
"2 + 2", | |
"define transformers architecture", | |
"sqrt(256)"] | |
for t in tests: | |
tool, p = route(t) | |
print(f"{t:<40} → {tool:10} probs={p}") | |
## output | |
who is ceo of dropbox → SEARCH probs=[[1.0, 0.0]] | |
2 + 2 → CALCULATE probs=[[0.0, 1.0]] | |
define transformers architecture → SEARCH probs=[[0.9990000128746033, 0.0010000000474974513]] | |
sqrt(256) → CALCULATE probs=[[0.18299999833106995, 0.8169999718666077]] | |
## Add on the ToolController to DSPy | |
import dspy | |
# Define DSPy tools | |
class SearchTool(dspy.Module): | |
def forward(self, query: str) -> dspy.Prediction: | |
return dspy.Prediction(result=f"SEARCH") | |
class CalcTool(dspy.Module): | |
def forward(self, query: str) -> dspy.Prediction: | |
return dspy.Prediction(result=f"CALCULATE") | |
# Router module that uses the controller | |
class DiffRouter(dspy.Module): | |
def __init__(self, controller, tools: dict[str, dspy.Module]): | |
super().__init__() | |
self.controller = controller | |
self.tools = tools | |
self.tool_keys = list(tools.keys()) | |
def forward(self, query: str) -> dspy.Prediction: | |
with torch.no_grad(): | |
probs = self.controller(query) | |
selected = self.tool_keys[int(probs.argmax())] | |
return self.tools[selected](query=query) | |
# Instantiate and run | |
vocab = {} | |
tools = { | |
"search": SearchTool(), | |
"calc": CalcTool() | |
} | |
router = DiffRouter(net, tools) | |
tests = ["who is ceo of dropbox", | |
"2 + 2", | |
"define transformers architecture", | |
"sqrt(256)"] | |
for query in tests: | |
result = router(query) | |
print("Selected tool output:", result.result) | |
## Dspy output | |
Selected tool output: SEARCH | |
Selected tool output: CALCULATE | |
Selected tool output: SEARCH | |
Selected tool output: CALCULATE |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment