1use std::collections::HashMap;
16
17use gruel_error::{CompileError, CompileResult, ErrorKind};
18use gruel_rir::RirParamMode;
19use gruel_span::Span;
20use lasso::{Spur, ThreadedRodeo};
21
22use crate::inst::{Air, AirInstData};
23use crate::sema::{AnalyzedFunction, FunctionInfo, InferenceContext, Sema, SemaOutput};
24use crate::types::Type;
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub struct SpecializationKey {
29 pub base_name: Spur,
31 pub type_args: Vec<Type>,
33}
34
35struct SpecializationInfo {
37 mangled_name: Spur,
39 call_site_span: Span,
41}
42
43pub fn specialize(
48 output: &mut SemaOutput,
49 sema: &mut Sema<'_>,
50 infer_ctx: &InferenceContext,
51 interner: &ThreadedRodeo,
52) -> CompileResult<()> {
53 let mut specializations: HashMap<SpecializationKey, SpecializationInfo> = HashMap::new();
55
56 for func in &output.functions {
57 collect_specializations(&func.air, interner, &mut specializations);
58 }
59
60 if specializations.is_empty() {
61 return Ok(());
63 }
64
65 let name_map: HashMap<SpecializationKey, Spur> = specializations
67 .iter()
68 .map(|(k, v)| (k.clone(), v.mangled_name))
69 .collect();
70
71 for func in &mut output.functions {
73 rewrite_call_generic(&mut func.air, &name_map);
74 }
75
76 for (key, info) in &specializations {
78 let base_info = match sema.functions.get(&key.base_name) {
79 Some(info) => *info,
80 None => {
81 let func_name = interner.resolve(&key.base_name);
82 return Err(CompileError::new(
83 ErrorKind::UndefinedFunction(func_name.to_string()),
84 info.call_site_span,
85 ));
86 }
87 };
88 let specialized_func = create_specialized_function(
89 sema,
90 infer_ctx,
91 key,
92 info.mangled_name,
93 &base_info,
94 interner,
95 )?;
96 output.functions.push(specialized_func);
97 }
98
99 Ok(())
100}
101
102fn collect_specializations(
104 air: &Air,
105 interner: &ThreadedRodeo,
106 specializations: &mut HashMap<SpecializationKey, SpecializationInfo>,
107) {
108 for inst in air.instructions() {
109 if let AirInstData::CallGeneric {
110 name,
111 type_args_start,
112 type_args_len,
113 ..
114 } = &inst.data
115 {
116 let type_args: Vec<Type> = air
118 .get_extra(*type_args_start, *type_args_len)
119 .iter()
120 .map(|&encoded| Type::from_u32(encoded))
121 .collect();
122
123 let key = SpecializationKey {
124 base_name: *name,
125 type_args: type_args.clone(),
126 };
127
128 specializations.entry(key).or_insert_with(|| {
129 let base_name = interner.resolve(name);
131 let mangled = mangle_specialized_name(base_name, &type_args);
132 let mangled_sym = interner.get_or_intern(&mangled);
133 SpecializationInfo {
134 mangled_name: mangled_sym,
135 call_site_span: inst.span,
136 }
137 });
138 }
139 }
140}
141
142fn rewrite_call_generic(air: &mut Air, specializations: &HashMap<SpecializationKey, Spur>) {
144 let mut rewrites: Vec<(usize, AirInstData)> = Vec::new();
147
148 for (i, inst) in air.instructions().iter().enumerate() {
149 if let AirInstData::CallGeneric {
150 name,
151 type_args_start,
152 type_args_len,
153 args_start,
154 args_len,
155 } = &inst.data
156 {
157 let type_args: Vec<Type> = air
159 .get_extra(*type_args_start, *type_args_len)
160 .iter()
161 .map(|&encoded| Type::from_u32(encoded))
162 .collect();
163
164 let key = SpecializationKey {
165 base_name: *name,
166 type_args,
167 };
168
169 if let Some(&specialized_name) = specializations.get(&key) {
170 let new_data = AirInstData::Call {
172 name: specialized_name,
173 args_start: *args_start,
174 args_len: *args_len,
175 };
176 rewrites.push((i, new_data));
177 }
178 }
179 }
180
181 for (index, new_data) in rewrites {
183 air.rewrite_inst_data(index, new_data);
184 }
185}
186
187fn mangle_specialized_name(base_name: &str, type_args: &[Type]) -> String {
189 let mut mangled = base_name.to_string();
190 for ty in type_args {
191 mangled.push_str("__");
192 mangled.push_str(ty.name());
193 }
194 mangled
195}
196
197fn create_specialized_function(
202 sema: &mut Sema<'_>,
203 infer_ctx: &InferenceContext,
204 key: &SpecializationKey,
205 specialized_name: Spur,
206 base_info: &FunctionInfo,
207 interner: &ThreadedRodeo,
208) -> CompileResult<AnalyzedFunction> {
209 let specialized_name_str = interner.resolve(&specialized_name).to_string();
210
211 let param_names = sema.param_arena.names(base_info.params);
213 let param_types = sema.param_arena.types(base_info.params);
214 let param_modes = sema.param_arena.modes(base_info.params);
215 let param_comptime = sema.param_arena.comptime(base_info.params);
216
217 let mut type_subst: HashMap<Spur, Type> = HashMap::new();
219 let mut type_arg_idx = 0;
220 for (i, is_comptime) in param_comptime.iter().enumerate() {
221 if *is_comptime && type_arg_idx < key.type_args.len() {
222 type_subst.insert(param_names[i], key.type_args[type_arg_idx]);
223 type_arg_idx += 1;
224 }
225 }
226
227 let return_type = if base_info.return_type == Type::COMPTIME_TYPE {
229 type_subst
231 .get(&base_info.return_type_sym)
232 .copied()
233 .unwrap_or(Type::UNIT)
234 } else {
235 base_info.return_type
236 };
237
238 let specialized_params: Vec<(Spur, Type, RirParamMode)> = param_names
242 .iter()
243 .zip(param_types.iter())
244 .zip(param_modes.iter())
245 .zip(param_comptime.iter())
246 .filter(|(((_, _), _), is_comptime)| !*is_comptime)
247 .map(|(((name, ty), mode), _)| {
248 let concrete_ty = if *ty == Type::COMPTIME_TYPE {
253 substitute_param_type(sema, base_info, *name, &type_subst).unwrap_or(*ty)
258 } else {
259 *ty
260 };
261 (*name, concrete_ty, *mode)
262 })
263 .collect();
264
265 let (
267 air,
268 num_locals,
269 num_param_slots,
270 param_modes,
271 param_slot_types,
272 _warnings,
273 _local_strings,
274 _ref_fns,
275 _ref_meths,
276 ) = sema.analyze_specialized_function(
277 infer_ctx,
278 return_type,
279 &specialized_params,
280 base_info.body,
281 &type_subst,
282 )?;
283
284 Ok(AnalyzedFunction {
285 name: specialized_name_str,
286 air,
287 num_locals,
288 num_param_slots,
289 param_modes,
290 param_slot_types,
291 is_destructor: false,
292 })
293}
294
295fn substitute_param_type(
300 sema: &Sema<'_>,
301 base_info: &FunctionInfo,
302 param_name: Spur,
303 type_subst: &HashMap<Spur, Type>,
304) -> Option<Type> {
305 for (_, inst) in sema.rir.iter() {
307 if let gruel_rir::InstData::FnDecl {
308 body,
309 params_start,
310 params_len,
311 ..
312 } = &inst.data
313 && *body == base_info.body
314 {
315 let params = sema.rir.get_params(*params_start, *params_len);
317 for param in params {
318 if param.name == param_name {
319 if let Some(&concrete_ty) = type_subst.get(¶m.ty) {
322 return Some(concrete_ty);
323 }
324 }
325 }
326 }
327 }
328
329 None
330}