Skip to content

Instantly share code, notes, and snippets.

@Teivaz
Created December 20, 2024 08:53
Show Gist options
  • Save Teivaz/ee48e66733eb21a2f00004f02845f99a to your computer and use it in GitHub Desktop.
Save Teivaz/ee48e66733eb21a2f00004f02845f99a to your computer and use it in GitHub Desktop.
from queue import PriorityQueue
from typing import *
from dataclasses import dataclass, field
from typing import Any
Node = Any
def null_heuristic(node: Node) -> int:
return 0
def astar(
start: Node,
adjacent: Callable[[Node], Iterator[Tuple[Node, int]]],
is_target: Callable[[Node], bool],
heuristic: Callable[[Node], int]=null_heuristic,
trail: bool = False,
) -> Tuple[int, List[Node]]:
'''
@param `adjacent` - function that takes the position and returns sequence of position + cost tuples
```
def adjacent(position: Any) -> Iterator[Tuple[Any, float]]:
yield adj_position, cost
```
'''
@dataclass(order=True)
class PrioritizedNode:
priority: int
cost: int=field(compare=False)
item: Any=field(compare=False)
trail: List[Any]=field(compare=False)
nodes = PriorityQueue()
nodes.put(PrioritizedNode(0, 0, start, []))
visited = {}
visited[start] = 0
while not nodes.empty():
node = nodes.get()
if is_target(node.item):
return (node.priority, node.trail)
for nextnode, nextcost in adjacent(node.item):
nextnodecost = node.cost + nextcost
if nextnode in visited:
if visited[nextnode] <= nextnodecost:
continue
visited[nextnode] = nextnodecost
nextnodepriority = nextnodecost + heuristic(nextnode)
if trail:
nodes.put(PrioritizedNode(nextnodepriority, nextnodecost, nextnode, [*node.trail, nextnode]))
else:
nodes.put(PrioritizedNode(nextnodepriority, nextnodecost, nextnode, []))
return None, []
def dijkstra(start, adjacent, is_target) -> int:
'''
@param `adjacent` - function that takes the position and returns sequence of position + cost tuples
```
def adjacent(position: Any) -> Iterator[Tuple[Any, float]]:
yield adj_position, cost
```
'''
@dataclass(order=True)
class PrioritizedItem:
priority: int
item: Any=field(compare=False)
position_pq = PriorityQueue()
position_pq.put(PrioritizedItem(0, start))
visited = {}
visited[start] = 0
while not position_pq.empty():
item = position_pq.get()
if is_target(item.item):
return item.priority
for nextpos, nextcost in adjacent(item.item):
if nextpos in visited:
if visited[nextpos] < item.priority+nextcost:
continue
visited[nextpos] = item.priority+nextcost
position_pq.put(PrioritizedItem(item.priority+nextcost, nextpos))
return None
def astar_all_top(start, adjacent, is_target) -> Tuple[complex, float]:
'''
@param `adjacent` - function that takes the position and returns sequence of position + cost tuples
```
def adjacent(position: Any) -> Iterator[Tuple[Any, float]]:
yield adj_position, cost
```
'''
@dataclass(order=True)
class PrioritizedItem:
priority: int
item: Any=field(compare=False)
trail: List[Any]=field(compare=False)
position_pq = PriorityQueue()
position_pq.put(PrioritizedItem(0, start, []))
visited = {}
visited[start] = 0
best = []
while not position_pq.empty():
item = position_pq.get()
if is_target(item.item):
if best:
if item.priority > best[0][0]:
break
best.append((item.priority, item.trail))
for nextpos, nextcost in adjacent(item.item):
if nextpos in visited:
if visited[nextpos] < item.priority+nextcost:
continue
visited[nextpos] = item.priority+nextcost
position_pq.put(PrioritizedItem(item.priority+nextcost, nextpos, [*item.trail, nextpos]))
return best
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment