1use std::borrow::Cow;
4use std::hash::Hash;
5
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::{Expr, ExprPath, GenericArgument, Token, Type};
12
13use self::ops::{OperatorConstraints, Persistence};
14use crate::diagnostic::{Diagnostic, Diagnostics, Level};
15use crate::parse::{DfirCode, IndexInt, Operator, PortIndex, Ported, SingletonRef};
16use crate::pretty_span::PrettySpan;
17
18mod di_mul_graph;
19mod eliminate_extra_unions_tees;
20mod flat_graph_builder;
21mod flat_to_partitioned;
22mod graph_write;
23mod meta_graph;
24mod meta_graph_debugging;
25
26use std::fmt::Display;
27
28pub use di_mul_graph::DiMulGraph;
29pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
30pub use flat_graph_builder::{FlatGraphBuilder, FlatGraphBuilderOutput};
31pub use flat_to_partitioned::partition_graph;
32pub use meta_graph::{DfirGraph, WriteConfig, WriteGraphType};
33
34pub use crate::graph_ids::{GraphEdgeId, GraphLoopId, GraphNodeId, GraphSubgraphId};
35
36pub mod graph_algorithms;
37pub mod ops;
38
39impl GraphSubgraphId {
40 pub fn as_ident(self, span: Span) -> Ident {
42 use slotmap::Key;
43 Ident::new(&format!("sgid_{:?}", self.data()), span)
44 }
45}
46
47impl GraphLoopId {
48 pub fn as_ident(self, span: Span) -> Ident {
50 use slotmap::Key;
51 Ident::new(&format!("loop_{:?}", self.data()), span)
52 }
53}
54
55const CONTEXT: &str = "context";
57const GRAPH: &str = "df";
59
60const HANDOFF_NODE_STR: &str = "handoff";
61const SINGLETON_SLOT_NODE_STR: &str = "singleton";
62const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
63
64mod serde_syn {
65 use serde::{Deserialize, Deserializer, Serializer};
66
67 pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
68 where
69 S: Serializer,
70 T: quote::ToTokens,
71 {
72 serializer.serialize_str(&value.to_token_stream().to_string())
73 }
74
75 pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
76 where
77 D: Deserializer<'de>,
78 T: syn::parse::Parse,
79 {
80 let s = String::deserialize(deserializer)?;
81 syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
82 }
83}
84
85#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq, Hash)]
89pub struct Varname(#[serde(with = "serde_syn")] pub Ident);
90
91#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
93pub enum HandoffKind {
94 Vec,
96 Singleton,
99 Optional,
102}
103
104#[derive(Clone, Serialize, Deserialize)]
106pub enum GraphNode {
107 Operator(#[serde(with = "serde_syn")] Operator),
109 Handoff {
111 kind: HandoffKind,
113 #[serde(skip, default = "Span::call_site")]
115 src_span: Span,
116 #[serde(skip, default = "Span::call_site")]
118 dst_span: Span,
119 },
120
121 ModuleBoundary {
123 input: bool,
125
126 #[serde(skip, default = "Span::call_site")]
130 import_expr: Span,
131 },
132}
133impl GraphNode {
134 pub fn to_pretty_string(&self) -> Cow<'static, str> {
136 match self {
137 GraphNode::Operator(op) => op.to_pretty_string().into(),
138 GraphNode::Handoff {
139 kind: HandoffKind::Vec,
140 ..
141 } => HANDOFF_NODE_STR.into(),
142 GraphNode::Handoff {
143 kind: HandoffKind::Singleton | HandoffKind::Optional,
144 ..
145 } => SINGLETON_SLOT_NODE_STR.into(),
146 GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
147 }
148 }
149
150 pub fn to_name_string(&self) -> Cow<'static, str> {
152 match self {
153 GraphNode::Operator(op) => op.name_string().into(),
154 GraphNode::Handoff {
155 kind: HandoffKind::Vec,
156 ..
157 } => HANDOFF_NODE_STR.into(),
158 GraphNode::Handoff {
159 kind: HandoffKind::Singleton | HandoffKind::Optional,
160 ..
161 } => SINGLETON_SLOT_NODE_STR.into(),
162 GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
163 }
164 }
165
166 pub fn span(&self) -> Span {
168 match self {
169 Self::Operator(op) => op.span(),
170 &Self::Handoff {
171 src_span, dst_span, ..
172 } => src_span.join(dst_span).unwrap_or(src_span),
173 Self::ModuleBoundary { import_expr, .. } => *import_expr,
174 }
175 }
176}
177impl std::fmt::Debug for GraphNode {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 match self {
180 Self::Operator(operator) => {
181 write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
182 }
183 Self::Handoff { kind, .. } => write!(f, "Node::Handoff({kind:?})"),
184 Self::ModuleBoundary { input, .. } => {
185 write!(f, "Node::ModuleBoundary{{input: {}}}", input)
186 }
187 }
188 }
189}
190
191#[derive(Clone, Debug)]
200pub struct OperatorInstance {
201 pub op_constraints: &'static OperatorConstraints,
203 pub input_ports: Vec<PortIndexValue>,
205 pub output_ports: Vec<PortIndexValue>,
207 pub singletons_referenced: Vec<SingletonRef>,
209
210 pub generics: OpInstGenerics,
212 pub arguments_pre: Punctuated<Expr, Token![,]>,
218 pub arguments_raw: TokenStream,
220}
221
222#[derive(Clone, Debug)]
224pub struct OpInstGenerics {
225 pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
227 pub persistence_args: Vec<Persistence>,
229 pub type_args: Vec<Type>,
231}
232
233impl OpInstGenerics {
234 fn join_spans<I>(mut spans: I) -> Option<Span>
239 where
240 I: Iterator<Item = Span>,
241 {
242 let mut span = spans.next()?;
243 for s in spans {
244 span = span.join(s)?;
245 }
246 Some(span)
247 }
248
249 pub fn persistence_args_span(&self) -> Option<Span> {
251 self.generic_args.as_ref().and_then(|args| {
252 Self::join_spans(
253 args.iter()
254 .filter(|a| matches!(a, GenericArgument::Lifetime(_)))
255 .map(|a| a.span()),
256 )
257 })
258 }
259
260 pub fn type_args_span(&self) -> Option<Span> {
262 self.generic_args.as_ref().and_then(|args| {
263 Self::join_spans(
264 args.iter()
265 .filter(|a| matches!(a, GenericArgument::Type(_)))
266 .map(|a| a.span()),
267 )
268 })
269 }
270}
271
272pub fn get_operator_generics(diagnostics: &mut Diagnostics, operator: &Operator) -> OpInstGenerics {
277 let generic_args = operator.type_arguments().cloned();
279 let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
280 GenericArgument::Lifetime(lifetime) => {
281 match &*lifetime.ident.to_string() {
282 "none" => Some(Persistence::None),
283 "loop" => Some(Persistence::Loop),
284 "tick" => Some(Persistence::Tick),
285 "static" => Some(Persistence::Static),
286 "mutable" => Some(Persistence::Mutable),
287 _ => {
288 diagnostics.push(Diagnostic::spanned(
289 generic_arg.span(),
290 Level::Error,
291 format!("Unknown lifetime generic argument `'{}`, expected `'none`, `'loop`, `'tick`, `'static`, or `'mutable`.", lifetime.ident),
292 ));
293 None
295 }
296 }
297 },
298 _ => None,
299 }).collect::<Vec<_>>();
300 let type_args = generic_args
301 .iter()
302 .flatten()
303 .skip(persistence_args.len())
304 .map_while(|generic_arg| match generic_arg {
305 GenericArgument::Type(typ) => Some(typ),
306 _ => None,
307 })
308 .cloned()
309 .collect::<Vec<_>>();
310
311 OpInstGenerics {
312 generic_args,
313 persistence_args,
314 type_args,
315 }
316}
317
318#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
320pub enum Color {
321 Pull,
323 Push,
325 Comp,
327 Hoff,
329}
330
331#[derive(Clone, Debug, Serialize, Deserialize)]
333pub enum PortIndexValue {
334 Int(#[serde(with = "serde_syn")] IndexInt),
336 Path(#[serde(with = "serde_syn")] ExprPath),
338 Elided(#[serde(skip)] Option<Span>),
341}
342impl PortIndexValue {
343 pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
346 where
347 Inner: Spanned,
348 {
349 let ported_span = Some(ported.inner.span());
350 let port_inn = ported
351 .inn
352 .map(|idx| idx.index.into())
353 .unwrap_or_else(|| Self::Elided(ported_span));
354 let inner = ported.inner;
355 let port_out = ported
356 .out
357 .map(|idx| idx.index.into())
358 .unwrap_or_else(|| Self::Elided(ported_span));
359 (port_inn, inner, port_out)
360 }
361
362 pub fn is_specified(&self) -> bool {
364 !matches!(self, Self::Elided(_))
365 }
366
367 #[allow(clippy::allow_attributes, reason = "Only triggered on nightly.")]
371 #[allow(
372 clippy::result_large_err,
373 reason = "variants are same size, error isn't to be propagated."
374 )]
375 pub fn combine(self, other: Self) -> Result<Self, Self> {
376 match (self.is_specified(), other.is_specified()) {
377 (false, _other) => Ok(other),
378 (true, false) => Ok(self),
379 (true, true) => Err(self),
380 }
381 }
382
383 pub fn as_error_message_string(&self) -> String {
385 match self {
386 PortIndexValue::Int(n) => format!("`{}`", n.value),
387 PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
388 PortIndexValue::Elided(_) => "<elided>".to_owned(),
389 }
390 }
391
392 pub fn span(&self) -> Span {
394 match self {
395 PortIndexValue::Int(x) => x.span(),
396 PortIndexValue::Path(x) => x.span(),
397 PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
398 }
399 }
400}
401impl From<PortIndex> for PortIndexValue {
402 fn from(value: PortIndex) -> Self {
403 match value {
404 PortIndex::Int(x) => Self::Int(x),
405 PortIndex::Path(x) => Self::Path(x),
406 }
407 }
408}
409impl PartialEq for PortIndexValue {
410 fn eq(&self, other: &Self) -> bool {
411 match (self, other) {
412 (Self::Int(l0), Self::Int(r0)) => l0 == r0,
413 (Self::Path(l0), Self::Path(r0)) => l0 == r0,
414 (Self::Elided(_), Self::Elided(_)) => true,
415 _else => false,
416 }
417 }
418}
419impl Eq for PortIndexValue {}
420impl PartialOrd for PortIndexValue {
421 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
422 Some(self.cmp(other))
423 }
424}
425impl Ord for PortIndexValue {
426 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
427 match (self, other) {
428 (Self::Int(s), Self::Int(o)) => s.cmp(o),
429 (Self::Path(s), Self::Path(o)) => s
430 .to_token_stream()
431 .to_string()
432 .cmp(&o.to_token_stream().to_string()),
433 (Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
434 (Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
435 (Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
436 (_, Self::Elided(_)) => std::cmp::Ordering::Less,
437 (Self::Elided(_), _) => std::cmp::Ordering::Greater,
438 }
439 }
440}
441
442impl Display for PortIndexValue {
443 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444 match self {
445 PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
446 PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
447 PortIndexValue::Elided(_) => write!(f, "[]"),
448 }
449 }
450}
451
452pub struct BuildDfirCodeOutput {
454 pub partitioned_graph: DfirGraph,
456 pub code: TokenStream,
458 pub diagnostics: Diagnostics,
460}
461
462pub fn build_dfir_code(
464 dfir_code: DfirCode,
465 root: &TokenStream,
466) -> Result<BuildDfirCodeOutput, Diagnostics> {
467 let flat_graph_builder = FlatGraphBuilder::from_dfir(dfir_code);
468
469 let FlatGraphBuilderOutput {
470 mut flat_graph,
471 uses,
472 mut diagnostics,
473 } = flat_graph_builder.build()?;
474
475 let () = match flat_graph.merge_modules() {
476 Ok(()) => (),
477 Err(d) => {
478 diagnostics.push(d);
479 return Err(diagnostics);
480 }
481 };
482
483 eliminate_extra_unions_tees(&mut flat_graph);
484
485 for (edge_id, (src, dst)) in flat_graph.edges() {
488 let _ = edge_id;
489 if matches!(flat_graph.node(src), GraphNode::Handoff { .. })
490 && matches!(flat_graph.node(dst), GraphNode::Handoff { .. })
491 {
492 let span = flat_graph.node(dst).span();
493 diagnostics.push(Diagnostic::spanned(
494 span,
495 Level::Error,
496 "Adjacent handoff/singleton operators are not allowed. \
497 Remove one or insert an operator between them.",
498 ));
499 }
500 }
501
502 for (_loop_id, nodes) in flat_graph.loops() {
506 let span = nodes
507 .first()
508 .map_or_else(Span::call_site, |&n| flat_graph.node(n).span());
509 diagnostics.push(Diagnostic::spanned(
510 span,
511 Level::Error,
512 "`loop { }` blocks are not (yet) supported in `dfir_syntax!`.",
513 ));
514 }
515 if diagnostics.has_error() {
516 return Err(diagnostics);
517 }
518
519 let partitioned_graph = match partition_graph(flat_graph) {
520 Ok(partitioned_graph) => partitioned_graph,
521 Err(d) => {
522 diagnostics.push(d);
523 return Err(diagnostics);
524 }
525 };
526
527 let code =
528 partitioned_graph.as_code(root, true, quote::quote! { #( #uses )* }, &mut diagnostics)?;
529
530 Ok(BuildDfirCodeOutput {
531 partitioned_graph,
532 code,
533 diagnostics,
534 })
535}
536
537fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
539 use proc_macro2::{Group, TokenTree};
540 tokens
541 .into_iter()
542 .map(|token| match token {
543 TokenTree::Group(mut group) => {
544 group.set_span(span);
545 TokenTree::Group(Group::new(
546 group.delimiter(),
547 change_spans(group.stream(), span),
548 ))
549 }
550 TokenTree::Ident(mut ident) => {
551 ident.set_span(span.resolved_at(ident.span()));
552 TokenTree::Ident(ident)
553 }
554 TokenTree::Punct(mut punct) => {
555 punct.set_span(span);
556 TokenTree::Punct(punct)
557 }
558 TokenTree::Literal(mut literal) => {
559 literal.set_span(span);
560 TokenTree::Literal(literal)
561 }
562 })
563 .collect()
564}