Skip to main content

gruel_lsp/
references.rs

1//! Find references (ADR-0091 Phase 5).
2//!
3//! Walk the merged AST and collect every identifier whose name matches a
4//! given target name *and* whose enclosing scope is consistent with the
5//! definition. Without sema-level symbol resolution we use a textual
6//! match scoped to:
7//!
8//! - all references whose name matches a top-level item's defining name
9//!   (functions, structs, enums, interfaces, derives, consts) — Phase 5
10//!   conservatively returns every occurrence, since the same name at
11//!   top level can only resolve to that item under the current
12//!   "all symbols live in a flat namespace" rule (ADR-0023).
13//! - parameter and local-let references — limited to the enclosing
14//!   function body. The same name in a different function is a
15//!   different binding.
16
17use gruel_parser::ast::{
18    AssignTarget, Ast, BlockExpr, Expr, Function, Ident, Item, MatchArm, Method, Pattern,
19    Statement, TypeExpr,
20};
21use gruel_util::{FileId, Span};
22use lasso::{Spur, ThreadedRodeo};
23
24use crate::goto::definition_at;
25
26/// Find every reference to the identifier under the cursor.
27///
28/// `include_declaration` mirrors the LSP `referencesParams.context`.
29pub fn references_at(
30    ast: &Ast,
31    interner: &ThreadedRodeo,
32    file_id: FileId,
33    byte: u32,
34    include_declaration: bool,
35) -> Vec<Span> {
36    let target = match find_ident_at(ast, file_id, byte) {
37        Some(i) => i,
38        None => return Vec::new(),
39    };
40    let def_span = definition_at(ast, interner, file_id, byte);
41
42    // Decide scope: if def is a top-level item's name, scope = whole AST.
43    // Otherwise (def lives inside a function body — a param or let), scope
44    // = that function body.
45    let scope = if let Some(def) = def_span {
46        if is_top_level_item_name(ast, def) {
47            Scope::Workspace
48        } else if let Some(f) = enclosing_function(ast, def) {
49            Scope::Function(f)
50        } else {
51            Scope::Workspace
52        }
53    } else {
54        Scope::Workspace
55    };
56
57    let mut out = Vec::new();
58    collect_references(ast, target.name, scope, &mut out);
59
60    if !include_declaration {
61        if let Some(def) = def_span {
62            out.retain(|s| *s != def);
63        }
64    }
65    out.sort_by_key(|s| (s.file_id.0, s.start, s.end));
66    out.dedup();
67    out
68}
69
70#[derive(Clone)]
71enum Scope<'a> {
72    Workspace,
73    Function(&'a Function),
74}
75
76fn enclosing_function(ast: &Ast, def: Span) -> Option<&Function> {
77    for item in &ast.items {
78        if let Item::Function(f) = item {
79            if def.file_id == f.span.file_id && def.start >= f.span.start && def.end <= f.span.end {
80                return Some(f);
81            }
82        }
83    }
84    None
85}
86
87fn is_top_level_item_name(ast: &Ast, span: Span) -> bool {
88    for item in &ast.items {
89        let name_span = match item {
90            Item::Function(f) => f.name.span,
91            Item::Struct(s) => s.name.span,
92            Item::Enum(e) => e.name.span,
93            Item::Interface(i) => i.name.span,
94            Item::Derive(d) => d.name.span,
95            Item::Const(c) => c.name.span,
96            _ => continue,
97        };
98        if name_span == span {
99            return true;
100        }
101    }
102    false
103}
104
105fn collect_references(ast: &Ast, name: Spur, scope: Scope, out: &mut Vec<Span>) {
106    let mut walker = RefWalker { name, out };
107    match scope {
108        Scope::Workspace => {
109            for item in &ast.items {
110                walker.visit_item(item);
111            }
112        }
113        Scope::Function(f) => {
114            walker.visit_function_body(f);
115        }
116    }
117}
118
119struct RefWalker<'a> {
120    name: Spur,
121    out: &'a mut Vec<Span>,
122}
123
124impl<'a> RefWalker<'a> {
125    fn consider(&mut self, ident: Ident) {
126        if ident.name == self.name {
127            self.out.push(ident.span);
128        }
129    }
130
131    fn visit_item(&mut self, item: &Item) {
132        match item {
133            Item::Function(f) => self.visit_function_full(f),
134            Item::Struct(s) => {
135                self.consider(s.name);
136                for field in &s.fields {
137                    self.consider(field.name);
138                    self.visit_type(&field.ty);
139                }
140                for m in &s.methods {
141                    self.visit_method_full(m);
142                }
143            }
144            Item::Enum(e) => {
145                self.consider(e.name);
146                for v in &e.variants {
147                    self.consider(v.name);
148                }
149                for m in &e.methods {
150                    self.visit_method_full(m);
151                }
152            }
153            Item::Interface(i) => {
154                self.consider(i.name);
155                for sig in &i.methods {
156                    self.consider(sig.name);
157                    for p in &sig.params {
158                        self.consider(p.name);
159                        self.visit_type(&p.ty);
160                    }
161                    if let Some(rt) = &sig.return_type {
162                        self.visit_type(rt);
163                    }
164                }
165            }
166            Item::Derive(d) => {
167                self.consider(d.name);
168                for m in &d.methods {
169                    self.visit_method_full(m);
170                }
171            }
172            Item::Const(c) => {
173                self.consider(c.name);
174                if let Some(ty) = &c.ty {
175                    self.visit_type(ty);
176                }
177                self.visit_expr(&c.init);
178            }
179            Item::LinkExtern(b) => {
180                for ext in &b.items {
181                    self.consider(ext.name);
182                    for p in &ext.params {
183                        self.consider(p.name);
184                        self.visit_type(&p.ty);
185                    }
186                    if let Some(rt) = &ext.return_type {
187                        self.visit_type(rt);
188                    }
189                }
190            }
191            Item::Error(_) => {}
192        }
193    }
194
195    fn visit_function_full(&mut self, f: &Function) {
196        self.consider(f.name);
197        for p in &f.params {
198            self.consider(p.name);
199            self.visit_type(&p.ty);
200        }
201        if let Some(rt) = &f.return_type {
202            self.visit_type(rt);
203        }
204        self.visit_expr(&f.body);
205    }
206
207    fn visit_method_full(&mut self, m: &Method) {
208        self.consider(m.name);
209        for p in &m.params {
210            self.consider(p.name);
211            self.visit_type(&p.ty);
212        }
213        if let Some(rt) = &m.return_type {
214            self.visit_type(rt);
215        }
216        self.visit_expr(&m.body);
217    }
218
219    fn visit_function_body(&mut self, f: &Function) {
220        for p in &f.params {
221            self.consider(p.name);
222        }
223        self.visit_expr(&f.body);
224    }
225
226    fn visit_type(&mut self, ty: &TypeExpr) {
227        match ty {
228            TypeExpr::Named(ident) => self.consider(*ident),
229            TypeExpr::TypeCall { callee, args, .. } => {
230                self.consider(*callee);
231                for a in args {
232                    self.visit_type(a);
233                }
234            }
235            TypeExpr::Array { element, .. } => self.visit_type(element),
236            TypeExpr::Tuple { elems, .. } => {
237                for e in elems {
238                    self.visit_type(e);
239                }
240            }
241            _ => {}
242        }
243    }
244
245    fn visit_expr(&mut self, expr: &Expr) {
246        match expr {
247            Expr::Ident(ident) => self.consider(*ident),
248            Expr::Block(b) => self.visit_block(b),
249            Expr::Call(c) => {
250                self.consider(c.name);
251                for arg in &c.args {
252                    self.visit_expr(&arg.expr);
253                }
254            }
255            Expr::MethodCall(m) => {
256                self.visit_expr(&m.receiver);
257                self.consider(m.method);
258                for arg in &m.args {
259                    self.visit_expr(&arg.expr);
260                }
261            }
262            Expr::Field(f) => {
263                self.visit_expr(&f.base);
264                self.consider(f.field);
265            }
266            Expr::Binary(b) => {
267                self.visit_expr(&b.left);
268                self.visit_expr(&b.right);
269            }
270            Expr::Unary(u) => self.visit_expr(&u.operand),
271            Expr::Paren(p) => self.visit_expr(&p.inner),
272            Expr::If(i) => {
273                self.visit_expr(&i.cond);
274                self.visit_block(&i.then_block);
275                if let Some(b) = &i.else_block {
276                    self.visit_block(b);
277                }
278            }
279            Expr::While(w) => {
280                self.visit_expr(&w.cond);
281                self.visit_block(&w.body);
282            }
283            Expr::For(f) => {
284                self.consider(f.binding);
285                self.visit_expr(&f.iterable);
286                self.visit_block(&f.body);
287            }
288            Expr::Loop(l) => self.visit_block(&l.body),
289            Expr::Match(m) => {
290                self.visit_expr(&m.scrutinee);
291                for arm in &m.arms {
292                    self.visit_match_arm(arm);
293                }
294            }
295            Expr::Return(r) => {
296                if let Some(e) = &r.value {
297                    self.visit_expr(e);
298                }
299            }
300            Expr::Tuple(t) => {
301                for e in &t.elems {
302                    self.visit_expr(e);
303                }
304            }
305            Expr::Index(i) => {
306                self.visit_expr(&i.base);
307                self.visit_expr(&i.index);
308            }
309            Expr::TupleIndex(t) => self.visit_expr(&t.base),
310            Expr::StructLit(s) => {
311                if let Some(b) = &s.base {
312                    self.visit_expr(b);
313                }
314                self.consider(s.name);
315                for fi in &s.fields {
316                    self.consider(fi.name);
317                    self.visit_expr(&fi.value);
318                }
319            }
320            Expr::Path(p) => {
321                if let Some(b) = &p.base {
322                    self.visit_expr(b);
323                }
324                self.consider(p.type_name);
325                self.consider(p.variant);
326            }
327            Expr::ArrayLit(a) => {
328                for e in &a.elements {
329                    self.visit_expr(e);
330                }
331            }
332            Expr::IntrinsicCall(c) => {
333                self.consider(c.name);
334                for a in &c.args {
335                    if let gruel_parser::ast::IntrinsicArg::Expr(e) = a {
336                        self.visit_expr(e);
337                    }
338                }
339            }
340            _ => {}
341        }
342    }
343
344    fn visit_block(&mut self, b: &BlockExpr) {
345        for stmt in &b.statements {
346            self.visit_statement(stmt);
347        }
348        self.visit_expr(&b.expr);
349    }
350
351    fn visit_statement(&mut self, stmt: &Statement) {
352        match stmt {
353            Statement::Let(l) => {
354                if let Pattern::Ident { name, .. } = &l.pattern {
355                    self.consider(*name);
356                }
357                if let Some(ty) = &l.ty {
358                    self.visit_type(ty);
359                }
360                self.visit_expr(&l.init);
361            }
362            Statement::Assign(a) => {
363                match &a.target {
364                    AssignTarget::Var(i) => self.consider(*i),
365                    AssignTarget::Field(f) => {
366                        self.visit_expr(&f.base);
367                        self.consider(f.field);
368                    }
369                    AssignTarget::Index(i) => {
370                        self.visit_expr(&i.base);
371                        self.visit_expr(&i.index);
372                    }
373                }
374                self.visit_expr(&a.value);
375            }
376            Statement::Expr(e) => self.visit_expr(e),
377        }
378    }
379
380    fn visit_match_arm(&mut self, arm: &MatchArm) {
381        self.visit_expr(&arm.body);
382    }
383}
384
385fn find_ident_at(ast: &Ast, file_id: FileId, byte: u32) -> Option<Ident> {
386    let mut finder = IdentFinder::new(file_id, byte);
387    for item in &ast.items {
388        finder.visit_item(item);
389    }
390    finder.result
391}
392
393struct IdentFinder {
394    file_id: FileId,
395    byte: u32,
396    result: Option<Ident>,
397    best_size: u32,
398}
399
400impl IdentFinder {
401    fn new(file_id: FileId, byte: u32) -> Self {
402        Self {
403            file_id,
404            byte,
405            result: None,
406            best_size: u32::MAX,
407        }
408    }
409    fn consider(&mut self, ident: Ident) {
410        if ident.span.file_id != self.file_id {
411            return;
412        }
413        if self.byte < ident.span.start || self.byte >= ident.span.end {
414            return;
415        }
416        let size = ident.span.end.saturating_sub(ident.span.start);
417        if size <= self.best_size {
418            self.best_size = size;
419            self.result = Some(ident);
420        }
421    }
422
423    fn visit_item(&mut self, item: &Item) {
424        let mut walker = RefWalker {
425            name: lasso::Spur::default(),
426            out: &mut Vec::new(),
427        };
428        // We don't actually filter by name here — just collect every Ident
429        // and pick the smallest containing the byte.
430        let _ = &mut walker;
431        // Simpler: re-implement a passthrough walker that calls our consider().
432        self.walk_item(item);
433    }
434
435    fn walk_item(&mut self, item: &Item) {
436        match item {
437            Item::Function(f) => {
438                self.consider(f.name);
439                for p in &f.params {
440                    self.consider(p.name);
441                    self.walk_type(&p.ty);
442                }
443                if let Some(rt) = &f.return_type {
444                    self.walk_type(rt);
445                }
446                self.walk_expr(&f.body);
447            }
448            Item::Struct(s) => {
449                self.consider(s.name);
450                for field in &s.fields {
451                    self.consider(field.name);
452                    self.walk_type(&field.ty);
453                }
454                for m in &s.methods {
455                    self.walk_method(m);
456                }
457            }
458            Item::Enum(e) => {
459                self.consider(e.name);
460                for v in &e.variants {
461                    self.consider(v.name);
462                }
463                for m in &e.methods {
464                    self.walk_method(m);
465                }
466            }
467            Item::Interface(i) => {
468                self.consider(i.name);
469                for sig in &i.methods {
470                    self.consider(sig.name);
471                    for p in &sig.params {
472                        self.consider(p.name);
473                        self.walk_type(&p.ty);
474                    }
475                    if let Some(rt) = &sig.return_type {
476                        self.walk_type(rt);
477                    }
478                }
479            }
480            Item::Derive(d) => {
481                self.consider(d.name);
482                for m in &d.methods {
483                    self.walk_method(m);
484                }
485            }
486            Item::Const(c) => {
487                self.consider(c.name);
488                if let Some(ty) = &c.ty {
489                    self.walk_type(ty);
490                }
491                self.walk_expr(&c.init);
492            }
493            Item::LinkExtern(b) => {
494                for ext in &b.items {
495                    self.consider(ext.name);
496                    for p in &ext.params {
497                        self.consider(p.name);
498                        self.walk_type(&p.ty);
499                    }
500                    if let Some(rt) = &ext.return_type {
501                        self.walk_type(rt);
502                    }
503                }
504            }
505            Item::Error(_) => {}
506        }
507    }
508
509    fn walk_method(&mut self, m: &Method) {
510        self.consider(m.name);
511        for p in &m.params {
512            self.consider(p.name);
513            self.walk_type(&p.ty);
514        }
515        if let Some(rt) = &m.return_type {
516            self.walk_type(rt);
517        }
518        self.walk_expr(&m.body);
519    }
520
521    fn walk_type(&mut self, ty: &TypeExpr) {
522        match ty {
523            TypeExpr::Named(ident) => self.consider(*ident),
524            TypeExpr::TypeCall { callee, args, .. } => {
525                self.consider(*callee);
526                for a in args {
527                    self.walk_type(a);
528                }
529            }
530            TypeExpr::Array { element, .. } => self.walk_type(element),
531            TypeExpr::Tuple { elems, .. } => {
532                for e in elems {
533                    self.walk_type(e);
534                }
535            }
536            _ => {}
537        }
538    }
539
540    fn walk_expr(&mut self, expr: &Expr) {
541        match expr {
542            Expr::Ident(ident) => self.consider(*ident),
543            Expr::Block(b) => {
544                for stmt in &b.statements {
545                    self.walk_statement(stmt);
546                }
547                self.walk_expr(&b.expr);
548            }
549            Expr::Call(c) => {
550                self.consider(c.name);
551                for arg in &c.args {
552                    self.walk_expr(&arg.expr);
553                }
554            }
555            Expr::MethodCall(m) => {
556                self.walk_expr(&m.receiver);
557                self.consider(m.method);
558                for arg in &m.args {
559                    self.walk_expr(&arg.expr);
560                }
561            }
562            Expr::Field(f) => {
563                self.walk_expr(&f.base);
564                self.consider(f.field);
565            }
566            Expr::Binary(b) => {
567                self.walk_expr(&b.left);
568                self.walk_expr(&b.right);
569            }
570            Expr::Unary(u) => self.walk_expr(&u.operand),
571            Expr::Paren(p) => self.walk_expr(&p.inner),
572            Expr::If(i) => {
573                self.walk_expr(&i.cond);
574                self.walk_expr(&Expr::Block(i.then_block.clone()));
575                if let Some(b) = &i.else_block {
576                    self.walk_expr(&Expr::Block(b.clone()));
577                }
578            }
579            Expr::While(w) => {
580                self.walk_expr(&w.cond);
581                self.walk_expr(&Expr::Block(w.body.clone()));
582            }
583            Expr::For(f) => {
584                self.consider(f.binding);
585                self.walk_expr(&f.iterable);
586                self.walk_expr(&Expr::Block(f.body.clone()));
587            }
588            Expr::Match(m) => {
589                self.walk_expr(&m.scrutinee);
590                for arm in &m.arms {
591                    self.walk_expr(&arm.body);
592                }
593            }
594            Expr::Return(r) => {
595                if let Some(e) = &r.value {
596                    self.walk_expr(e);
597                }
598            }
599            Expr::Tuple(t) => {
600                for e in &t.elems {
601                    self.walk_expr(e);
602                }
603            }
604            Expr::Index(i) => {
605                self.walk_expr(&i.base);
606                self.walk_expr(&i.index);
607            }
608            Expr::TupleIndex(t) => self.walk_expr(&t.base),
609            Expr::StructLit(s) => {
610                if let Some(b) = &s.base {
611                    self.walk_expr(b);
612                }
613                self.consider(s.name);
614                for fi in &s.fields {
615                    self.consider(fi.name);
616                    self.walk_expr(&fi.value);
617                }
618            }
619            Expr::Path(p) => {
620                if let Some(b) = &p.base {
621                    self.walk_expr(b);
622                }
623                self.consider(p.type_name);
624                self.consider(p.variant);
625            }
626            _ => {}
627        }
628    }
629
630    fn walk_statement(&mut self, stmt: &Statement) {
631        match stmt {
632            Statement::Let(l) => {
633                if let Pattern::Ident { name, .. } = &l.pattern {
634                    self.consider(*name);
635                }
636                if let Some(ty) = &l.ty {
637                    self.walk_type(ty);
638                }
639                self.walk_expr(&l.init);
640            }
641            Statement::Assign(a) => {
642                match &a.target {
643                    AssignTarget::Var(i) => self.consider(*i),
644                    AssignTarget::Field(f) => {
645                        self.walk_expr(&f.base);
646                        self.consider(f.field);
647                    }
648                    AssignTarget::Index(i) => {
649                        self.walk_expr(&i.base);
650                        self.walk_expr(&i.index);
651                    }
652                }
653                self.walk_expr(&a.value);
654            }
655            Statement::Expr(e) => self.walk_expr(e),
656        }
657    }
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663    use gruel_compiler::{
664        PreviewFeatures, SourceFile, merge_symbols, parse_all_files_with_preview,
665    };
666
667    fn parse(source: &str) -> (Ast, ThreadedRodeo) {
668        let sources = vec![SourceFile::new("main.gruel", source, FileId::new(1))];
669        let parsed = parse_all_files_with_preview(&sources, &PreviewFeatures::default()).unwrap();
670        let merged = merge_symbols(parsed).unwrap();
671        (merged.ast, merged.interner)
672    }
673
674    #[test]
675    fn references_to_function_includes_call_sites() {
676        let src = "fn foo() -> i32 { 0 }\nfn main() -> i32 { foo() + foo() }";
677        let (ast, interner) = parse(src);
678        let byte = src.find("foo").unwrap() as u32;
679        let refs = references_at(&ast, &interner, FileId::new(1), byte, true);
680        // 1 def + 2 call sites = 3 references
681        assert!(refs.len() >= 3, "got: {:?}", refs);
682    }
683
684    #[test]
685    fn references_to_local_limited_to_scope() {
686        let src = "fn main() -> i32 { let x = 1; x + x }\nfn other() -> i32 { let x = 2; x }";
687        let (ast, interner) = parse(src);
688        let byte = src.find("let x").unwrap() as u32 + 4;
689        let refs = references_at(&ast, &interner, FileId::new(1), byte, true);
690        // 1 binding `x` + 2 references in main, NOT the `x` in other()
691        assert_eq!(refs.len(), 3, "got: {:?}", refs);
692    }
693
694    #[test]
695    fn references_excludes_declaration_when_requested() {
696        let src = "fn foo() -> i32 { 0 }\nfn main() -> i32 { foo() }";
697        let (ast, interner) = parse(src);
698        let byte = src.find("foo").unwrap() as u32;
699        let refs = references_at(&ast, &interner, FileId::new(1), byte, false);
700        // Without declaration: 1 call site
701        assert_eq!(refs.len(), 1);
702    }
703}