Info
This post is auto-generated from RSS feed The Rust Programming Language Forum - Latest topics. Source: BFS is encountering OOM in Rust but not Python
So, the goal of my script is to identify all of the nodes that lie on any path between any source-sink pair. The issue is that my graph can be pretty wide. I prototyped a solution in python first, got that working, and translated the solution over to rust (as a sanity check). Unsurprisingly, the python code is slow but works just fine, while (very surprisingly) the Rust version actually runs out of memory.
This is my working version that still OOMs, but much more slowly after adding an Arc
(very much a skill issue). For completeness, I've listed my original Rust and python code below.
fn find_paths_with_progress(
start: CMatIdx,
end: CMatIdx,
adjacency: &HashMap<CMatIdx, HashSet<CMatIdx>>,
nodes_on_path: &mut HashSet<CMatIdx>,
max_depth: usize,
) {
let mut queue = VecDeque::new();
let start_visited = Arc::new(HashSet::new());
let start_path = Arc::new(Vec::new());
queue.push_back((start, start_path, start_visited));
while let Some((current, path_arc, visited_arc)) = queue.pop_front() {
// Clone and update both visited and path ONCE at the top
let mut new_visited = visited_arc.as_ref().clone();
new_visited.insert(current);
let new_visited_arc = Arc::new(new_visited);
let mut new_path = path_arc.as_ref().clone();
new_path.push(current);
let new_path_arc = Arc::new(new_path);
if current == end {
for node in new_path_arc.iter() {
nodes_on_path.insert(*node);
}
continue;
}
if new_path_arc.len() >= max_depth {
continue;
}
// All neighbors share the same Arc instances
for &neighbor in adjacency.get(¤t).unwrap_or(&HashSet::new()) {
if !new_visited_arc.contains(&neighbor) {
let shared_visited = Arc::clone(&new_visited_arc);
let shared_path = Arc::clone(&new_path_arc);
queue.push_back((neighbor, shared_path, shared_visited));
}
}
}
}
use std::collections::VecDeque;
use std::collections::{HashMap, HashSet};
use tqdm::pbar;
pub(crate) type NeuronID = i64;
pub(crate) type CMatIdx = i64;
fn find_paths_with_progress(
start: CMatIdx,
end: CMatIdx,
adjacency: &HashMap<CMatIdx, HashSet<CMatIdx>>,
nodes_on_path: &mut HashSet<CMatIdx>,
max_depth: usize,
) {
let mut queue = VecDeque::new();
let mut start_visited = HashSet::new();
start_visited.insert(start);
queue.push_back((start, vec![start], start_visited));
while !queue.is_empty() {
let (current, path, visited) = queue.pop_front().unwrap();
if current == end {
for node in path.iter() {
nodes_on_path.insert(*node);
}
continue;
}
if path.len() >= max_depth {
continue;
}
for neighbor in adjacency.get(¤t).unwrap_or(&HashSet::new()) {
if !visited.contains(neighbor) {
let mut new_visited = visited.clone();
new_visited.insert(*neighbor);
let mut new_path = path.clone();
new_path.push(*neighbor);
queue.push_back((*neighbor, new_path, new_visited));
}
}
}
}
What's wild to me is that the python naively stores all of the paths between the start and end before returning, but seems to work just fine.
def find_paths_with_progress(
start: int, end: int, adjacency: dict[int, set[int]], max_depth: int
) -> list[list[int]]:
"""Find all simple paths from start to end with depth limit."""
paths = []
queue = deque([(start, [start], {start})])
while queue:
current, path, visited = queue.popleft()
if current == end:
paths.append(path)
continue
if len(path) >= max_depth:
continue
for neighbor in adjacency.get(current, set()):
if neighbor not in visited:
new_visited = visited | {neighbor}
queue.append((neighbor, path + [neighbor], new_visited))
return paths
Note I'm exposing the rust implementation to python via maturin
, and PyO3, so the function call is essentially a simple swapping out of the function from:
find_paths_with_progress(start, end, adjacency, max_depth)
# Over to
rust_pkg_name.find_paths_with_progress(start, end, adjacency, max_depth)
so all the parameters and such are exactly the same
1 post - 1 participant
🏷️ rust_feed