1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
16
17use gruel_rir::{InstRef, RirParamMode};
18use gruel_util::Span;
19use gruel_util::{CompileError, CompileResult, ErrorKind};
20use lasso::{Spur, ThreadedRodeo};
21
22use crate::inst::{Air, AirInstData};
23use crate::param_arena::ParamRange;
24use crate::sema::{AnalyzedFunction, ConstValue, FunctionInfo, InferenceContext, MethodInfo, Sema};
25use crate::types::{StructId, Type};
26
27#[derive(Debug, Default)]
31pub struct SpecializationRefs {
32 pub fns: HashSet<Spur>,
33 pub meths: HashSet<(StructId, Spur)>,
34}
35
36pub type AnalyzedRow = (AnalyzedFunction, Vec<String>, Vec<Vec<u8>>);
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub struct SpecializationKey {
44 pub base_name: Spur,
46 pub type_args: Vec<Type>,
48 pub value_args: Vec<ConstValue>,
54}
55
56struct SpecializationInfo {
58 mangled_name: Spur,
60 call_site_span: Span,
62}
63
64pub fn specialize(
79 functions_with_strings: &mut Vec<AnalyzedRow>,
80 name_map: &mut HashMap<SpecializationKey, Spur>,
81 sema: &mut Sema<'_>,
82 infer_ctx: &InferenceContext,
83 interner: &ThreadedRodeo,
84) -> CompileResult<SpecializationRefs> {
85 let mut accumulated_refs = SpecializationRefs::default();
86
87 loop {
88 let mut seen: HashMap<SpecializationKey, SpecializationInfo> = HashMap::default();
90 for (func, _, _) in functions_with_strings.iter() {
91 collect_specializations(&func.air, interner, &mut seen);
92 }
93
94 let new_specs: Vec<(SpecializationKey, SpecializationInfo)> = seen
96 .into_iter()
97 .filter(|(k, _)| !name_map.contains_key(k))
98 .collect();
99
100 for (k, info) in &new_specs {
101 name_map.insert(k.clone(), info.mangled_name);
102 }
103
104 for (func, _, _) in functions_with_strings.iter_mut() {
113 rewrite_call_generic(&mut func.air, name_map);
114 }
115
116 if new_specs.is_empty() {
117 return Ok(accumulated_refs);
118 }
119
120 for (key, info) in new_specs {
122 let base = if let Some(fn_info) = sema.functions.get(&key.base_name).copied() {
123 SpecializeBase::function(&fn_info)
124 } else if let Some((struct_id, method_sym)) =
125 resolve_method_name(sema, interner, key.base_name)
126 && let Some(method_info) = sema.methods.get(&(struct_id, method_sym)).copied()
127 {
128 SpecializeBase::method(&method_info)
130 } else {
131 let func_name = interner.resolve(&key.base_name);
132 return Err(CompileError::new(
133 ErrorKind::UndefinedFunction(func_name.to_string()),
134 info.call_site_span,
135 ));
136 };
137
138 let row = create_specialized(
139 sema,
140 infer_ctx,
141 &key,
142 info.mangled_name,
143 base,
144 interner,
145 &mut accumulated_refs,
146 )?;
147 functions_with_strings.push(row);
148 }
149 }
150}
151
152fn resolve_method_name(
156 sema: &Sema<'_>,
157 interner: &ThreadedRodeo,
158 name: Spur,
159) -> Option<(StructId, Spur)> {
160 let name_str = interner.resolve(&name);
161 let (struct_str, method_str) = name_str.rsplit_once('.')?;
162 let struct_sym = interner.get(struct_str)?;
163 let struct_id = *sema.structs.get(&struct_sym)?;
164 let method_sym = interner.get(method_str)?;
165 Some((struct_id, method_sym))
166}
167
168fn collect_specializations(
170 air: &Air,
171 interner: &ThreadedRodeo,
172 specializations: &mut HashMap<SpecializationKey, SpecializationInfo>,
173) {
174 for (i, inst) in air.instructions().iter().enumerate() {
175 if let AirInstData::CallGeneric {
176 name,
177 type_args_start,
178 type_args_len,
179 ..
180 } = &inst.data
181 {
182 let type_args: Vec<Type> = air
184 .get_extra(*type_args_start, *type_args_len)
185 .iter()
186 .map(|&encoded| Type::from_u32(encoded))
187 .collect();
188
189 let value_args = air.comptime_value_args(i as u32).to_vec();
193
194 let key = SpecializationKey {
195 base_name: *name,
196 type_args: type_args.clone(),
197 value_args: value_args.clone(),
198 };
199
200 specializations.entry(key).or_insert_with(|| {
201 let base_name = interner.resolve(name);
203 let mangled = mangle_specialized_name(base_name, &type_args, &value_args);
204 let mangled_sym = interner.get_or_intern(&mangled);
205 SpecializationInfo {
206 mangled_name: mangled_sym,
207 call_site_span: inst.span,
208 }
209 });
210 }
211 }
212}
213
214fn rewrite_call_generic(air: &mut Air, specializations: &HashMap<SpecializationKey, Spur>) {
216 let mut rewrites: Vec<(usize, AirInstData)> = Vec::new();
219
220 for (i, inst) in air.instructions().iter().enumerate() {
221 if let AirInstData::CallGeneric {
222 name,
223 type_args_start,
224 type_args_len,
225 args_start,
226 args_len,
227 } = &inst.data
228 {
229 let type_args: Vec<Type> = air
231 .get_extra(*type_args_start, *type_args_len)
232 .iter()
233 .map(|&encoded| Type::from_u32(encoded))
234 .collect();
235
236 let value_args = air.comptime_value_args(i as u32).to_vec();
237
238 let key = SpecializationKey {
239 base_name: *name,
240 type_args,
241 value_args,
242 };
243
244 if let Some(&specialized_name) = specializations.get(&key) {
245 let new_data = AirInstData::Call {
247 name: specialized_name,
248 args_start: *args_start,
249 args_len: *args_len,
250 };
251 rewrites.push((i, new_data));
252 }
253 }
254 }
255
256 for (index, new_data) in rewrites {
258 air.rewrite_inst_data(index, new_data);
259 }
260}
261
262fn mangle_specialized_name(
272 base_name: &str,
273 type_args: &[Type],
274 value_args: &[ConstValue],
275) -> String {
276 let mut mangled = base_name.to_string();
277 for ty in type_args {
278 mangled.push_str("__");
279 mangled.push_str(ty.name());
280 mangled.push('#');
283 mangled.push_str(&ty.as_u32().to_string());
284 }
285 for v in value_args {
286 mangled.push_str("__v");
287 match v {
288 ConstValue::Integer(n) => {
289 mangled.push('i');
290 if *n < 0 {
291 mangled.push('m');
292 mangled.push_str(&(-(*n as i128)).to_string());
293 } else {
294 mangled.push_str(&n.to_string());
295 }
296 }
297 ConstValue::Bool(b) => {
298 mangled.push('b');
299 mangled.push(if *b { '1' } else { '0' });
300 }
301 ConstValue::Type(t) => {
302 mangled.push('t');
303 mangled.push_str(&t.as_u32().to_string());
304 }
305 ConstValue::ComptimeStr(idx) => {
306 mangled.push('s');
307 mangled.push_str(&idx.to_string());
308 }
309 ConstValue::Unit => mangled.push('u'),
310 _ => {
313 mangled.push('x');
314 mangled.push_str(&format!("{:?}", v));
315 }
316 }
317 }
318 mangled
319}
320
321struct SpecializeBase {
325 params: ParamRange,
326 return_type: Type,
327 return_type_sym: Spur,
328 body: InstRef,
329 span: Span,
330 method: Option<(Type, bool)>,
333}
334
335impl SpecializeBase {
336 fn function(info: &FunctionInfo) -> Self {
337 Self {
338 params: info.params,
339 return_type: info.return_type,
340 return_type_sym: info.return_type_sym,
341 body: info.body,
342 span: info.span,
343 method: None,
344 }
345 }
346
347 fn method(info: &MethodInfo) -> Self {
348 Self {
349 params: info.params,
350 return_type: info.return_type,
351 return_type_sym: info.return_type_sym,
352 body: info.body,
353 span: info.span,
354 method: Some((info.struct_type, info.has_self)),
355 }
356 }
357}
358
359fn create_specialized(
369 sema: &mut Sema<'_>,
370 infer_ctx: &InferenceContext,
371 key: &SpecializationKey,
372 specialized_name: Spur,
373 base: SpecializeBase,
374 interner: &ThreadedRodeo,
375 refs: &mut SpecializationRefs,
376) -> CompileResult<AnalyzedRow> {
377 let specialized_name_str = interner.resolve(&specialized_name).to_string();
378
379 let param_names = sema.param_arena.names(base.params).to_vec();
380 let param_types = sema.param_arena.types(base.params).to_vec();
381 let param_modes = sema.param_arena.modes(base.params).to_vec();
382 let param_comptime = sema.param_arena.comptime(base.params).to_vec();
383
384 let type_sym = interner.get_or_intern("type");
389 let declared_ty_syms: Vec<Option<Spur>> = param_names
390 .iter()
391 .map(|n| param_declared_type_sym(sema, base.body, *n))
392 .collect();
393 let is_comptime_type_param: Vec<bool> = param_names
394 .iter()
395 .enumerate()
396 .map(|(i, _)| {
397 if !param_comptime[i] {
398 return false;
399 }
400 match declared_ty_syms[i] {
404 Some(s) => s == type_sym || sema.interfaces.contains_key(&s),
405 None => param_types[i] == Type::COMPTIME_TYPE,
406 }
407 })
408 .collect();
409
410 let mut type_subst: HashMap<Spur, Type> = HashMap::default();
411 let mut value_subst: HashMap<Spur, ConstValue> = HashMap::default();
412 let mut type_arg_idx = 0;
413 let mut value_arg_idx = 0;
414 for (i, &is_type) in is_comptime_type_param.iter().enumerate() {
415 if is_type && type_arg_idx < key.type_args.len() {
416 type_subst.insert(param_names[i], key.type_args[type_arg_idx]);
417 type_arg_idx += 1;
418 } else if param_comptime[i] && !is_type && value_arg_idx < key.value_args.len() {
419 value_subst.insert(param_names[i], key.value_args[value_arg_idx]);
420 value_arg_idx += 1;
421 }
422 }
423 if let Some((struct_type, _)) = base.method {
424 let self_sym = interner.get_or_intern("Self");
425 type_subst.insert(self_sym, struct_type);
426 }
427
428 for p in ¶m_names {
433 if let Some(iid) = sema
434 .comptime_interface_bounds
435 .get(&(key.base_name, *p))
436 .copied()
437 && let Some(&concrete) = type_subst.get(p)
438 {
439 sema.check_conforms(concrete, iid, base.span)?;
440 }
441 }
442
443 let return_type = if let Some(&concrete) = type_subst.get(&base.return_type_sym) {
459 concrete
460 } else if base.return_type == Type::COMPTIME_TYPE {
461 sema.resolve_type_for_comptime_with_subst(base.return_type_sym, &type_subst)
462 .unwrap_or(base.return_type)
463 } else {
464 base.return_type
465 };
466
467 let mut specialized_params: Vec<(Spur, Type, RirParamMode)> = Vec::new();
470 if let Some((struct_type, true)) = base.method {
471 let self_val_sym = interner.get_or_intern("self");
472 specialized_params.push((self_val_sym, struct_type, RirParamMode::Normal));
473 }
474 for i in 0..param_names.len() {
475 if param_comptime[i] {
476 continue;
477 }
478 let name = param_names[i];
479 let ty = param_types[i];
480 let mode = param_modes[i];
481 let concrete_ty = if ty == Type::COMPTIME_TYPE {
482 substitute_param_type(sema, base.body, name, &type_subst).unwrap_or(ty)
483 } else {
484 ty
485 };
486 specialized_params.push((name, concrete_ty, mode));
487 }
488
489 let (
490 air,
491 num_locals,
492 num_param_slots,
493 modes_result,
494 param_slot_types,
495 _warnings,
496 local_strings,
497 local_bytes,
498 ref_fns,
499 ref_meths,
500 ) = sema.analyze_specialized_function(
501 infer_ctx,
502 return_type,
503 &specialized_params,
504 base.body,
505 &type_subst,
506 Some(&value_subst),
507 )?;
508
509 refs.fns.extend(ref_fns);
510 refs.meths.extend(ref_meths);
511
512 let analyzed = AnalyzedFunction {
513 name: specialized_name_str,
514 air,
515 num_locals,
516 num_param_slots,
517 param_modes: modes_result,
518 param_slot_types,
519 is_destructor: false,
520 };
521 Ok((analyzed, local_strings, local_bytes))
522}
523
524fn substitute_param_type(
531 sema: &mut Sema<'_>,
532 body: InstRef,
533 param_name: Spur,
534 type_subst: &HashMap<Spur, Type>,
535) -> Option<Type> {
536 let mut declared_ty: Option<Spur> = None;
538 for (_, inst) in sema.rir.iter() {
539 if let gruel_rir::InstData::FnDecl {
540 body: fn_body,
541 params_start,
542 params_len,
543 ..
544 } = &inst.data
545 && *fn_body == body
546 {
547 let params = sema.rir.get_params(*params_start, *params_len);
548 for param in params {
549 if param.name == param_name {
550 declared_ty = Some(param.ty);
551 break;
552 }
553 }
554 if declared_ty.is_some() {
555 break;
556 }
557 }
558 }
559 let declared_ty = declared_ty?;
560 if let Some(&concrete) = type_subst.get(&declared_ty) {
562 return Some(concrete);
563 }
564 sema.resolve_type_for_comptime_with_subst(declared_ty, type_subst)
568}
569
570fn param_declared_type_sym(sema: &Sema<'_>, body: InstRef, param_name: Spur) -> Option<Spur> {
576 for (_, inst) in sema.rir.iter() {
577 if let gruel_rir::InstData::FnDecl {
578 body: fn_body,
579 params_start,
580 params_len,
581 ..
582 } = &inst.data
583 && *fn_body == body
584 {
585 let params = sema.rir.get_params(*params_start, *params_len);
586 for param in params {
587 if param.name == param_name {
588 return Some(param.ty);
589 }
590 }
591 }
592 }
593 None
594}