Skip to main content

gruel_air/
specialize.rs

1//! Generic function specialization pass.
2//!
3//! This module provides the specialization pass that transforms `CallGeneric`
4//! instructions into regular `Call` instructions by:
5//!
6//! 1. Collecting all `CallGeneric` instructions in the analyzed functions
7//! 2. For each unique (func_name, type_args) combination, creating a specialized function
8//! 3. Rewriting `CallGeneric` to `Call` with the specialized function name
9//!
10//! # Architecture
11//!
12//! The specialization pass runs after semantic analysis but before CFG building.
13//! It transforms the AIR in-place and adds new specialized functions to the output.
14
15use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
16
17use gruel_rir::{InstRef, RirParamMode};
18use gruel_util::Span;
19use gruel_util::{CompileError, CompileResult, ErrorKind};
20use lasso::{Spur, ThreadedRodeo};
21
22use crate::inst::{Air, AirInstData};
23use crate::param_arena::ParamRange;
24use crate::sema::{AnalyzedFunction, ConstValue, FunctionInfo, InferenceContext, MethodInfo, Sema};
25use crate::types::{StructId, Type};
26
27/// Function/method references discovered while specializing — feed back
28/// into the lazy work queue so reachability stays closed under
29/// specialization.
30#[derive(Debug, Default)]
31pub struct SpecializationRefs {
32    pub fns: HashSet<Spur>,
33    pub meths: HashSet<(StructId, Spur)>,
34}
35
36/// One row in the analyzed-functions accumulator: an analyzed body plus its
37/// per-function string and byte literal pools (remapped to global tables
38/// later).
39pub type AnalyzedRow = (AnalyzedFunction, Vec<String>, Vec<Vec<u8>>);
40
41/// A key for a specialized function: (base_function_name, type_arguments).
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub struct SpecializationKey {
44    /// Base function name (e.g., "identity")
45    pub base_name: Spur,
46    /// Type arguments (e.g., [Type::I32])
47    pub type_args: Vec<Type>,
48    /// Comptime value arguments captured at the call site (e.g. `[Integer(7)]`
49    /// for `check_n(7)` where the parameter is `comptime n: i32`). Two calls
50    /// with the same type args but different value args produce different
51    /// specializations so per-call `comptime if`/`@compile_error` checks fire
52    /// only for the values they apply to.
53    pub value_args: Vec<ConstValue>,
54}
55
56/// Info about a specialization: the mangled name and the first call site span.
57struct SpecializationInfo {
58    /// The mangled name for the specialized function.
59    mangled_name: Spur,
60    /// The span of the first call site (for error reporting if the function doesn't exist).
61    call_site_span: Span,
62}
63
64/// Perform the specialization pass on the analyzed-functions accumulator.
65///
66/// Collects every `CallGeneric` instruction across the analyzed bodies,
67/// rewrites them to direct `Call`s by mangled name, and synthesizes the
68/// specialized bodies. Iterates until the accumulator is closed: each
69/// newly-synthesized body can introduce further `CallGeneric`s
70/// (transitively-generic specializations), so we re-collect and re-rewrite
71/// until no new keys appear.
72///
73/// Returns the set of regular function/method references discovered while
74/// analyzing the synthesized bodies. The caller feeds these back into the
75/// lazy work queue so reachability stays closed under specialization
76/// (e.g. `use_greeter[T=Foo]` exposes `Foo.greet` as a reachable method
77/// even though `main` only sees a `CallGeneric`).
78pub fn specialize(
79    functions_with_strings: &mut Vec<AnalyzedRow>,
80    name_map: &mut HashMap<SpecializationKey, Spur>,
81    sema: &mut Sema<'_>,
82    infer_ctx: &InferenceContext,
83    interner: &ThreadedRodeo,
84) -> CompileResult<SpecializationRefs> {
85    let mut accumulated_refs = SpecializationRefs::default();
86
87    loop {
88        // Phase 1: collect every CallGeneric across the current accumulator.
89        let mut seen: HashMap<SpecializationKey, SpecializationInfo> = HashMap::default();
90        for (func, _, _) in functions_with_strings.iter() {
91            collect_specializations(&func.air, interner, &mut seen);
92        }
93
94        // Take only keys we haven't already specialized in a prior round.
95        let new_specs: Vec<(SpecializationKey, SpecializationInfo)> = seen
96            .into_iter()
97            .filter(|(k, _)| !name_map.contains_key(k))
98            .collect();
99
100        for (k, info) in &new_specs {
101            name_map.insert(k.clone(), info.mangled_name);
102        }
103
104        // Phase 2: rewrite CallGeneric → Call across every body. We run this
105        // BEFORE the "no new specs → return" check because the previous
106        // round's Phase 3 may have appended freshly synthesized bodies that
107        // still contain unrewritten CallGenerics. If we returned without
108        // walking them, those CallGenerics would leak past specialization
109        // and panic in CFG building. The walk is cheap on already-rewritten
110        // bodies — they hold no CallGenerics — so doing one final pass
111        // when `new_specs` is empty costs only the linear scan.
112        for (func, _, _) in functions_with_strings.iter_mut() {
113            rewrite_call_generic(&mut func.air, name_map);
114        }
115
116        if new_specs.is_empty() {
117            return Ok(accumulated_refs);
118        }
119
120        // Phase 3: synthesize the specialized bodies for the new keys.
121        for (key, info) in new_specs {
122            let base = if let Some(fn_info) = sema.functions.get(&key.base_name).copied() {
123                SpecializeBase::function(&fn_info)
124            } else if let Some((struct_id, method_sym)) =
125                resolve_method_name(sema, interner, key.base_name)
126                && let Some(method_info) = sema.methods.get(&(struct_id, method_sym)).copied()
127            {
128                // ADR-0055: generic method encoded as "StructName.methodName".
129                SpecializeBase::method(&method_info)
130            } else {
131                let func_name = interner.resolve(&key.base_name);
132                return Err(CompileError::new(
133                    ErrorKind::UndefinedFunction(func_name.to_string()),
134                    info.call_site_span,
135                ));
136            };
137
138            let row = create_specialized(
139                sema,
140                infer_ctx,
141                &key,
142                info.mangled_name,
143                base,
144                interner,
145                &mut accumulated_refs,
146            )?;
147            functions_with_strings.push(row);
148        }
149    }
150}
151
152/// Parse a "StructName.methodName" mangled name into a (StructId, method Spur).
153/// Returns None if the name does not match the pattern or the struct is
154/// unknown.
155fn resolve_method_name(
156    sema: &Sema<'_>,
157    interner: &ThreadedRodeo,
158    name: Spur,
159) -> Option<(StructId, Spur)> {
160    let name_str = interner.resolve(&name);
161    let (struct_str, method_str) = name_str.rsplit_once('.')?;
162    let struct_sym = interner.get(struct_str)?;
163    let struct_id = *sema.structs.get(&struct_sym)?;
164    let method_sym = interner.get(method_str)?;
165    Some((struct_id, method_sym))
166}
167
168/// Collect all specializations needed from a function's AIR.
169fn collect_specializations(
170    air: &Air,
171    interner: &ThreadedRodeo,
172    specializations: &mut HashMap<SpecializationKey, SpecializationInfo>,
173) {
174    for (i, inst) in air.instructions().iter().enumerate() {
175        if let AirInstData::CallGeneric {
176            name,
177            type_args_start,
178            type_args_len,
179            ..
180        } = &inst.data
181        {
182            // Extract type arguments using the public accessor
183            let type_args: Vec<Type> = air
184                .get_extra(*type_args_start, *type_args_len)
185                .iter()
186                .map(|&encoded| Type::from_u32(encoded))
187                .collect();
188
189            // Comptime value arguments captured at the call site (sidecar
190            // populated by `analyze_call_impl` when the function has
191            // `comptime n: i32`-style parameters).
192            let value_args = air.comptime_value_args(i as u32).to_vec();
193
194            let key = SpecializationKey {
195                base_name: *name,
196                type_args: type_args.clone(),
197                value_args: value_args.clone(),
198            };
199
200            specializations.entry(key).or_insert_with(|| {
201                // Generate a mangled name for the specialized function
202                let base_name = interner.resolve(name);
203                let mangled = mangle_specialized_name(base_name, &type_args, &value_args);
204                let mangled_sym = interner.get_or_intern(&mangled);
205                SpecializationInfo {
206                    mangled_name: mangled_sym,
207                    call_site_span: inst.span,
208                }
209            });
210        }
211    }
212}
213
214/// Rewrite CallGeneric instructions to Call instructions.
215fn rewrite_call_generic(air: &mut Air, specializations: &HashMap<SpecializationKey, Spur>) {
216    // We need to collect the rewrites first, then apply them.
217    // This avoids borrowing issues with the extra array.
218    let mut rewrites: Vec<(usize, AirInstData)> = Vec::new();
219
220    for (i, inst) in air.instructions().iter().enumerate() {
221        if let AirInstData::CallGeneric {
222            name,
223            type_args_start,
224            type_args_len,
225            args_start,
226            args_len,
227        } = &inst.data
228        {
229            // Extract type arguments to form the key
230            let type_args: Vec<Type> = air
231                .get_extra(*type_args_start, *type_args_len)
232                .iter()
233                .map(|&encoded| Type::from_u32(encoded))
234                .collect();
235
236            let value_args = air.comptime_value_args(i as u32).to_vec();
237
238            let key = SpecializationKey {
239                base_name: *name,
240                type_args,
241                value_args,
242            };
243
244            if let Some(&specialized_name) = specializations.get(&key) {
245                // Rewrite to a regular Call
246                let new_data = AirInstData::Call {
247                    name: specialized_name,
248                    args_start: *args_start,
249                    args_len: *args_len,
250                };
251                rewrites.push((i, new_data));
252            }
253        }
254    }
255
256    // Apply all rewrites
257    for (index, new_data) in rewrites {
258        air.rewrite_inst_data(index, new_data);
259    }
260}
261
262/// Generate a mangled name for a specialized function.
263///
264/// `Type::name()` returns generic placeholders like `"<struct>"` for struct
265/// and enum types, which would collide across different structs — so we also
266/// append the raw `Type` discriminant, which is unique per type. Primitive
267/// types get their normal name for readability.
268///
269/// Comptime value arguments are appended after type args so two calls that
270/// differ only by a `comptime n: i32` produce distinct specializations.
271fn mangle_specialized_name(
272    base_name: &str,
273    type_args: &[Type],
274    value_args: &[ConstValue],
275) -> String {
276    let mut mangled = base_name.to_string();
277    for ty in type_args {
278        mangled.push_str("__");
279        mangled.push_str(ty.name());
280        // Disambiguate compound types (structs, enums, arrays) whose
281        // `name()` is a generic placeholder.
282        mangled.push('#');
283        mangled.push_str(&ty.as_u32().to_string());
284    }
285    for v in value_args {
286        mangled.push_str("__v");
287        match v {
288            ConstValue::Integer(n) => {
289                mangled.push('i');
290                if *n < 0 {
291                    mangled.push('m');
292                    mangled.push_str(&(-(*n as i128)).to_string());
293                } else {
294                    mangled.push_str(&n.to_string());
295                }
296            }
297            ConstValue::Bool(b) => {
298                mangled.push('b');
299                mangled.push(if *b { '1' } else { '0' });
300            }
301            ConstValue::Type(t) => {
302                mangled.push('t');
303                mangled.push_str(&t.as_u32().to_string());
304            }
305            ConstValue::ComptimeStr(idx) => {
306                mangled.push('s');
307                mangled.push_str(&idx.to_string());
308            }
309            ConstValue::Unit => mangled.push('u'),
310            // Composite/heap-backed and signal variants don't appear in
311            // call-site value args today; mangle defensively if they ever do.
312            _ => {
313                mangled.push('x');
314                mangled.push_str(&format!("{:?}", v));
315            }
316        }
317    }
318    mangled
319}
320
321/// View into the parts of a base function or method that specialization
322/// needs. Adapts `FunctionInfo` and `MethodInfo` to one shape so the synthesis
323/// logic can stay generic.
324struct SpecializeBase {
325    params: ParamRange,
326    return_type: Type,
327    return_type_sym: Spur,
328    body: InstRef,
329    span: Span,
330    /// `Some((struct_type, has_self))` for methods (ADR-0055); `None` for
331    /// free functions.
332    method: Option<(Type, bool)>,
333}
334
335impl SpecializeBase {
336    fn function(info: &FunctionInfo) -> Self {
337        Self {
338            params: info.params,
339            return_type: info.return_type,
340            return_type_sym: info.return_type_sym,
341            body: info.body,
342            span: info.span,
343            method: None,
344        }
345    }
346
347    fn method(info: &MethodInfo) -> Self {
348        Self {
349            params: info.params,
350            return_type: info.return_type,
351            return_type_sym: info.return_type_sym,
352            body: info.body,
353            span: info.span,
354            method: Some((info.struct_type, info.has_self)),
355        }
356    }
357}
358
359/// Synthesize a specialized function or method by re-analyzing the body with
360/// the type substitutions implied by `key.type_args`.
361///
362/// Comptime params are erased at runtime; references to them are substituted
363/// with concrete types via the resulting `type_subst` map. For methods,
364/// `Self` is also wired in so `Self { ... }` literals and `Self::Variant`
365/// paths resolve, and the receiver is prepended to the parameter list.
366///
367/// ADR-0055 (method-level comptime type params), ADR-0056 (interface bounds).
368fn create_specialized(
369    sema: &mut Sema<'_>,
370    infer_ctx: &InferenceContext,
371    key: &SpecializationKey,
372    specialized_name: Spur,
373    base: SpecializeBase,
374    interner: &ThreadedRodeo,
375    refs: &mut SpecializationRefs,
376) -> CompileResult<AnalyzedRow> {
377    let specialized_name_str = interner.resolve(&specialized_name).to_string();
378
379    let param_names = sema.param_arena.names(base.params).to_vec();
380    let param_types = sema.param_arena.types(base.params).to_vec();
381    let param_modes = sema.param_arena.modes(base.params).to_vec();
382    let param_comptime = sema.param_arena.comptime(base.params).to_vec();
383
384    // Determine which comptime params are type-shaped (`comptime T: type` or
385    // `comptime T: SomeInterface`) vs value-shaped (`comptime n: i32`). The
386    // call site emits separate `type_args` and `value_args` lists in matching
387    // declaration order, so we walk both side by side.
388    let type_sym = interner.get_or_intern("type");
389    let declared_ty_syms: Vec<Option<Spur>> = param_names
390        .iter()
391        .map(|n| param_declared_type_sym(sema, base.body, *n))
392        .collect();
393    let is_comptime_type_param: Vec<bool> = param_names
394        .iter()
395        .enumerate()
396        .map(|(i, _)| {
397            if !param_comptime[i] {
398                return false;
399            }
400            // Type-shaped if declared as `type` or as an interface name; the
401            // resolved param type is a fallback for synthesized methods whose
402            // declared symbol isn't in RIR.
403            match declared_ty_syms[i] {
404                Some(s) => s == type_sym || sema.interfaces.contains_key(&s),
405                None => param_types[i] == Type::COMPTIME_TYPE,
406            }
407        })
408        .collect();
409
410    let mut type_subst: HashMap<Spur, Type> = HashMap::default();
411    let mut value_subst: HashMap<Spur, ConstValue> = HashMap::default();
412    let mut type_arg_idx = 0;
413    let mut value_arg_idx = 0;
414    for (i, &is_type) in is_comptime_type_param.iter().enumerate() {
415        if is_type && type_arg_idx < key.type_args.len() {
416            type_subst.insert(param_names[i], key.type_args[type_arg_idx]);
417            type_arg_idx += 1;
418        } else if param_comptime[i] && !is_type && value_arg_idx < key.value_args.len() {
419            value_subst.insert(param_names[i], key.value_args[value_arg_idx]);
420            value_arg_idx += 1;
421        }
422    }
423    if let Some((struct_type, _)) = base.method {
424        let self_sym = interner.get_or_intern("Self");
425        type_subst.insert(self_sym, struct_type);
426    }
427
428    // ADR-0056: for any comptime param with an interface bound, verify the
429    // concrete type structurally conforms. The bound table keys by
430    // (owner, param) where owner is the function name or "StructName.method"
431    // — both are already encoded in `key.base_name`.
432    for p in &param_names {
433        if let Some(iid) = sema
434            .comptime_interface_bounds
435            .get(&(key.base_name, *p))
436            .copied()
437            && let Some(&concrete) = type_subst.get(p)
438        {
439            sema.check_conforms(concrete, iid, base.span)?;
440        }
441    }
442
443    // Substitute the return type if it references a type parameter (or
444    // `Self`). Three cases:
445    //
446    //   1. Bare type parameter (`-> T`): direct hit in `type_subst`.
447    //   2. Compound type (`-> Ptr(T)`, `-> MutRef(Vec(T))`): the source
448    //      symbol isn't in `type_subst` directly, but the resolver walks
449    //      its inner positions and substitutes there.
450    //   3. Concrete type (`-> i32`): both lookups miss and we keep the
451    //      already-resolved `base.return_type`.
452    //
453    // Without case 2, generic helpers like
454    //   `pub fn vec_ptr(...) -> Ptr(T) { ... }`
455    // get a `COMPTIME_TYPE` return type at specialization, which leaves
456    // body inference unable to fix the result type of `@ptr_cast` (or
457    // any other return-type-inferred intrinsic).
458    let return_type = if let Some(&concrete) = type_subst.get(&base.return_type_sym) {
459        concrete
460    } else if base.return_type == Type::COMPTIME_TYPE {
461        sema.resolve_type_for_comptime_with_subst(base.return_type_sym, &type_subst)
462            .unwrap_or(base.return_type)
463    } else {
464        base.return_type
465    };
466
467    // Specialized param list: prepend `self` for methods with a receiver,
468    // drop comptime params (erased), substitute `ComptimeType` references.
469    let mut specialized_params: Vec<(Spur, Type, RirParamMode)> = Vec::new();
470    if let Some((struct_type, true)) = base.method {
471        let self_val_sym = interner.get_or_intern("self");
472        specialized_params.push((self_val_sym, struct_type, RirParamMode::Normal));
473    }
474    for i in 0..param_names.len() {
475        if param_comptime[i] {
476            continue;
477        }
478        let name = param_names[i];
479        let ty = param_types[i];
480        let mode = param_modes[i];
481        let concrete_ty = if ty == Type::COMPTIME_TYPE {
482            substitute_param_type(sema, base.body, name, &type_subst).unwrap_or(ty)
483        } else {
484            ty
485        };
486        specialized_params.push((name, concrete_ty, mode));
487    }
488
489    let (
490        air,
491        num_locals,
492        num_param_slots,
493        modes_result,
494        param_slot_types,
495        _warnings,
496        local_strings,
497        local_bytes,
498        ref_fns,
499        ref_meths,
500    ) = sema.analyze_specialized_function(
501        infer_ctx,
502        return_type,
503        &specialized_params,
504        base.body,
505        &type_subst,
506        Some(&value_subst),
507    )?;
508
509    refs.fns.extend(ref_fns);
510    refs.meths.extend(ref_meths);
511
512    let analyzed = AnalyzedFunction {
513        name: specialized_name_str,
514        air,
515        num_locals,
516        num_param_slots,
517        param_modes: modes_result,
518        param_slot_types,
519        is_destructor: false,
520    };
521    Ok((analyzed, local_strings, local_bytes))
522}
523
524/// Resolve a type-parameter reference on `param_name` by walking the RIR to
525/// find the `FnDecl` whose body matches `body`, then resolving the
526/// parameter's source-text type symbol under `type_subst`. The substitution-
527/// aware resolver lets compound types (`MutRef(Vec(T))`, `[T; N]`,
528/// `Ptr(T)`, etc.) substitute through their inner positions, not just
529/// the bare-`T` case.
530fn substitute_param_type(
531    sema: &mut Sema<'_>,
532    body: InstRef,
533    param_name: Spur,
534    type_subst: &HashMap<Spur, Type>,
535) -> Option<Type> {
536    // Find the parameter's declared type symbol.
537    let mut declared_ty: Option<Spur> = None;
538    for (_, inst) in sema.rir.iter() {
539        if let gruel_rir::InstData::FnDecl {
540            body: fn_body,
541            params_start,
542            params_len,
543            ..
544        } = &inst.data
545            && *fn_body == body
546        {
547            let params = sema.rir.get_params(*params_start, *params_len);
548            for param in params {
549                if param.name == param_name {
550                    declared_ty = Some(param.ty);
551                    break;
552                }
553            }
554            if declared_ty.is_some() {
555                break;
556            }
557        }
558    }
559    let declared_ty = declared_ty?;
560    // Fast path: bare `T` — direct hit.
561    if let Some(&concrete) = type_subst.get(&declared_ty) {
562        return Some(concrete);
563    }
564    // Compound case: `MutRef(Vec(T))`, `Ptr(T)`, etc. Resolve under the
565    // substitution. Uses the same comptime-substitution resolver the
566    // method-side uses for inline `comptime T: type` methods.
567    sema.resolve_type_for_comptime_with_subst(declared_ty, type_subst)
568}
569
570/// Look up the *source-text* type symbol declared for `param_name` in the
571/// `FnDecl` whose body is `body`. Used by specialization to distinguish a
572/// `comptime T: type` parameter (declared symbol == "type") from a
573/// `comptime n: i32` parameter (declared symbol == "i32") — the resolved
574/// `Type` field on the param can be `COMPTIME_TYPE` for both.
575fn param_declared_type_sym(sema: &Sema<'_>, body: InstRef, param_name: Spur) -> Option<Spur> {
576    for (_, inst) in sema.rir.iter() {
577        if let gruel_rir::InstData::FnDecl {
578            body: fn_body,
579            params_start,
580            params_len,
581            ..
582        } = &inst.data
583            && *fn_body == body
584        {
585            let params = sema.rir.get_params(*params_start, *params_len);
586            for param in params {
587                if param.name == param_name {
588                    return Some(param.ty);
589                }
590            }
591        }
592    }
593    None
594}