Last active
July 19, 2024 17:54
-
-
Save thomasmarwitz/dc3a15e280c6ea367f96f87157973362 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Embedding, UMAP, Edge, Path } from "./types"; | |
/** | |
* Calculates the Euclidean distance between two UMAP points. | |
* @param umap1 - The first UMAP point | |
* @param umap2 - The second UMAP point | |
* @returns The Euclidean distance between the two points | |
*/ | |
function calculateDistance(umap1: UMAP, umap2: UMAP): number { | |
return Math.sqrt( | |
Math.pow(umap1[0] - umap2[0], 2) + | |
Math.pow(umap1[1] - umap2[1], 2) + | |
Math.pow(umap1[2] - umap2[2], 2) | |
); | |
} | |
/** | |
* Implements Dijkstra's algorithm to find the shortest path between two embeddings. | |
* @param embeddings - Array of all embeddings | |
* @param startId - ID of the starting embedding | |
* @param endId - ID of the ending embedding | |
* @returns An object containing the path and total distance, or null if no path is found | |
*/ | |
export function dijkstra( | |
embeddings: Embedding[], | |
startId: string, | |
endId: string | |
): { path: Path; totalDistance: number } | null { | |
if (startId === endId) { | |
throw new Error("Start and end nodes are the same"); | |
} | |
const distances: { [id: string]: number } = {}; | |
const previous: { [id: string]: string | null } = {}; | |
const unvisited = new Set<string>(); | |
// Initialize distances, previous nodes, and the unvisited set | |
embeddings.forEach((embedding) => { | |
distances[embedding.id] = embedding.id === startId ? 0 : Infinity; | |
previous[embedding.id] = null; | |
unvisited.add(embedding.id); | |
}); | |
while (unvisited.size > 0) { | |
// Select the unvisited node with the smallest distance | |
let current = Array.from(unvisited).reduce((a, b) => | |
distances[a] < distances[b] ? a : b | |
); | |
// If we've reached the end or there's no path, exit the loop | |
if (current === endId || distances[current] === Infinity) { | |
break; | |
} | |
unvisited.delete(current); | |
const currentEmbedding = embeddings.find((e) => e.id === current); | |
if (!currentEmbedding) continue; | |
// Check all neighboring nodes | |
for (const neighborId of currentEmbedding.neighbors) { | |
if (!unvisited.has(neighborId)) continue; | |
const neighbor = embeddings.find((e) => e.id === neighborId); | |
if (!neighbor) continue; | |
// Calculate the distance to the neighbor through the current node | |
const distance = calculateDistance(currentEmbedding.umap, neighbor.umap); | |
const totalDistance = distances[current] + distance; | |
// If this path to the neighbor is shorter, update it | |
if (totalDistance < distances[neighborId]) { | |
distances[neighborId] = totalDistance; | |
previous[neighborId] = current; | |
} | |
} | |
} | |
// If no path was found to the end node, return null | |
if (previous[endId] === null) { | |
return null; | |
} | |
// Reconstruct the path | |
const path: Path = []; | |
let current: string | null = endId; | |
let prev: string | null = previous[endId]; | |
while (prev !== null) { | |
const fromEmbedding = embeddings.find((e) => e.id === prev); | |
const toEmbedding = embeddings.find((e) => e.id === current); | |
if (fromEmbedding && toEmbedding) { | |
const edge: Edge = { | |
from: fromEmbedding, | |
to: toEmbedding, | |
distance: calculateDistance(fromEmbedding.umap, toEmbedding.umap), | |
}; | |
path.unshift(edge); | |
} | |
current = prev; | |
prev = previous[current]; | |
} | |
return { path, totalDistance: distances[endId] }; | |
} | |
/** | |
* Finds and logs the shortest path between two embeddings. | |
* @param embeddings - Array of all embeddings | |
* @param startId - ID of the starting embedding | |
* @param endId - ID of the ending embedding | |
*/ | |
export function findPath( | |
embeddings: Embedding[], | |
startId: string, | |
endId: string | |
) { | |
const result = dijkstra(embeddings, startId, endId); | |
if (result) { | |
console.log("Shortest path:"); | |
result.path.forEach((edge, index) => { | |
console.log( | |
`Step ${index + 1}: ${edge.from.id} -> ${ | |
edge.to.id | |
} (distance: ${edge.distance.toFixed(2)})` | |
); | |
}); | |
console.log(`Total distance: ${result.totalDistance.toFixed(2)}`); | |
} else { | |
console.log("No path found"); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Embedding, UMAP } from './types'; | |
/** | |
* Computes the vector between two points in 3D space. | |
* @param point1 - The starting point. | |
* @param point2 - The ending point. | |
* @returns The vector from point1 to point2. | |
*/ | |
function computeVector(point1: UMAP, point2: UMAP): UMAP { | |
return [ | |
point2[0] - point1[0], | |
point2[1] - point1[1], | |
point2[2] - point1[2] | |
]; | |
} | |
/** | |
* Adds a vector to a point in 3D space. | |
* @param point - The original point. | |
* @param vector - The vector to be added. | |
* @returns The resulting point after vector addition. | |
*/ | |
function addVector(point: UMAP, vector: UMAP): UMAP { | |
return [ | |
point[0] + vector[0], | |
point[1] + vector[1], | |
point[2] + vector[2] | |
]; | |
} | |
/** | |
* Calculates the Euclidean distance between two points in 3D space. | |
* @param point1 - The first point. | |
* @param point2 - The second point. | |
* @returns The distance between the two points. | |
*/ | |
function distance(point1: UMAP, point2: UMAP): number { | |
return Math.sqrt( | |
Math.pow(point2[0] - point1[0], 2) + | |
Math.pow(point2[1] - point1[1], 2) + | |
Math.pow(point2[2] - point1[2], 2) | |
); | |
} | |
/** | |
* Finds the closest embedding within a given threshold distance. | |
* @param potentialEnd - The point to find the closest embedding to. | |
* @param embeddings - The array of all embeddings. | |
* @param threshold - The maximum distance to consider. | |
* @param currentId - The ID of the current embedding to exclude from the search. | |
* @returns The ID of the closest embedding, or null if none found within the threshold. | |
*/ | |
function findClosestEmbedding( | |
potentialEnd: UMAP, | |
embeddings: Embedding[], | |
threshold: number, | |
currentId: string | |
): string | null { | |
let closestId: string | null = null; | |
let minDistance = threshold; | |
for (const embedding of embeddings) { | |
if (embedding.id === currentId) continue; | |
const dist = distance(potentialEnd, embedding.umap); | |
if (dist <= minDistance) { | |
minDistance = dist; | |
closestId = embedding.id; | |
} | |
} | |
return closestId; | |
} | |
/** | |
* Finds similar relationships in the embeddings based on a given start and end point. | |
* @param embeddings - The array of all embeddings. | |
* @param startPoint - The starting point of the relationship vector. | |
* @param endPoint - The ending point of the relationship vector. | |
* @param threshold - The maximum distance to consider for similar relationships. | |
* @returns An array of pairs of embedding IDs representing similar relationships. | |
*/ | |
function findSimilarRelationships( | |
embeddings: Embedding[], | |
startPoint: UMAP, | |
endPoint: UMAP, | |
threshold: number | |
): Array<[string, string]> { | |
const relationshipVector = computeVector(startPoint, endPoint); | |
return embeddings | |
.map(embedding => { | |
// For each embedding, apply the relationship vector and find the closest other embedding | |
const potentialEnd = addVector(embedding.umap, relationshipVector); | |
const closestId = findClosestEmbedding(potentialEnd, embeddings, threshold, embedding.id); | |
return closestId ? [embedding.id, closestId] as [string, string] : null; | |
}) | |
.filter((pair): pair is [string, string] => pair !== null); | |
// Filter out any null results, keeping only valid relationship pairs | |
} | |
/** | |
* Main function to find relationships similar to the one between two given points. | |
* @param embeddings - The array of all embeddings. | |
* @param point1Id - The ID of the first point defining the relationship. | |
* @param point2Id - The ID of the second point defining the relationship. | |
* @param threshold - The maximum distance to consider for similar relationships. | |
* @returns An array of pairs of embedding IDs representing similar relationships. | |
* @throws Error if either of the specified points is not found in the embeddings. | |
*/ | |
function findRelationships( | |
embeddings: Embedding[], | |
point1Id: string, | |
point2Id: string, | |
threshold: number | |
): Array<[string, string]> { | |
const point1 = embeddings.find(e => e.id === point1Id); | |
const point2 = embeddings.find(e => e.id === point2Id); | |
if (!point1 || !point2) { | |
throw new Error('One or both points not found'); | |
} | |
return findSimilarRelationships(embeddings, point1.umap, point2.umap, threshold); | |
} | |
export { findRelationships }; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Embedding, UMAP } from "./types"; | |
import { findPath, dijkstra } from "./dijkstra"; | |
function assert(condition: boolean, message: string): asserts condition { | |
if (!condition) { | |
throw new Error(message); | |
} | |
} | |
function arrayEqual(arr1: any[], arr2: any[]): boolean { | |
if (arr1.length !== arr2.length) return false; | |
for (let i = 0; i < arr1.length; i++) { | |
if (arr1[i] !== arr2[i]) return false; | |
} | |
return true; | |
} | |
function createEmbedding( | |
id: string, | |
umap: UMAP, | |
neighbors: string[] | |
): Embedding { | |
return { | |
id, | |
umap, | |
neighbors, | |
image_url: "<not-found>", | |
}; | |
} | |
const testEmbeddings: Embedding[] = [ | |
createEmbedding("A", [0, 0, 0], ["B", "C"]), | |
createEmbedding("B", [1, 1, 1], ["A", "D", "E"]), | |
createEmbedding("C", [2, 0, 1], ["A", "D"]), | |
createEmbedding("D", [3, 1, 2], ["B", "C", "E"]), | |
createEmbedding("E", [4, 2, 2], ["B", "D"]), | |
]; | |
function assertPath( | |
result: ReturnType<typeof dijkstra>, | |
expectedPath: string[] | |
) { | |
assert(result !== null, "Expected a path, but got null"); | |
console.log( | |
"Actual path:", | |
result.path.map((edge) => `${edge.from.id}->${edge.to.id}`).join(", ") | |
); | |
console.log("Expected path:", expectedPath.join(", ")); | |
assert( | |
arrayEqual( | |
result.path.map((edge) => edge.from.id), | |
expectedPath.slice(0, -1) | |
), | |
"Path does not match expected path" | |
); | |
assert( | |
result.path[result.path.length - 1].to.id === | |
expectedPath[expectedPath.length - 1], | |
"Last node in path does not match expected" | |
); | |
} | |
function runTests() { | |
console.log("Testing Dijkstra's algorithm with example embeddings:"); | |
// Test case 1: Path from A to E | |
console.log("\nFinding path from A to E:"); | |
const resultAE = dijkstra(testEmbeddings, "A", "E"); | |
findPath(testEmbeddings, "A", "E"); | |
assertPath(resultAE, ["A", "B", "E"]); | |
// Test case 2: Path from C to E | |
console.log("\nFinding path from C to E:"); | |
const resultCE = dijkstra(testEmbeddings, "C", "E"); | |
findPath(testEmbeddings, "C", "E"); | |
assertPath(resultCE, ["C", "D", "E"]); | |
// Test case 3: Path from E to A (reverse of case 1) | |
console.log("\nFinding path from E to A:"); | |
const resultEA = dijkstra(testEmbeddings, "E", "A"); | |
findPath(testEmbeddings, "E", "A"); | |
assertPath(resultEA, ["E", "B", "A"]); | |
// Test case 4: Path to unconnected node | |
console.log("\nTrying to find path between unconnected nodes:"); | |
const unconnectedEmbedding = createEmbedding("F", [10, 10, 10], []); | |
const extendedEmbeddings = [...testEmbeddings, unconnectedEmbedding]; | |
const resultAF = dijkstra(extendedEmbeddings, "A", "F"); | |
findPath(extendedEmbeddings, "A", "F"); | |
assert(resultAF === null, "Expected null result for unconnected nodes"); | |
// Test case 5: Path to self (should throw an error) | |
console.log("\nFinding path from A to A (should throw an error):"); | |
try { | |
dijkstra(testEmbeddings, "A", "A"); | |
throw new Error( | |
"Expected dijkstra to throw an error for same start and end nodes" | |
); | |
} catch (error) { | |
assert( | |
error.message === "Start and end nodes are the same", | |
"Unexpected error message" | |
); | |
console.log("Correctly threw error for same start and end nodes"); | |
} | |
console.log("\nAll tests passed successfully!"); | |
} | |
try { | |
runTests(); | |
} catch (error) { | |
console.error("Test failed:", error.message); | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
export type UMAP = [number, number, number]; | |
export interface Embedding { | |
id: string; | |
umap: UMAP; | |
neighbors: Array<Embedding["id"]>; | |
image_url: string; | |
} | |
export interface Edge { | |
to: Embedding; | |
from: Embedding; | |
distance: number; | |
} | |
export type Path = Edge[]; | |
export interface Episode extends Embedding { | |
episode: number; | |
title: string; | |
year: number; | |
url: string; | |
summary: string; | |
speakers: string; | |
details: string; | |
cover_img_url: string; | |
hash: string; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment