Skip to main content

dfir_lang/graph/
flat_to_partitioned.rs

1//! Subgraph partioning algorithm
2
3use std::collections::BTreeSet;
4
5use itertools::Itertools;
6use slotmap::{SecondaryMap, SparseSecondaryMap};
7
8use super::meta_graph::DfirGraph;
9use super::ops::{DelayType, FloType};
10use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, HandoffKind};
11use crate::diagnostic::{Diagnostic, Level};
12use crate::graph::graph_algorithms::SubgraphMerge;
13
14/// Find edge barriers: edges whose destination operator declares an input delay type.
15/// Excludes edges within `loop {}` blocks.
16///
17/// Returns:
18/// - Tick/TickLazy edges keyed by edge ID (for topo-sort exclusion and handoff marking).
19/// - All barrier (src, dst) node pairs (for the enemies set).
20fn find_edge_barriers(
21    partitioned_graph: &DfirGraph,
22) -> (
23    SecondaryMap<GraphEdgeId, DelayType>,
24    Vec<(GraphNodeId, GraphNodeId)>,
25) {
26    let mut tick_edges = SecondaryMap::new();
27    let mut barrier_pairs = Vec::new();
28
29    for (edge_id, (src, dst)) in partitioned_graph.edges() {
30        // Ignore barriers within `loop {` blocks.
31        if partitioned_graph.node_loop(dst).is_some() {
32            continue;
33        }
34        let Some(op_inst) = partitioned_graph.node_op_inst(dst) else {
35            continue;
36        };
37        let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
38        let Some(delay_type) = (op_inst.op_constraints.input_delaytype_fn)(dst_port) else {
39            continue;
40        };
41
42        barrier_pairs.push((src, dst));
43        if matches!(delay_type, DelayType::Tick | DelayType::TickLazy) {
44            tick_edges.insert(edge_id, delay_type);
45        }
46    }
47
48    (tick_edges, barrier_pairs)
49}
50
51/// Find handoff reference access group ordering constraints: for the same handoff target, operators in
52/// lower access groups must run before operators in higher access groups.
53fn find_access_group_ordering(partitioned_graph: &DfirGraph) -> Vec<(GraphNodeId, GraphNodeId)> {
54    let mut pairs = Vec::new();
55    let refs_by_target = partitioned_graph.node_handoff_reference_groups();
56    for (_handoff, groups) in refs_by_target {
57        for (group_a, group_b) in groups.values().tuple_windows() {
58            for &(node_a, _, _) in group_a {
59                for &(node_b, _, _) in group_b {
60                    // TODO(mingwei): handle with diagnostics.
61                    assert_ne!(
62                        node_a, node_b,
63                        "encounted conflicted or cyclical handoff references\n{:?}\n{:?}",
64                        group_a, group_b,
65                    );
66                    pairs.push((node_a, node_b));
67                }
68            }
69        }
70    }
71    pairs
72}
73
74fn find_subgraph_unionfind(
75    partitioned_graph: &DfirGraph,
76    tick_edges: &SecondaryMap<GraphEdgeId, DelayType>,
77    edge_barrier_pairs: &[(GraphNodeId, GraphNodeId)],
78    access_group_pairs: &[(GraphNodeId, GraphNodeId)],
79) -> Result<(SubgraphMerge<GraphNodeId>, BTreeSet<GraphEdgeId>), Diagnostic> {
80    // Modality (color) of nodes, push or pull.
81    // TODO(mingwei)? This does NOT consider `DelayType` barriers (which generally imply `Pull`),
82    // which makes it inconsistant with the final output in `as_code()`. But this doesn't create
83    // any bugs since we exclude `DelayType` edges from joining subgraphs anyway.
84    let mut node_color = partitioned_graph
85        .node_ids()
86        .filter_map(|node_id| {
87            let op_color = partitioned_graph.node_color(node_id)?;
88            Some((node_id, op_color))
89        })
90        .collect::<SparseSecondaryMap<_, _>>();
91
92    // Pre-compute all predecessor edges for the topological sort.
93    let mut all_preds: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = SecondaryMap::new();
94
95    // Pipe predecessors (excluding tick edges which are cross-tick).
96    for (edge_id, (src, dst)) in partitioned_graph.edges() {
97        if !tick_edges.contains_key(edge_id) {
98            all_preds.entry(dst).unwrap().or_default().push(src);
99        }
100    }
101
102    // Handoff references: producer must run before consumer.
103    for node_id in partitioned_graph.node_ids() {
104        for handoff_ref in partitioned_graph.node_handoff_references(node_id).iter() {
105            if let Some(src) = handoff_ref.node_id {
106                all_preds.entry(node_id).unwrap().or_default().push(src);
107                // Extra ordering: if the ref target is a handoff, its pipe consumers
108                // depend on the borrower (borrower runs before consumer).
109                if let GraphNode::Handoff { .. } = partitioned_graph.node(src) {
110                    for (_edge, consumer) in partitioned_graph.node_successors(src) {
111                        all_preds
112                            .entry(consumer)
113                            .unwrap()
114                            .or_default()
115                            .push(node_id);
116                    }
117                }
118            }
119        }
120    }
121
122    // Access group ordering.
123    for &(src, dst) in access_group_pairs {
124        all_preds.entry(dst).unwrap().or_default().push(src);
125    }
126
127    // Build enemies: all node pairs that must not be in the same subgraph.
128    let enemies = edge_barrier_pairs
129        .iter()
130        .copied()
131        .chain(access_group_pairs.iter().copied())
132        .chain(partitioned_graph.node_ids().flat_map(|dst| {
133            partitioned_graph
134                .node_handoff_references(dst)
135                .iter()
136                .filter_map(|r| r.node_id)
137                .map(move |src| (src, dst))
138        }));
139
140    let mut subgraph_unionfind = SubgraphMerge::<GraphNodeId>::new(
141        partitioned_graph.node_ids(),
142        |node_id| all_preds.get(node_id).into_iter().flatten().copied(),
143        enemies,
144    )
145    .map_err(|cycle| {
146        let span = cycle
147            .first()
148            .map(|&node_id| partitioned_graph.node(node_id).span())
149            .unwrap_or_else(proc_macro2::Span::call_site);
150        let node_cycle = cycle
151            .iter()
152            .map(|&node_id| partitioned_graph.node(node_id).to_pretty_string())
153            .collect::<Vec<_>>();
154        Diagnostic::spanned(
155            span,
156            Level::Error,
157            format!(
158                "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks. \
159                Cycle: {:?}",
160                node_cycle,
161            ),
162        )
163    })?;
164
165    // Will contain all edges which need handoffs added. Starts out with all edges and
166    // we remove from this set as we combine nodes into subgraphs.
167    let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
168    // Would sort edges here for priority (for now, no sort/priority).
169
170    // Each edge gets looked at in order. However we may not know if a linear
171    // chain of operators is PUSH vs PULL until we look at the ends. A fancier
172    // algorithm would know to handle linear chains from the outside inward.
173    // But instead we just run through the edges in a loop until no more
174    // progress is made. Could have some sort of O(N^2) pathological worst
175    // case.
176    let mut progress = true;
177    while progress {
178        progress = false;
179        // TODO(mingwei): Could this iterate `handoff_edges` instead? (Modulo ownership). Then no case (1) below.
180        for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
181            // Ignore existing handoffs, remove `edge_id` from `handoff_edges`.
182            if matches!(partitioned_graph.node(src), GraphNode::Handoff { .. })
183                || matches!(partitioned_graph.node(dst), GraphNode::Handoff { .. })
184            {
185                handoff_edges.remove(&edge_id);
186                continue;
187            }
188
189            // Ignore (1) already added edges as well as (2) new self-cycles. (Unless reference edge).
190            if subgraph_unionfind.same_set(src, dst) {
191                // Note that the _edge_ `edge_id` might not be in the subgraph even when both `src` and `dst` are. This prevents case 2.
192                // Handoffs will be inserted later for this self-loop.
193                continue;
194            }
195
196            // Do not connect across loop contexts.
197            if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
198                continue;
199            }
200            // Do not connect `next_iteration()`.
201            if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
202                Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
203            }) {
204                continue;
205            }
206
207            if can_connect_colorize(&mut node_color, src, dst) {
208                // At this point we have selected this edge and its src & dst to be
209                // within a single subgraph.
210                let ok = subgraph_unionfind.try_merge(src, dst);
211                if ok {
212                    assert!(handoff_edges.remove(&edge_id));
213                    progress = true;
214                }
215            }
216        }
217    }
218
219    Ok((subgraph_unionfind, handoff_edges))
220}
221
222/// Find subgraphs and insert handoffs.
223fn make_subgraphs(
224    partitioned_graph: &mut DfirGraph,
225    tick_edges: &mut SecondaryMap<GraphEdgeId, DelayType>,
226    edge_barrier_pairs: &[(GraphNodeId, GraphNodeId)],
227    access_group_pairs: &[(GraphNodeId, GraphNodeId)],
228) -> Result<(), Diagnostic> {
229    // Algorithm:
230    // 1. Each node begins as its own subgraph.
231    // 2. Collect edges. (Future optimization: sort so edges which should not be split across a handoff come first).
232    // 3. For each edge, try to join `(to, from)` into the same subgraph.
233
234    // TODO(mingwei):
235    // self.partitioned_graph.assert_valid();
236
237    let (subgraph_merge, handoff_edges) = find_subgraph_unionfind(
238        partitioned_graph,
239        tick_edges,
240        edge_barrier_pairs,
241        access_group_pairs,
242    )?;
243
244    // Insert handoffs between subgraphs (or on subgraph self-loop edges)
245    for edge_id in handoff_edges {
246        let (src_id, dst_id) = partitioned_graph.edge(edge_id);
247
248        // Already has a handoff, no need to insert one.
249        let src_node = partitioned_graph.node(src_id);
250        let dst_node = partitioned_graph.node(dst_id);
251        if matches!(src_node, GraphNode::Handoff { .. })
252            || matches!(dst_node, GraphNode::Handoff { .. })
253        {
254            continue;
255        }
256
257        let hoff = GraphNode::Handoff {
258            kind: HandoffKind::Vec,
259            src_span: src_node.span(),
260            dst_span: dst_node.span(),
261        };
262        let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
263
264        // Update tick_edges for inserted node.
265        if let Some(delay_type) = tick_edges.remove(edge_id) {
266            tick_edges.insert(out_edge_id, delay_type);
267        }
268    }
269
270    // Register subgraphs. SubgraphMerge maintains operators in topo-sorted order per subgraph.
271    // Filter out handoff nodes — they are not part of any subgraph.
272    let mut subgraph_toposort = Vec::new();
273    for nodes in subgraph_merge.subgraphs() {
274        if nodes.is_empty() {
275            continue;
276        }
277        // Skip single-node "subgraphs" that are handoff nodes.
278        if nodes
279            .iter()
280            .any(|&n| matches!(partitioned_graph.node(n), GraphNode::Handoff { .. }))
281        {
282            continue;
283        }
284        let sg_id = partitioned_graph.insert_subgraph(nodes.to_vec()).unwrap();
285        subgraph_toposort.push(sg_id);
286    }
287    partitioned_graph.set_subgraph_toposort(subgraph_toposort);
288    Ok(())
289}
290
291/// Set `src` or `dst` color if `None` based on the other (if possible):
292/// `None` indicates an op could be pull or push i.e. unary-in & unary-out.
293/// So in that case we color `src` or `dst` based on its newfound neighbor (the other one).
294///
295/// Returns if `src` and `dst` can be in the same subgraph.
296fn can_connect_colorize(
297    node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
298    src: GraphNodeId,
299    dst: GraphNodeId,
300) -> bool {
301    // Pull -> Pull
302    // Push -> Push
303    // Pull -> [Computation] -> Push
304    // Push -> [Handoff] -> Pull
305    let can_connect = match (node_color.get(src), node_color.get(dst)) {
306        // Linear chain, can't connect because it may cause future conflicts.
307        // But if it doesn't in the _future_ we can connect it (once either/both ends are determined).
308        (None, None) => false,
309
310        // Infer left side.
311        (None, Some(Color::Pull | Color::Comp)) => {
312            node_color.insert(src, Color::Pull);
313            true
314        }
315        (None, Some(Color::Push | Color::Hoff)) => {
316            node_color.insert(src, Color::Push);
317            true
318        }
319
320        // Infer right side.
321        (Some(Color::Pull | Color::Hoff), None) => {
322            node_color.insert(dst, Color::Pull);
323            true
324        }
325        (Some(Color::Comp | Color::Push), None) => {
326            node_color.insert(dst, Color::Push);
327            true
328        }
329
330        // Both sides already specified.
331        (Some(Color::Pull), Some(Color::Pull)) => true,
332        (Some(Color::Pull), Some(Color::Comp)) => true,
333        (Some(Color::Pull), Some(Color::Push)) => true,
334
335        (Some(Color::Comp), Some(Color::Pull)) => false,
336        (Some(Color::Comp), Some(Color::Comp)) => false,
337        (Some(Color::Comp), Some(Color::Push)) => true,
338
339        (Some(Color::Push), Some(Color::Pull)) => false,
340        (Some(Color::Push), Some(Color::Comp)) => false,
341        (Some(Color::Push), Some(Color::Push)) => true,
342
343        // Handoffs are not part of subgraphs.
344        (Some(Color::Hoff), Some(_)) => false,
345        (Some(_), Some(Color::Hoff)) => false,
346    };
347    can_connect
348}
349
350/// Marks tick-boundary (`defer_tick` / `defer_tick_lazy`) handoffs with their delay type
351/// for double-buffered codegen in `as_code`.
352fn mark_tick_boundary_handoffs(
353    partitioned_graph: &mut DfirGraph,
354    tick_edges: &SecondaryMap<GraphEdgeId, DelayType>,
355) {
356    let tick_handoffs: Vec<_> = partitioned_graph
357        .nodes()
358        .filter_map(|(hoff_id, hoff)| {
359            if !matches!(hoff, GraphNode::Handoff { .. }) {
360                return None;
361            }
362            if partitioned_graph.node_degree_out(hoff_id) == 0 {
363                return None;
364            }
365            let (succ_edge, _) = partitioned_graph.node_successors(hoff_id).next().unwrap();
366            let &delay_type = tick_edges.get(succ_edge)?;
367            Some((hoff_id, delay_type))
368        })
369        .collect();
370
371    for (hoff_id, delay_type) in tick_handoffs {
372        partitioned_graph.set_handoff_delay_type(hoff_id, delay_type);
373    }
374}
375
376/// Main method for this module. Partitions a flat [`DfirGraph`] into one with subgraphs.
377///
378/// Returns an error if an intra-tick cycle exists in the graph.
379pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
380    let (mut tick_edges, edge_barrier_pairs) = find_edge_barriers(&flat_graph);
381    let access_group_pairs = find_access_group_ordering(&flat_graph);
382    let mut partitioned_graph = flat_graph;
383
384    // Partition into subgraphs and insert handoffs.
385    make_subgraphs(
386        &mut partitioned_graph,
387        &mut tick_edges,
388        &edge_barrier_pairs,
389        &access_group_pairs,
390    )?;
391
392    // Mark tick-boundary handoffs for double-buffering.
393    mark_tick_boundary_handoffs(&mut partitioned_graph, &tick_edges);
394
395    Ok(partitioned_graph)
396}