gruel_air/
specialize.rs

1//! Generic function specialization pass.
2//!
3//! This module provides the specialization pass that transforms `CallGeneric`
4//! instructions into regular `Call` instructions by:
5//!
6//! 1. Collecting all `CallGeneric` instructions in the analyzed functions
7//! 2. For each unique (func_name, type_args) combination, creating a specialized function
8//! 3. Rewriting `CallGeneric` to `Call` with the specialized function name
9//!
10//! # Architecture
11//!
12//! The specialization pass runs after semantic analysis but before CFG building.
13//! It transforms the AIR in-place and adds new specialized functions to the output.
14
15use std::collections::HashMap;
16
17use gruel_error::{CompileError, CompileResult, ErrorKind};
18use gruel_rir::RirParamMode;
19use gruel_span::Span;
20use lasso::{Spur, ThreadedRodeo};
21
22use crate::inst::{Air, AirInstData};
23use crate::sema::{AnalyzedFunction, FunctionInfo, InferenceContext, Sema, SemaOutput};
24use crate::types::Type;
25
26/// A key for a specialized function: (base_function_name, type_arguments).
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub struct SpecializationKey {
29    /// Base function name (e.g., "identity")
30    pub base_name: Spur,
31    /// Type arguments (e.g., [Type::I32])
32    pub type_args: Vec<Type>,
33}
34
35/// Info about a specialization: the mangled name and the first call site span.
36struct SpecializationInfo {
37    /// The mangled name for the specialized function.
38    mangled_name: Spur,
39    /// The span of the first call site (for error reporting if the function doesn't exist).
40    call_site_span: Span,
41}
42
43/// Perform the specialization pass on the sema output.
44///
45/// This collects all `CallGeneric` instructions, creates specialized functions,
46/// and rewrites calls to point to the specialized versions.
47pub fn specialize(
48    output: &mut SemaOutput,
49    sema: &mut Sema<'_>,
50    infer_ctx: &InferenceContext,
51    interner: &ThreadedRodeo,
52) -> CompileResult<()> {
53    // Phase 1: Collect all specialization requests
54    let mut specializations: HashMap<SpecializationKey, SpecializationInfo> = HashMap::new();
55
56    for func in &output.functions {
57        collect_specializations(&func.air, interner, &mut specializations);
58    }
59
60    if specializations.is_empty() {
61        // No generic calls, nothing to do
62        return Ok(());
63    }
64
65    // Build a map from key to just the mangled name for the rewrite phase
66    let name_map: HashMap<SpecializationKey, Spur> = specializations
67        .iter()
68        .map(|(k, v)| (k.clone(), v.mangled_name))
69        .collect();
70
71    // Phase 2: Rewrite CallGeneric to Call in all functions
72    for func in &mut output.functions {
73        rewrite_call_generic(&mut func.air, &name_map);
74    }
75
76    // Phase 3: Create specialized function bodies by re-analyzing with type substitution
77    for (key, info) in &specializations {
78        let base_info = match sema.functions.get(&key.base_name) {
79            Some(info) => *info,
80            None => {
81                let func_name = interner.resolve(&key.base_name);
82                return Err(CompileError::new(
83                    ErrorKind::UndefinedFunction(func_name.to_string()),
84                    info.call_site_span,
85                ));
86            }
87        };
88        let specialized_func = create_specialized_function(
89            sema,
90            infer_ctx,
91            key,
92            info.mangled_name,
93            &base_info,
94            interner,
95        )?;
96        output.functions.push(specialized_func);
97    }
98
99    Ok(())
100}
101
102/// Collect all specializations needed from a function's AIR.
103fn collect_specializations(
104    air: &Air,
105    interner: &ThreadedRodeo,
106    specializations: &mut HashMap<SpecializationKey, SpecializationInfo>,
107) {
108    for inst in air.instructions() {
109        if let AirInstData::CallGeneric {
110            name,
111            type_args_start,
112            type_args_len,
113            ..
114        } = &inst.data
115        {
116            // Extract type arguments using the public accessor
117            let type_args: Vec<Type> = air
118                .get_extra(*type_args_start, *type_args_len)
119                .iter()
120                .map(|&encoded| Type::from_u32(encoded))
121                .collect();
122
123            let key = SpecializationKey {
124                base_name: *name,
125                type_args: type_args.clone(),
126            };
127
128            specializations.entry(key).or_insert_with(|| {
129                // Generate a mangled name for the specialized function
130                let base_name = interner.resolve(name);
131                let mangled = mangle_specialized_name(base_name, &type_args);
132                let mangled_sym = interner.get_or_intern(&mangled);
133                SpecializationInfo {
134                    mangled_name: mangled_sym,
135                    call_site_span: inst.span,
136                }
137            });
138        }
139    }
140}
141
142/// Rewrite CallGeneric instructions to Call instructions.
143fn rewrite_call_generic(air: &mut Air, specializations: &HashMap<SpecializationKey, Spur>) {
144    // We need to collect the rewrites first, then apply them.
145    // This avoids borrowing issues with the extra array.
146    let mut rewrites: Vec<(usize, AirInstData)> = Vec::new();
147
148    for (i, inst) in air.instructions().iter().enumerate() {
149        if let AirInstData::CallGeneric {
150            name,
151            type_args_start,
152            type_args_len,
153            args_start,
154            args_len,
155        } = &inst.data
156        {
157            // Extract type arguments to form the key
158            let type_args: Vec<Type> = air
159                .get_extra(*type_args_start, *type_args_len)
160                .iter()
161                .map(|&encoded| Type::from_u32(encoded))
162                .collect();
163
164            let key = SpecializationKey {
165                base_name: *name,
166                type_args,
167            };
168
169            if let Some(&specialized_name) = specializations.get(&key) {
170                // Rewrite to a regular Call
171                let new_data = AirInstData::Call {
172                    name: specialized_name,
173                    args_start: *args_start,
174                    args_len: *args_len,
175                };
176                rewrites.push((i, new_data));
177            }
178        }
179    }
180
181    // Apply all rewrites
182    for (index, new_data) in rewrites {
183        air.rewrite_inst_data(index, new_data);
184    }
185}
186
187/// Generate a mangled name for a specialized function.
188fn mangle_specialized_name(base_name: &str, type_args: &[Type]) -> String {
189    let mut mangled = base_name.to_string();
190    for ty in type_args {
191        mangled.push_str("__");
192        mangled.push_str(ty.name());
193    }
194    mangled
195}
196
197/// Create a specialized function by re-analyzing the body with type substitution.
198///
199/// This builds a type substitution map from the comptime parameters to their concrete
200/// type arguments, then re-analyzes the function body with these substitutions.
201fn create_specialized_function(
202    sema: &mut Sema<'_>,
203    infer_ctx: &InferenceContext,
204    key: &SpecializationKey,
205    specialized_name: Spur,
206    base_info: &FunctionInfo,
207    interner: &ThreadedRodeo,
208) -> CompileResult<AnalyzedFunction> {
209    let specialized_name_str = interner.resolve(&specialized_name).to_string();
210
211    // Get parameter data from the arena
212    let param_names = sema.param_arena.names(base_info.params);
213    let param_types = sema.param_arena.types(base_info.params);
214    let param_modes = sema.param_arena.modes(base_info.params);
215    let param_comptime = sema.param_arena.comptime(base_info.params);
216
217    // Build the type substitution map: comptime param name -> concrete Type
218    let mut type_subst: HashMap<Spur, Type> = HashMap::new();
219    let mut type_arg_idx = 0;
220    for (i, is_comptime) in param_comptime.iter().enumerate() {
221        if *is_comptime && type_arg_idx < key.type_args.len() {
222            type_subst.insert(param_names[i], key.type_args[type_arg_idx]);
223            type_arg_idx += 1;
224        }
225    }
226
227    // Calculate the return type by substituting type parameters
228    let return_type = if base_info.return_type == Type::COMPTIME_TYPE {
229        // The return type references a type parameter - substitute it
230        type_subst
231            .get(&base_info.return_type_sym)
232            .copied()
233            .unwrap_or(Type::UNIT)
234    } else {
235        base_info.return_type
236    };
237
238    // Build the specialized parameter list by:
239    // 1. Filtering out comptime parameters (they're erased at runtime)
240    // 2. Substituting type parameters in non-comptime parameter types
241    let specialized_params: Vec<(Spur, Type, RirParamMode)> = param_names
242        .iter()
243        .zip(param_types.iter())
244        .zip(param_modes.iter())
245        .zip(param_comptime.iter())
246        .filter(|(((_, _), _), is_comptime)| !*is_comptime)
247        .map(|(((name, ty), mode), _)| {
248            // If the type is ComptimeType, look it up in the substitution map
249            // The param name's type symbol is stored in param_types as ComptimeType,
250            // but we need to find which type param it references.
251            // For now, we'll need to look at the original RIR to get the type name.
252            let concrete_ty = if *ty == Type::COMPTIME_TYPE {
253                // This parameter's type is a type parameter. We need to find which one.
254                // The type name in RIR is stored in the param's ty field as a Spur.
255                // Unfortunately, we've lost that information by this point.
256                // We need to look at the original function in RIR.
257                substitute_param_type(sema, base_info, *name, &type_subst).unwrap_or(*ty)
258            } else {
259                *ty
260            };
261            (*name, concrete_ty, *mode)
262        })
263        .collect();
264
265    // Now analyze the function body with the specialized types
266    let (
267        air,
268        num_locals,
269        num_param_slots,
270        param_modes,
271        param_slot_types,
272        _warnings,
273        _local_strings,
274        _ref_fns,
275        _ref_meths,
276    ) = sema.analyze_specialized_function(
277        infer_ctx,
278        return_type,
279        &specialized_params,
280        base_info.body,
281        &type_subst,
282    )?;
283
284    Ok(AnalyzedFunction {
285        name: specialized_name_str,
286        air,
287        num_locals,
288        num_param_slots,
289        param_modes,
290        param_slot_types,
291        is_destructor: false,
292    })
293}
294
295/// Substitute a parameter's type using the type substitution map.
296///
297/// This looks up the parameter's type symbol in the original RIR function
298/// and substitutes it with the concrete type if it's a type parameter.
299fn substitute_param_type(
300    sema: &Sema<'_>,
301    base_info: &FunctionInfo,
302    param_name: Spur,
303    type_subst: &HashMap<Spur, Type>,
304) -> Option<Type> {
305    // Walk up to find the FnDecl that contains this body
306    for (_, inst) in sema.rir.iter() {
307        if let gruel_rir::InstData::FnDecl {
308            body,
309            params_start,
310            params_len,
311            ..
312        } = &inst.data
313            && *body == base_info.body
314        {
315            // Found the function declaration
316            let params = sema.rir.get_params(*params_start, *params_len);
317            for param in params {
318                if param.name == param_name {
319                    // Found the parameter - get its type symbol
320                    // If the type symbol is in our substitution map, use that
321                    if let Some(&concrete_ty) = type_subst.get(&param.ty) {
322                        return Some(concrete_ty);
323                    }
324                }
325            }
326        }
327    }
328
329    None
330}