Skip to content

Instantly share code, notes, and snippets.

@thomasmarwitz
Last active July 19, 2024 17:54
Show Gist options
  • Save thomasmarwitz/dc3a15e280c6ea367f96f87157973362 to your computer and use it in GitHub Desktop.
Save thomasmarwitz/dc3a15e280c6ea367f96f87157973362 to your computer and use it in GitHub Desktop.
import { Embedding, UMAP, Edge, Path } from "./types";
/**
* Calculates the Euclidean distance between two UMAP points.
* @param umap1 - The first UMAP point
* @param umap2 - The second UMAP point
* @returns The Euclidean distance between the two points
*/
function calculateDistance(umap1: UMAP, umap2: UMAP): number {
return Math.sqrt(
Math.pow(umap1[0] - umap2[0], 2) +
Math.pow(umap1[1] - umap2[1], 2) +
Math.pow(umap1[2] - umap2[2], 2)
);
}
/**
* Implements Dijkstra's algorithm to find the shortest path between two embeddings.
* @param embeddings - Array of all embeddings
* @param startId - ID of the starting embedding
* @param endId - ID of the ending embedding
* @returns An object containing the path and total distance, or null if no path is found
*/
export function dijkstra(
embeddings: Embedding[],
startId: string,
endId: string
): { path: Path; totalDistance: number } | null {
if (startId === endId) {
throw new Error("Start and end nodes are the same");
}
const distances: { [id: string]: number } = {};
const previous: { [id: string]: string | null } = {};
const unvisited = new Set<string>();
// Initialize distances, previous nodes, and the unvisited set
embeddings.forEach((embedding) => {
distances[embedding.id] = embedding.id === startId ? 0 : Infinity;
previous[embedding.id] = null;
unvisited.add(embedding.id);
});
while (unvisited.size > 0) {
// Select the unvisited node with the smallest distance
let current = Array.from(unvisited).reduce((a, b) =>
distances[a] < distances[b] ? a : b
);
// If we've reached the end or there's no path, exit the loop
if (current === endId || distances[current] === Infinity) {
break;
}
unvisited.delete(current);
const currentEmbedding = embeddings.find((e) => e.id === current);
if (!currentEmbedding) continue;
// Check all neighboring nodes
for (const neighborId of currentEmbedding.neighbors) {
if (!unvisited.has(neighborId)) continue;
const neighbor = embeddings.find((e) => e.id === neighborId);
if (!neighbor) continue;
// Calculate the distance to the neighbor through the current node
const distance = calculateDistance(currentEmbedding.umap, neighbor.umap);
const totalDistance = distances[current] + distance;
// If this path to the neighbor is shorter, update it
if (totalDistance < distances[neighborId]) {
distances[neighborId] = totalDistance;
previous[neighborId] = current;
}
}
}
// If no path was found to the end node, return null
if (previous[endId] === null) {
return null;
}
// Reconstruct the path
const path: Path = [];
let current: string | null = endId;
let prev: string | null = previous[endId];
while (prev !== null) {
const fromEmbedding = embeddings.find((e) => e.id === prev);
const toEmbedding = embeddings.find((e) => e.id === current);
if (fromEmbedding && toEmbedding) {
const edge: Edge = {
from: fromEmbedding,
to: toEmbedding,
distance: calculateDistance(fromEmbedding.umap, toEmbedding.umap),
};
path.unshift(edge);
}
current = prev;
prev = previous[current];
}
return { path, totalDistance: distances[endId] };
}
/**
* Finds and logs the shortest path between two embeddings.
* @param embeddings - Array of all embeddings
* @param startId - ID of the starting embedding
* @param endId - ID of the ending embedding
*/
export function findPath(
embeddings: Embedding[],
startId: string,
endId: string
) {
const result = dijkstra(embeddings, startId, endId);
if (result) {
console.log("Shortest path:");
result.path.forEach((edge, index) => {
console.log(
`Step ${index + 1}: ${edge.from.id} -> ${
edge.to.id
} (distance: ${edge.distance.toFixed(2)})`
);
});
console.log(`Total distance: ${result.totalDistance.toFixed(2)}`);
} else {
console.log("No path found");
}
}
import { Embedding, UMAP } from './types';
/**
* Computes the vector between two points in 3D space.
* @param point1 - The starting point.
* @param point2 - The ending point.
* @returns The vector from point1 to point2.
*/
function computeVector(point1: UMAP, point2: UMAP): UMAP {
return [
point2[0] - point1[0],
point2[1] - point1[1],
point2[2] - point1[2]
];
}
/**
* Adds a vector to a point in 3D space.
* @param point - The original point.
* @param vector - The vector to be added.
* @returns The resulting point after vector addition.
*/
function addVector(point: UMAP, vector: UMAP): UMAP {
return [
point[0] + vector[0],
point[1] + vector[1],
point[2] + vector[2]
];
}
/**
* Calculates the Euclidean distance between two points in 3D space.
* @param point1 - The first point.
* @param point2 - The second point.
* @returns The distance between the two points.
*/
function distance(point1: UMAP, point2: UMAP): number {
return Math.sqrt(
Math.pow(point2[0] - point1[0], 2) +
Math.pow(point2[1] - point1[1], 2) +
Math.pow(point2[2] - point1[2], 2)
);
}
/**
* Finds the closest embedding within a given threshold distance.
* @param potentialEnd - The point to find the closest embedding to.
* @param embeddings - The array of all embeddings.
* @param threshold - The maximum distance to consider.
* @param currentId - The ID of the current embedding to exclude from the search.
* @returns The ID of the closest embedding, or null if none found within the threshold.
*/
function findClosestEmbedding(
potentialEnd: UMAP,
embeddings: Embedding[],
threshold: number,
currentId: string
): string | null {
let closestId: string | null = null;
let minDistance = threshold;
for (const embedding of embeddings) {
if (embedding.id === currentId) continue;
const dist = distance(potentialEnd, embedding.umap);
if (dist <= minDistance) {
minDistance = dist;
closestId = embedding.id;
}
}
return closestId;
}
/**
* Finds similar relationships in the embeddings based on a given start and end point.
* @param embeddings - The array of all embeddings.
* @param startPoint - The starting point of the relationship vector.
* @param endPoint - The ending point of the relationship vector.
* @param threshold - The maximum distance to consider for similar relationships.
* @returns An array of pairs of embedding IDs representing similar relationships.
*/
function findSimilarRelationships(
embeddings: Embedding[],
startPoint: UMAP,
endPoint: UMAP,
threshold: number
): Array<[string, string]> {
const relationshipVector = computeVector(startPoint, endPoint);
return embeddings
.map(embedding => {
// For each embedding, apply the relationship vector and find the closest other embedding
const potentialEnd = addVector(embedding.umap, relationshipVector);
const closestId = findClosestEmbedding(potentialEnd, embeddings, threshold, embedding.id);
return closestId ? [embedding.id, closestId] as [string, string] : null;
})
.filter((pair): pair is [string, string] => pair !== null);
// Filter out any null results, keeping only valid relationship pairs
}
/**
* Main function to find relationships similar to the one between two given points.
* @param embeddings - The array of all embeddings.
* @param point1Id - The ID of the first point defining the relationship.
* @param point2Id - The ID of the second point defining the relationship.
* @param threshold - The maximum distance to consider for similar relationships.
* @returns An array of pairs of embedding IDs representing similar relationships.
* @throws Error if either of the specified points is not found in the embeddings.
*/
function findRelationships(
embeddings: Embedding[],
point1Id: string,
point2Id: string,
threshold: number
): Array<[string, string]> {
const point1 = embeddings.find(e => e.id === point1Id);
const point2 = embeddings.find(e => e.id === point2Id);
if (!point1 || !point2) {
throw new Error('One or both points not found');
}
return findSimilarRelationships(embeddings, point1.umap, point2.umap, threshold);
}
export { findRelationships };
import { Embedding, UMAP } from "./types";
import { findPath, dijkstra } from "./dijkstra";
function assert(condition: boolean, message: string): asserts condition {
if (!condition) {
throw new Error(message);
}
}
function arrayEqual(arr1: any[], arr2: any[]): boolean {
if (arr1.length !== arr2.length) return false;
for (let i = 0; i < arr1.length; i++) {
if (arr1[i] !== arr2[i]) return false;
}
return true;
}
function createEmbedding(
id: string,
umap: UMAP,
neighbors: string[]
): Embedding {
return {
id,
umap,
neighbors,
image_url: "<not-found>",
};
}
const testEmbeddings: Embedding[] = [
createEmbedding("A", [0, 0, 0], ["B", "C"]),
createEmbedding("B", [1, 1, 1], ["A", "D", "E"]),
createEmbedding("C", [2, 0, 1], ["A", "D"]),
createEmbedding("D", [3, 1, 2], ["B", "C", "E"]),
createEmbedding("E", [4, 2, 2], ["B", "D"]),
];
function assertPath(
result: ReturnType<typeof dijkstra>,
expectedPath: string[]
) {
assert(result !== null, "Expected a path, but got null");
console.log(
"Actual path:",
result.path.map((edge) => `${edge.from.id}->${edge.to.id}`).join(", ")
);
console.log("Expected path:", expectedPath.join(", "));
assert(
arrayEqual(
result.path.map((edge) => edge.from.id),
expectedPath.slice(0, -1)
),
"Path does not match expected path"
);
assert(
result.path[result.path.length - 1].to.id ===
expectedPath[expectedPath.length - 1],
"Last node in path does not match expected"
);
}
function runTests() {
console.log("Testing Dijkstra's algorithm with example embeddings:");
// Test case 1: Path from A to E
console.log("\nFinding path from A to E:");
const resultAE = dijkstra(testEmbeddings, "A", "E");
findPath(testEmbeddings, "A", "E");
assertPath(resultAE, ["A", "B", "E"]);
// Test case 2: Path from C to E
console.log("\nFinding path from C to E:");
const resultCE = dijkstra(testEmbeddings, "C", "E");
findPath(testEmbeddings, "C", "E");
assertPath(resultCE, ["C", "D", "E"]);
// Test case 3: Path from E to A (reverse of case 1)
console.log("\nFinding path from E to A:");
const resultEA = dijkstra(testEmbeddings, "E", "A");
findPath(testEmbeddings, "E", "A");
assertPath(resultEA, ["E", "B", "A"]);
// Test case 4: Path to unconnected node
console.log("\nTrying to find path between unconnected nodes:");
const unconnectedEmbedding = createEmbedding("F", [10, 10, 10], []);
const extendedEmbeddings = [...testEmbeddings, unconnectedEmbedding];
const resultAF = dijkstra(extendedEmbeddings, "A", "F");
findPath(extendedEmbeddings, "A", "F");
assert(resultAF === null, "Expected null result for unconnected nodes");
// Test case 5: Path to self (should throw an error)
console.log("\nFinding path from A to A (should throw an error):");
try {
dijkstra(testEmbeddings, "A", "A");
throw new Error(
"Expected dijkstra to throw an error for same start and end nodes"
);
} catch (error) {
assert(
error.message === "Start and end nodes are the same",
"Unexpected error message"
);
console.log("Correctly threw error for same start and end nodes");
}
console.log("\nAll tests passed successfully!");
}
try {
runTests();
} catch (error) {
console.error("Test failed:", error.message);
}
export type UMAP = [number, number, number];
export interface Embedding {
id: string;
umap: UMAP;
neighbors: Array<Embedding["id"]>;
image_url: string;
}
export interface Edge {
to: Embedding;
from: Embedding;
distance: number;
}
export type Path = Edge[];
export interface Episode extends Embedding {
episode: number;
title: string;
year: number;
url: string;
summary: string;
speakers: string;
details: string;
cover_img_url: string;
hash: string;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment