1use lasso::{Spur, ThreadedRodeo};
7
8const TYPE_INTRINSICS: &[&str] = &["size_of", "align_of", "typeName", "typeInfo"];
11use gruel_parser::ast::{ConstDecl, DestructureBinding, DropFn, PatternBinding};
12use gruel_parser::{
13 ArgMode, AssignTarget, Ast, BinaryOp, CallArg, Directive, DirectiveArg, EnumDecl, Expr,
14 Function, IntrinsicArg, Item, LetPattern, Method, ParamMode, Pattern, Statement, StructDecl,
15 TypeExpr, UnaryOp, ast::Visibility,
16};
17
18use crate::inst::{
19 FunctionSpan, Inst, InstData, InstRef, Rir, RirArgMode, RirCallArg, RirDestructureField,
20 RirDirective, RirParam, RirParamMode, RirPattern, RirPatternBinding, RirStructPatternBinding,
21};
22
23pub struct AstGen<'a> {
25 ast: &'a Ast,
27 interner: &'a ThreadedRodeo,
29 rir: Rir,
31}
32
33impl<'a> AstGen<'a> {
34 pub fn new(ast: &'a Ast, interner: &'a ThreadedRodeo) -> Self {
36 Self {
37 ast,
38 interner,
39 rir: Rir::new(),
40 }
41 }
42
43 pub fn generate(mut self) -> Rir {
45 for item in &self.ast.items {
46 self.gen_item(item);
47 }
48 self.rir
49 }
50
51 fn gen_item(&mut self, item: &Item) {
52 match item {
53 Item::Function(func) => {
54 self.gen_function(func);
55 }
56 Item::Struct(struct_decl) => {
57 self.gen_struct(struct_decl);
58 }
59 Item::Enum(enum_decl) => {
60 self.gen_enum(enum_decl);
61 }
62 Item::DropFn(drop_fn) => {
63 self.gen_drop_fn(drop_fn);
64 }
65 Item::Const(const_decl) => {
66 self.gen_const(const_decl);
67 }
68 Item::Error(_) => {}
70 }
71 }
72
73 fn intern_type(&mut self, ty: &TypeExpr) -> Spur {
76 match ty {
77 TypeExpr::Named(ident) => ident.name, TypeExpr::Unit(_) => self.interner.get_or_intern("()"),
79 TypeExpr::Never(_) => self.interner.get_or_intern("!"),
80 TypeExpr::Array {
81 element, length, ..
82 } => {
83 let elem_sym = self.intern_type(element);
86 let elem_name = self.interner.resolve(&elem_sym);
87 let s = format!("[{}; {}]", elem_name, length);
88 self.interner.get_or_intern(&s)
89 }
90 TypeExpr::AnonymousStruct { fields, .. } => {
91 let mut s = String::from("struct { ");
93 for (i, field) in fields.iter().enumerate() {
94 if i > 0 {
95 s.push_str(", ");
96 }
97 let name = self.interner.resolve(&field.name.name);
98 let ty_sym = self.intern_type(&field.ty);
99 let ty_name = self.interner.resolve(&ty_sym);
100 s.push_str(name);
101 s.push_str(": ");
102 s.push_str(ty_name);
103 }
104 s.push_str(" }");
105 self.interner.get_or_intern(&s)
106 }
107 TypeExpr::AnonymousEnum { variants, .. } => {
108 use gruel_parser::ast::EnumVariantKind;
110 let mut s = String::from("enum { ");
111 for (i, v) in variants.iter().enumerate() {
112 if i > 0 {
113 s.push_str(", ");
114 }
115 let name = self.interner.resolve(&v.name.name);
116 s.push_str(name);
117 match &v.kind {
118 EnumVariantKind::Unit => {}
119 EnumVariantKind::Tuple(types) => {
120 s.push('(');
121 for (j, ty) in types.iter().enumerate() {
122 if j > 0 {
123 s.push_str(", ");
124 }
125 let ty_sym = self.intern_type(ty);
126 s.push_str(self.interner.resolve(&ty_sym));
127 }
128 s.push(')');
129 }
130 EnumVariantKind::Struct(fields) => {
131 s.push_str(" { ");
132 for (j, f) in fields.iter().enumerate() {
133 if j > 0 {
134 s.push_str(", ");
135 }
136 let fname = self.interner.resolve(&f.name.name);
137 let ty_sym = self.intern_type(&f.ty);
138 s.push_str(fname);
139 s.push_str(": ");
140 s.push_str(self.interner.resolve(&ty_sym));
141 }
142 s.push_str(" }");
143 }
144 }
145 }
146 s.push_str(" }");
147 self.interner.get_or_intern(&s)
148 }
149 TypeExpr::PointerConst { pointee, .. } => {
150 let pointee_sym = self.intern_type(pointee);
152 let pointee_name = self.interner.resolve(&pointee_sym);
153 let s = format!("ptr const {}", pointee_name);
154 self.interner.get_or_intern(&s)
155 }
156 TypeExpr::PointerMut { pointee, .. } => {
157 let pointee_sym = self.intern_type(pointee);
159 let pointee_name = self.interner.resolve(&pointee_sym);
160 let s = format!("ptr mut {}", pointee_name);
161 self.interner.get_or_intern(&s)
162 }
163 }
164 }
165
166 fn gen_struct(&mut self, struct_decl: &StructDecl) -> InstRef {
167 let directives = self.convert_directives(&struct_decl.directives);
168 let (directives_start, directives_len) = self.rir.add_directives(&directives);
169 let name = struct_decl.name.name; let fields: Vec<_> = struct_decl
171 .fields
172 .iter()
173 .map(|f| {
174 let field_name = f.name.name; let field_type = self.intern_type(&f.ty);
176 (field_name, field_type)
177 })
178 .collect();
179 let (fields_start, fields_len) = self.rir.add_field_decls(&fields);
180
181 let methods: Vec<_> = struct_decl
183 .methods
184 .iter()
185 .map(|m| self.gen_method(m))
186 .collect();
187 let (methods_start, methods_len) = self.rir.add_inst_refs(&methods);
188
189 self.rir.add_inst(Inst {
190 data: InstData::StructDecl {
191 directives_start,
192 directives_len,
193 is_pub: struct_decl.visibility == Visibility::Public,
194 is_linear: struct_decl.is_linear,
195 name,
196 fields_start,
197 fields_len,
198 methods_start,
199 methods_len,
200 },
201 span: struct_decl.span,
202 })
203 }
204
205 fn gen_enum(&mut self, enum_decl: &EnumDecl) -> InstRef {
206 use gruel_parser::ast::EnumVariantKind;
207
208 let name = enum_decl.name.name; let variants: Vec<(Spur, Vec<Spur>, Vec<Spur>)> = enum_decl
210 .variants
211 .iter()
212 .map(|v| {
213 let variant_name = v.name.name;
214 match &v.kind {
215 EnumVariantKind::Unit => (variant_name, vec![], vec![]),
216 EnumVariantKind::Tuple(types) => {
217 let field_types: Vec<Spur> =
218 types.iter().map(|ty| self.intern_type(ty)).collect();
219 (variant_name, field_types, vec![])
220 }
221 EnumVariantKind::Struct(fields) => {
222 let field_types: Vec<Spur> =
223 fields.iter().map(|f| self.intern_type(&f.ty)).collect();
224 let field_names: Vec<Spur> = fields.iter().map(|f| f.name.name).collect();
225 (variant_name, field_types, field_names)
226 }
227 }
228 })
229 .collect();
230 let (variants_start, variants_len) = self.rir.add_enum_variant_decls(&variants);
231
232 self.rir.add_inst(Inst {
233 data: InstData::EnumDecl {
234 is_pub: enum_decl.visibility == Visibility::Public,
235 name,
236 variants_start,
237 variants_len,
238 },
239 span: enum_decl.span,
240 })
241 }
242
243 fn gen_const(&mut self, const_decl: &ConstDecl) -> InstRef {
244 let directives = self.convert_directives(&const_decl.directives);
245 let (directives_start, directives_len) = self.rir.add_directives(&directives);
246 let name = const_decl.name.name; let ty = const_decl.ty.as_ref().map(|t| self.intern_type(t));
248 let init = self.gen_expr(&const_decl.init);
249
250 self.rir.add_inst(Inst {
251 data: InstData::ConstDecl {
252 directives_start,
253 directives_len,
254 is_pub: const_decl.visibility == Visibility::Public,
255 name,
256 ty,
257 init,
258 },
259 span: const_decl.span,
260 })
261 }
262
263 fn gen_drop_fn(&mut self, drop_fn: &DropFn) -> InstRef {
264 let type_name = drop_fn.type_name.name; let body = self.gen_expr(&drop_fn.body);
268
269 self.rir.add_inst(Inst {
270 data: InstData::DropFnDecl { type_name, body },
271 span: drop_fn.span,
272 })
273 }
274
275 fn gen_method(&mut self, method: &Method) -> InstRef {
276 let directives = self.convert_directives(&method.directives);
278 let (directives_start, directives_len) = self.rir.add_directives(&directives);
279
280 let name = method.name.name; let return_type = match &method.return_type {
283 Some(ty) => self.intern_type(ty),
284 None => self.interner.get_or_intern("()"), };
286
287 let params: Vec<_> = method
289 .params
290 .iter()
291 .map(|p| RirParam {
292 name: p.name.name, ty: self.intern_type(&p.ty),
294 mode: self.convert_param_mode(p.mode),
295 is_comptime: p.is_comptime,
296 })
297 .collect();
298 let (params_start, params_len) = self.rir.add_params(¶ms);
299
300 let body_start = InstRef::from_raw(self.rir.current_inst_index());
302
303 let body = self.gen_expr(&method.body);
305
306 let has_self = method.receiver.is_some();
308
309 let decl = self.rir.add_inst(Inst {
314 data: InstData::FnDecl {
315 directives_start,
316 directives_len,
317 is_pub: false,
318 is_unchecked: false,
319 name,
320 params_start,
321 params_len,
322 return_type,
323 body,
324 has_self,
325 },
326 span: method.span,
327 });
328
329 self.rir
332 .add_function_span(FunctionSpan::new(name, body_start, decl));
333
334 decl
335 }
336
337 fn convert_directives(&mut self, directives: &[Directive]) -> Vec<RirDirective> {
339 directives
340 .iter()
341 .map(|d| RirDirective {
342 name: d.name.name, args: d
344 .args
345 .iter()
346 .map(|arg| match arg {
347 DirectiveArg::Ident(ident) => ident.name, })
349 .collect(),
350 span: d.span,
351 })
352 .collect()
353 }
354
355 fn convert_param_mode(&self, mode: ParamMode) -> RirParamMode {
357 match mode {
358 ParamMode::Normal => RirParamMode::Normal,
359 ParamMode::Inout => RirParamMode::Inout,
360 ParamMode::Borrow => RirParamMode::Borrow,
361 ParamMode::Comptime => RirParamMode::Comptime,
362 }
363 }
364
365 fn convert_arg_mode(&self, mode: ArgMode) -> RirArgMode {
367 match mode {
368 ArgMode::Normal => RirArgMode::Normal,
369 ArgMode::Inout => RirArgMode::Inout,
370 ArgMode::Borrow => RirArgMode::Borrow,
371 }
372 }
373
374 fn convert_call_arg(&mut self, arg: &CallArg) -> RirCallArg {
376 RirCallArg {
377 value: self.gen_expr(&arg.expr),
378 mode: self.convert_arg_mode(arg.mode),
379 }
380 }
381
382 fn gen_function(&mut self, func: &Function) -> InstRef {
383 let directives = self.convert_directives(&func.directives);
385 let (directives_start, directives_len) = self.rir.add_directives(&directives);
386
387 let name = func.name.name; let return_type = match &func.return_type {
390 Some(ty) => self.intern_type(ty),
391 None => self.interner.get_or_intern("()"), };
393
394 let params: Vec<_> = func
396 .params
397 .iter()
398 .map(|p| RirParam {
399 name: p.name.name, ty: self.intern_type(&p.ty),
401 mode: self.convert_param_mode(p.mode),
402 is_comptime: p.is_comptime,
403 })
404 .collect();
405 let (params_start, params_len) = self.rir.add_params(¶ms);
406
407 let body_start = InstRef::from_raw(self.rir.current_inst_index());
409
410 let body = self.gen_expr(&func.body);
412
413 let decl = self.rir.add_inst(Inst {
416 data: InstData::FnDecl {
417 directives_start,
418 directives_len,
419 is_pub: func.visibility == Visibility::Public,
420 is_unchecked: func.is_unchecked,
421 name,
422 params_start,
423 params_len,
424 return_type,
425 body,
426 has_self: false,
427 },
428 span: func.span,
429 });
430
431 self.rir
433 .add_function_span(FunctionSpan::new(name, body_start, decl));
434
435 decl
436 }
437
438 fn gen_expr(&mut self, expr: &Expr) -> InstRef {
439 match expr {
440 Expr::Int(lit) => self.rir.add_inst(Inst {
441 data: InstData::IntConst(lit.value),
442 span: lit.span,
443 }),
444 Expr::Float(lit) => self.rir.add_inst(Inst {
445 data: InstData::FloatConst(lit.bits),
446 span: lit.span,
447 }),
448 Expr::Bool(lit) => self.rir.add_inst(Inst {
449 data: InstData::BoolConst(lit.value),
450 span: lit.span,
451 }),
452 Expr::String(lit) => {
453 self.rir.add_inst(Inst {
454 data: InstData::StringConst(lit.value), span: lit.span,
456 })
457 }
458 Expr::Unit(lit) => self.rir.add_inst(Inst {
459 data: InstData::UnitConst,
460 span: lit.span,
461 }),
462 Expr::Ident(ident) => {
463 self.rir.add_inst(Inst {
464 data: InstData::VarRef { name: ident.name }, span: ident.span,
466 })
467 }
468 Expr::Binary(bin) => {
469 let lhs = self.gen_expr(&bin.left);
470 let rhs = self.gen_expr(&bin.right);
471 let data = match bin.op {
472 BinaryOp::Add => InstData::Add { lhs, rhs },
473 BinaryOp::Sub => InstData::Sub { lhs, rhs },
474 BinaryOp::Mul => InstData::Mul { lhs, rhs },
475 BinaryOp::Div => InstData::Div { lhs, rhs },
476 BinaryOp::Mod => InstData::Mod { lhs, rhs },
477 BinaryOp::Eq => InstData::Eq { lhs, rhs },
478 BinaryOp::Ne => InstData::Ne { lhs, rhs },
479 BinaryOp::Lt => InstData::Lt { lhs, rhs },
480 BinaryOp::Gt => InstData::Gt { lhs, rhs },
481 BinaryOp::Le => InstData::Le { lhs, rhs },
482 BinaryOp::Ge => InstData::Ge { lhs, rhs },
483 BinaryOp::And => InstData::And { lhs, rhs },
484 BinaryOp::Or => InstData::Or { lhs, rhs },
485 BinaryOp::BitAnd => InstData::BitAnd { lhs, rhs },
486 BinaryOp::BitOr => InstData::BitOr { lhs, rhs },
487 BinaryOp::BitXor => InstData::BitXor { lhs, rhs },
488 BinaryOp::Shl => InstData::Shl { lhs, rhs },
489 BinaryOp::Shr => InstData::Shr { lhs, rhs },
490 };
491 self.rir.add_inst(Inst {
492 data,
493 span: bin.span,
494 })
495 }
496 Expr::Unary(un) => {
497 let operand = self.gen_expr(&un.operand);
498 let data = match un.op {
499 UnaryOp::Neg => InstData::Neg { operand },
500 UnaryOp::Not => InstData::Not { operand },
501 UnaryOp::BitNot => InstData::BitNot { operand },
502 };
503 self.rir.add_inst(Inst {
504 data,
505 span: un.span,
506 })
507 }
508 Expr::Paren(paren) => {
509 self.gen_expr(&paren.inner)
511 }
512 Expr::Block(block) => self.gen_block(block),
513 Expr::If(if_expr) => {
514 let cond = self.gen_expr(&if_expr.cond);
515 let then_block = self.gen_block(&if_expr.then_block);
516 let else_block = if_expr.else_block.as_ref().map(|b| self.gen_block(b));
517
518 self.rir.add_inst(Inst {
519 data: InstData::Branch {
520 cond,
521 then_block,
522 else_block,
523 },
524 span: if_expr.span,
525 })
526 }
527 Expr::While(while_expr) => {
528 let cond = self.gen_expr(&while_expr.cond);
529 let body = self.gen_block(&while_expr.body);
530 self.rir.add_inst(Inst {
531 data: InstData::Loop { cond, body },
532 span: while_expr.span,
533 })
534 }
535 Expr::For(for_expr) => {
536 let iterable = self.gen_expr(&for_expr.iterable);
537 let body = self.gen_block(&for_expr.body);
538 self.rir.add_inst(Inst {
539 data: InstData::For {
540 binding: for_expr.binding.name,
541 is_mut: for_expr.is_mut,
542 iterable,
543 body,
544 },
545 span: for_expr.span,
546 })
547 }
548 Expr::Loop(loop_expr) => {
549 let body = self.gen_block(&loop_expr.body);
550 self.rir.add_inst(Inst {
551 data: InstData::InfiniteLoop { body },
552 span: loop_expr.span,
553 })
554 }
555 Expr::Match(match_expr) => {
556 let scrutinee = self.gen_expr(&match_expr.scrutinee);
557 let arms: Vec<_> = match_expr
558 .arms
559 .iter()
560 .map(|arm| {
561 let pattern = self.gen_pattern(&arm.pattern);
562 let body = self.gen_expr(&arm.body);
563 (pattern, body)
564 })
565 .collect();
566 let (arms_start, arms_len) = self.rir.add_match_arms(&arms);
567
568 self.rir.add_inst(Inst {
569 data: InstData::Match {
570 scrutinee,
571 arms_start,
572 arms_len,
573 },
574 span: match_expr.span,
575 })
576 }
577 Expr::Call(call) => {
578 let args: Vec<_> = call.args.iter().map(|a| self.convert_call_arg(a)).collect();
579 let (args_start, args_len) = self.rir.add_call_args(&args);
580
581 self.rir.add_inst(Inst {
582 data: InstData::Call {
583 name: call.name.name, args_start,
585 args_len,
586 },
587 span: call.span,
588 })
589 }
590 Expr::Break(break_expr) => self.rir.add_inst(Inst {
591 data: InstData::Break,
592 span: break_expr.span,
593 }),
594 Expr::Continue(continue_expr) => self.rir.add_inst(Inst {
595 data: InstData::Continue,
596 span: continue_expr.span,
597 }),
598 Expr::Return(return_expr) => {
599 let value = return_expr.value.as_ref().map(|v| self.gen_expr(v));
600 self.rir.add_inst(Inst {
601 data: InstData::Ret(value),
602 span: return_expr.span,
603 })
604 }
605 Expr::StructLit(struct_lit) => {
606 let module = struct_lit
608 .base
609 .as_ref()
610 .map(|base_expr| self.gen_expr(base_expr));
611
612 let fields: Vec<_> = struct_lit
613 .fields
614 .iter()
615 .map(|f| {
616 let field_value = self.gen_expr(&f.value);
617 (f.name.name, field_value) })
619 .collect();
620 let (fields_start, fields_len) = self.rir.add_field_inits(&fields);
621
622 self.rir.add_inst(Inst {
623 data: InstData::StructInit {
624 module,
625 type_name: struct_lit.name.name, fields_start,
627 fields_len,
628 },
629 span: struct_lit.span,
630 })
631 }
632 Expr::EnumStructLit(lit) => {
633 let module = lit.base.as_ref().map(|base_expr| self.gen_expr(base_expr));
634
635 let fields: Vec<_> = lit
636 .fields
637 .iter()
638 .map(|f| {
639 let field_value = self.gen_expr(&f.value);
640 (f.name.name, field_value)
641 })
642 .collect();
643 let (fields_start, fields_len) = self.rir.add_field_inits(&fields);
644
645 self.rir.add_inst(Inst {
646 data: InstData::EnumStructVariant {
647 module,
648 type_name: lit.type_name.name,
649 variant: lit.variant.name,
650 fields_start,
651 fields_len,
652 },
653 span: lit.span,
654 })
655 }
656 Expr::Field(field_expr) => {
657 let base = self.gen_expr(&field_expr.base);
658
659 self.rir.add_inst(Inst {
660 data: InstData::FieldGet {
661 base,
662 field: field_expr.field.name, },
664 span: field_expr.span,
665 })
666 }
667 Expr::IntrinsicCall(intrinsic) => {
668 let name = intrinsic.name.name; let intrinsic_name_str = self.interner.resolve(&name);
670
671 let is_type_intrinsic = TYPE_INTRINSICS.contains(&intrinsic_name_str);
672
673 if is_type_intrinsic && intrinsic.args.len() == 1 {
674 if let IntrinsicArg::Type(ty) = &intrinsic.args[0] {
676 let type_arg = self.intern_type(ty);
677 return self.rir.add_inst(Inst {
678 data: InstData::TypeIntrinsic { name, type_arg },
679 span: intrinsic.span,
680 });
681 }
682
683 if let IntrinsicArg::Expr(Expr::Ident(ident)) = &intrinsic.args[0] {
686 return self.rir.add_inst(Inst {
687 data: InstData::TypeIntrinsic {
688 name,
689 type_arg: ident.name, },
691 span: intrinsic.span,
692 });
693 }
694 }
695
696 let args: Vec<_> = intrinsic
698 .args
699 .iter()
700 .filter_map(|a| match a {
701 IntrinsicArg::Expr(expr) => Some(self.gen_expr(expr)),
702 IntrinsicArg::Type(_) => None, })
704 .collect();
705 let (args_start, args_len) = self.rir.add_inst_refs(&args);
706
707 self.rir.add_inst(Inst {
708 data: InstData::Intrinsic {
709 name,
710 args_start,
711 args_len,
712 },
713 span: intrinsic.span,
714 })
715 }
716 Expr::ArrayLit(array_lit) => {
717 let elements: Vec<_> = array_lit
718 .elements
719 .iter()
720 .map(|e| self.gen_expr(e))
721 .collect();
722 let (elems_start, elems_len) = self.rir.add_inst_refs(&elements);
723
724 self.rir.add_inst(Inst {
725 data: InstData::ArrayInit {
726 elems_start,
727 elems_len,
728 },
729 span: array_lit.span,
730 })
731 }
732 Expr::Index(index_expr) => {
733 let base = self.gen_expr(&index_expr.base);
734 let index = self.gen_expr(&index_expr.index);
735
736 self.rir.add_inst(Inst {
737 data: InstData::IndexGet { base, index },
738 span: index_expr.span,
739 })
740 }
741 Expr::Path(path_expr) => {
742 let module = path_expr
744 .base
745 .as_ref()
746 .map(|base_expr| self.gen_expr(base_expr));
747
748 self.rir.add_inst(Inst {
749 data: InstData::EnumVariant {
750 module,
751 type_name: path_expr.type_name.name, variant: path_expr.variant.name, },
754 span: path_expr.span,
755 })
756 }
757 Expr::MethodCall(method_call) => {
758 let receiver = self.gen_expr(&method_call.receiver);
759 let args: Vec<_> = method_call
760 .args
761 .iter()
762 .map(|a| self.convert_call_arg(a))
763 .collect();
764 let (args_start, args_len) = self.rir.add_call_args(&args);
765
766 self.rir.add_inst(Inst {
767 data: InstData::MethodCall {
768 receiver,
769 method: method_call.method.name, args_start,
771 args_len,
772 },
773 span: method_call.span,
774 })
775 }
776 Expr::AssocFnCall(assoc_fn_call) => {
777 let args: Vec<_> = assoc_fn_call
778 .args
779 .iter()
780 .map(|a| self.convert_call_arg(a))
781 .collect();
782 let (args_start, args_len) = self.rir.add_call_args(&args);
783
784 self.rir.add_inst(Inst {
785 data: InstData::AssocFnCall {
786 type_name: assoc_fn_call.type_name.name, function: assoc_fn_call.function.name, args_start,
789 args_len,
790 },
791 span: assoc_fn_call.span,
792 })
793 }
794 Expr::SelfExpr(self_expr) => {
795 let name = self.interner.get_or_intern("self");
797 self.rir.add_inst(Inst {
798 data: InstData::VarRef { name },
799 span: self_expr.span,
800 })
801 }
802 Expr::Comptime(comptime_block) => {
803 let inner_expr = self.gen_expr(&comptime_block.expr);
806 self.rir.add_inst(Inst {
807 data: InstData::Comptime { expr: inner_expr },
808 span: comptime_block.span,
809 })
810 }
811 Expr::ComptimeUnrollFor(unroll) => {
812 let iterable = self.gen_expr(&unroll.iterable);
813 let body = self.gen_block(&unroll.body);
814 self.rir.add_inst(Inst {
815 data: InstData::ComptimeUnrollFor {
816 binding: unroll.binding.name,
817 iterable,
818 body,
819 },
820 span: unroll.span,
821 })
822 }
823 Expr::Checked(checked_block) => {
824 let inner_expr = self.gen_expr(&checked_block.expr);
827 self.rir.add_inst(Inst {
828 data: InstData::Checked { expr: inner_expr },
829 span: checked_block.span,
830 })
831 }
832 Expr::TypeLit(type_lit) => {
833 match &type_lit.type_expr {
835 TypeExpr::AnonymousStruct {
836 fields, methods, ..
837 } => {
838 let field_decls: Vec<(Spur, Spur)> = fields
840 .iter()
841 .map(|f| {
842 let name = f.name.name;
843 let ty = self.intern_type(&f.ty);
844 (name, ty)
845 })
846 .collect();
847 let (fields_start, fields_len) = self.rir.add_field_decls(&field_decls);
848
849 let method_refs: Vec<InstRef> =
852 methods.iter().map(|m| self.gen_method(m)).collect();
853 let (methods_start, methods_len) = self.rir.add_inst_refs(&method_refs);
854
855 self.rir.add_inst(Inst {
856 data: InstData::AnonStructType {
857 fields_start,
858 fields_len,
859 methods_start,
860 methods_len,
861 },
862 span: type_lit.span,
863 })
864 }
865 TypeExpr::AnonymousEnum {
866 variants, methods, ..
867 } => {
868 use gruel_parser::ast::EnumVariantKind;
870 let variant_decls: Vec<(Spur, Vec<Spur>, Vec<Spur>)> = variants
871 .iter()
872 .map(|v| {
873 let variant_name = v.name.name;
874 match &v.kind {
875 EnumVariantKind::Unit => (variant_name, vec![], vec![]),
876 EnumVariantKind::Tuple(types) => {
877 let field_types: Vec<Spur> =
878 types.iter().map(|ty| self.intern_type(ty)).collect();
879 (variant_name, field_types, vec![])
880 }
881 EnumVariantKind::Struct(fields) => {
882 let field_types: Vec<Spur> = fields
883 .iter()
884 .map(|f| self.intern_type(&f.ty))
885 .collect();
886 let field_names: Vec<Spur> =
887 fields.iter().map(|f| f.name.name).collect();
888 (variant_name, field_types, field_names)
889 }
890 }
891 })
892 .collect();
893 let (variants_start, variants_len) =
894 self.rir.add_enum_variant_decls(&variant_decls);
895
896 let method_refs: Vec<InstRef> =
898 methods.iter().map(|m| self.gen_method(m)).collect();
899 let (methods_start, methods_len) = self.rir.add_inst_refs(&method_refs);
900
901 self.rir.add_inst(Inst {
902 data: InstData::AnonEnumType {
903 variants_start,
904 variants_len,
905 methods_start,
906 methods_len,
907 },
908 span: type_lit.span,
909 })
910 }
911 _ => {
912 let type_name = match &type_lit.type_expr {
914 TypeExpr::Named(ident) => ident.name,
915 TypeExpr::Unit(_) => self.interner.get_or_intern_static("()"),
916 TypeExpr::Never(_) => self.interner.get_or_intern_static("!"),
917 TypeExpr::Array { .. } => {
918 self.interner.get_or_intern_static("array")
921 }
922 TypeExpr::AnonymousStruct { .. } | TypeExpr::AnonymousEnum { .. } => {
923 unreachable!("handled above")
924 }
925 TypeExpr::PointerConst { .. } | TypeExpr::PointerMut { .. } => {
926 self.intern_type(&type_lit.type_expr)
928 }
929 };
930 self.rir.add_inst(Inst {
931 data: InstData::TypeConst { type_name },
932 span: type_lit.span,
933 })
934 }
935 }
936 }
937 Expr::Error(span) => self.rir.add_inst(Inst {
940 data: InstData::UnitConst,
941 span: *span,
942 }),
943 }
944 }
945
946 fn gen_pattern(&mut self, pattern: &Pattern) -> RirPattern {
947 match pattern {
948 Pattern::Wildcard(span) => RirPattern::Wildcard(*span),
949 Pattern::Int(lit) => RirPattern::Int(lit.value as i64, lit.span),
950 Pattern::NegInt(lit) => RirPattern::Int((lit.value as i64).wrapping_neg(), lit.span),
952 Pattern::Bool(lit) => RirPattern::Bool(lit.value, lit.span),
953 Pattern::Path(path) => {
954 let module = path.base.as_ref().map(|base| self.gen_expr(base));
956 RirPattern::Path {
957 module,
958 type_name: path.type_name.name, variant: path.variant.name, span: path.span,
961 }
962 }
963 Pattern::DataVariant {
964 base,
965 type_name,
966 variant,
967 bindings,
968 span,
969 } => {
970 let module = base.as_ref().map(|b| self.gen_expr(b));
971 let rir_bindings = bindings
972 .iter()
973 .map(|b| match b {
974 PatternBinding::Wildcard(_) => RirPatternBinding {
975 is_wildcard: true,
976 is_mut: false,
977 name: None,
978 },
979 PatternBinding::Ident { is_mut, name } => RirPatternBinding {
980 is_wildcard: false,
981 is_mut: *is_mut,
982 name: Some(name.name),
983 },
984 })
985 .collect();
986 RirPattern::DataVariant {
987 module,
988 type_name: type_name.name,
989 variant: variant.name,
990 bindings: rir_bindings,
991 span: *span,
992 }
993 }
994 Pattern::StructVariant {
995 base,
996 type_name,
997 variant,
998 fields,
999 span,
1000 } => {
1001 let module = base.as_ref().map(|b| self.gen_expr(b));
1002 let field_bindings = fields
1003 .iter()
1004 .map(|fb| {
1005 let binding = match &fb.binding {
1006 PatternBinding::Wildcard(_) => RirPatternBinding {
1007 is_wildcard: true,
1008 is_mut: false,
1009 name: None,
1010 },
1011 PatternBinding::Ident { is_mut, name } => RirPatternBinding {
1012 is_wildcard: false,
1013 is_mut: *is_mut,
1014 name: Some(name.name),
1015 },
1016 };
1017 RirStructPatternBinding {
1018 field_name: fb.field_name.name,
1019 binding,
1020 }
1021 })
1022 .collect();
1023 RirPattern::StructVariant {
1024 module,
1025 type_name: type_name.name,
1026 variant: variant.name,
1027 field_bindings,
1028 span: *span,
1029 }
1030 }
1031 }
1032 }
1033
1034 fn gen_block(&mut self, block: &gruel_parser::BlockExpr) -> InstRef {
1035 if block.statements.is_empty() {
1036 self.gen_expr(&block.expr)
1038 } else {
1039 let mut inst_refs = Vec::with_capacity(block.statements.len() + 1);
1042
1043 for stmt in &block.statements {
1045 let inst_ref = self.gen_statement(stmt);
1046 inst_refs.push(inst_ref.as_u32());
1047 }
1048
1049 let final_expr = self.gen_expr(&block.expr);
1051 inst_refs.push(final_expr.as_u32());
1052
1053 let extra_start = self.rir.add_extra(&inst_refs);
1055 let len = inst_refs.len() as u32;
1056
1057 self.rir.add_inst(Inst {
1058 data: InstData::Block { extra_start, len },
1059 span: block.span,
1060 })
1061 }
1062 }
1063
1064 fn gen_statement(&mut self, stmt: &Statement) -> InstRef {
1065 match stmt {
1066 Statement::Let(let_stmt) => match &let_stmt.pattern {
1067 LetPattern::Struct {
1068 type_name, fields, ..
1069 } => {
1070 let rir_fields: Vec<RirDestructureField> = fields
1071 .iter()
1072 .map(|f| {
1073 let binding_name = match &f.binding {
1074 DestructureBinding::Shorthand => None,
1075 DestructureBinding::Renamed(ident) => Some(ident.name),
1076 DestructureBinding::Wildcard(_) => None,
1077 };
1078 let is_wildcard = matches!(&f.binding, DestructureBinding::Wildcard(_));
1079 RirDestructureField {
1080 field_name: f.field_name.name,
1081 binding_name,
1082 is_wildcard,
1083 is_mut: f.is_mut,
1084 }
1085 })
1086 .collect();
1087 let (fields_start, fields_len) = self.rir.add_destructure_fields(&rir_fields);
1088 let init = self.gen_expr(&let_stmt.init);
1089 self.rir.add_inst(Inst {
1090 data: InstData::StructDestructure {
1091 type_name: type_name.name,
1092 fields_start,
1093 fields_len,
1094 init,
1095 },
1096 span: let_stmt.span,
1097 })
1098 }
1099 pattern => {
1100 let directives = self.convert_directives(&let_stmt.directives);
1101 let (directives_start, directives_len) = self.rir.add_directives(&directives);
1102 let name = match pattern {
1103 LetPattern::Ident(ident) => Some(ident.name),
1104 LetPattern::Wildcard(_) => None,
1105 LetPattern::Struct { .. } => unreachable!(),
1106 };
1107 let ty = let_stmt.ty.as_ref().map(|t| self.intern_type(t));
1108 let init = self.gen_expr(&let_stmt.init);
1109 self.rir.add_inst(Inst {
1110 data: InstData::Alloc {
1111 directives_start,
1112 directives_len,
1113 name,
1114 is_mut: let_stmt.is_mut,
1115 ty,
1116 init,
1117 },
1118 span: let_stmt.span,
1119 })
1120 }
1121 },
1122 Statement::Assign(assign) => {
1123 let value = self.gen_expr(&assign.value);
1124 match &assign.target {
1125 AssignTarget::Var(ident) => {
1126 self.rir.add_inst(Inst {
1127 data: InstData::Assign {
1128 name: ident.name, value,
1130 },
1131 span: assign.span,
1132 })
1133 }
1134 AssignTarget::Field(field_expr) => {
1135 let base = self.gen_expr(&field_expr.base);
1136 self.rir.add_inst(Inst {
1137 data: InstData::FieldSet {
1138 base,
1139 field: field_expr.field.name, value,
1141 },
1142 span: assign.span,
1143 })
1144 }
1145 AssignTarget::Index(index_expr) => {
1146 let base = self.gen_expr(&index_expr.base);
1147 let index = self.gen_expr(&index_expr.index);
1148 self.rir.add_inst(Inst {
1149 data: InstData::IndexSet { base, index, value },
1150 span: assign.span,
1151 })
1152 }
1153 }
1154 }
1155 Statement::Expr(expr) => {
1156 self.gen_expr(expr)
1159 }
1160 }
1161 }
1162}
1163
1164#[cfg(test)]
1165mod tests {
1166 use super::*;
1167 use crate::inst::RirPrinter;
1168 use gruel_lexer::Lexer;
1169 use gruel_parser::Parser;
1170
1171 fn gen_rir(source: &str) -> (Rir, ThreadedRodeo) {
1172 let lexer = Lexer::new(source);
1173 let (tokens, interner) = lexer.tokenize().unwrap();
1174 let parser = Parser::new(tokens, interner);
1175 let (ast, interner) = parser.parse().unwrap();
1176
1177 let astgen = AstGen::new(&ast, &interner);
1178 let rir = astgen.generate();
1179 (rir, interner)
1180 }
1181
1182 #[test]
1183 fn test_gen_simple_function() {
1184 let (rir, interner) = gen_rir("fn main() -> i32 { 42 }");
1185
1186 assert_eq!(rir.len(), 2);
1188
1189 let (_, fn_inst) = rir.iter().last().unwrap();
1191 match &fn_inst.data {
1192 InstData::FnDecl {
1193 name,
1194 params_start,
1195 params_len,
1196 return_type,
1197 body,
1198 has_self,
1199 ..
1200 } => {
1201 assert_eq!(interner.resolve(name), "main");
1202 let params = rir.get_params(*params_start, *params_len);
1203 assert!(params.is_empty());
1204 assert_eq!(interner.resolve(return_type), "i32");
1205 assert!(!has_self); let body_inst = rir.get(*body);
1208 assert!(matches!(body_inst.data, InstData::IntConst(42)));
1209 }
1210 _ => panic!("expected FnDecl"),
1211 }
1212 }
1213
1214 #[test]
1215 fn test_gen_addition() {
1216 let (rir, _) = gen_rir("fn main() -> i32 { 1 + 2 }");
1217
1218 assert_eq!(rir.len(), 4);
1220
1221 let add_inst = rir.get(InstRef::from_raw(2));
1223 match &add_inst.data {
1224 InstData::Add { lhs, rhs } => {
1225 assert!(matches!(rir.get(*lhs).data, InstData::IntConst(1)));
1226 assert!(matches!(rir.get(*rhs).data, InstData::IntConst(2)));
1227 }
1228 _ => panic!("expected Add"),
1229 }
1230 }
1231
1232 #[test]
1233 fn test_gen_precedence() {
1234 let (rir, _) = gen_rir("fn main() -> i32 { 1 + 2 * 3 }");
1235
1236 assert_eq!(rir.len(), 6);
1238
1239 let fn_inst = rir.iter().last().unwrap().1;
1241 match &fn_inst.data {
1242 InstData::FnDecl { body, .. } => {
1243 let body_inst = rir.get(*body);
1244 match &body_inst.data {
1245 InstData::Add { lhs, rhs } => {
1246 assert!(matches!(rir.get(*lhs).data, InstData::IntConst(1)));
1248 assert!(matches!(rir.get(*rhs).data, InstData::Mul { .. }));
1250 }
1251 _ => panic!("expected Add"),
1252 }
1253 }
1254 _ => panic!("expected FnDecl"),
1255 }
1256 }
1257
1258 #[test]
1259 fn test_gen_negation() {
1260 let (rir, _) = gen_rir("fn main() -> i32 { -42 }");
1261
1262 assert_eq!(rir.len(), 3);
1264
1265 let neg_inst = rir.get(InstRef::from_raw(1));
1267 match &neg_inst.data {
1268 InstData::Neg { operand } => {
1269 assert!(matches!(rir.get(*operand).data, InstData::IntConst(42)));
1270 }
1271 _ => panic!("expected Neg"),
1272 }
1273 }
1274
1275 #[test]
1276 fn test_gen_parens() {
1277 let (rir, _) = gen_rir("fn main() -> i32 { (1 + 2) * 3 }");
1278
1279 assert_eq!(rir.len(), 6);
1282
1283 let fn_inst = rir.iter().last().unwrap().1;
1285 match &fn_inst.data {
1286 InstData::FnDecl { body, .. } => {
1287 let body_inst = rir.get(*body);
1288 match &body_inst.data {
1289 InstData::Mul { lhs, rhs } => {
1290 assert!(matches!(rir.get(*lhs).data, InstData::Add { .. }));
1292 assert!(matches!(rir.get(*rhs).data, InstData::IntConst(3)));
1294 }
1295 _ => panic!("expected Mul"),
1296 }
1297 }
1298 _ => panic!("expected FnDecl"),
1299 }
1300 }
1301
1302 #[test]
1303 fn test_gen_all_binary_ops() {
1304 let (rir, _) = gen_rir("fn main() -> i32 { 1 + 2 }");
1306 assert!(matches!(
1307 rir.get(InstRef::from_raw(2)).data,
1308 InstData::Add { .. }
1309 ));
1310
1311 let (rir, _) = gen_rir("fn main() -> i32 { 1 - 2 }");
1312 assert!(matches!(
1313 rir.get(InstRef::from_raw(2)).data,
1314 InstData::Sub { .. }
1315 ));
1316
1317 let (rir, _) = gen_rir("fn main() -> i32 { 1 * 2 }");
1318 assert!(matches!(
1319 rir.get(InstRef::from_raw(2)).data,
1320 InstData::Mul { .. }
1321 ));
1322
1323 let (rir, _) = gen_rir("fn main() -> i32 { 1 / 2 }");
1324 assert!(matches!(
1325 rir.get(InstRef::from_raw(2)).data,
1326 InstData::Div { .. }
1327 ));
1328
1329 let (rir, _) = gen_rir("fn main() -> i32 { 1 % 2 }");
1330 assert!(matches!(
1331 rir.get(InstRef::from_raw(2)).data,
1332 InstData::Mod { .. }
1333 ));
1334 }
1335
1336 #[test]
1337 fn test_gen_let_binding() {
1338 let (rir, interner) = gen_rir("fn main() -> i32 { let x = 42; x }");
1339
1340 let alloc_inst = rir
1342 .iter()
1343 .find(|(_, inst)| matches!(inst.data, InstData::Alloc { .. }));
1344 assert!(alloc_inst.is_some());
1345
1346 let (_, inst) = alloc_inst.unwrap();
1347 match &inst.data {
1348 InstData::Alloc {
1349 name,
1350 is_mut,
1351 ty,
1352 init,
1353 ..
1354 } => {
1355 assert_eq!(interner.resolve(&name.unwrap()), "x");
1356 assert!(!is_mut);
1357 assert!(ty.is_none());
1358 assert!(matches!(rir.get(*init).data, InstData::IntConst(42)));
1359 }
1360 _ => panic!("expected Alloc"),
1361 }
1362 }
1363
1364 #[test]
1365 fn test_gen_let_mut() {
1366 let (rir, interner) = gen_rir("fn main() -> i32 { let mut x = 10; x }");
1367
1368 let alloc_inst = rir
1369 .iter()
1370 .find(|(_, inst)| matches!(inst.data, InstData::Alloc { .. }));
1371 assert!(alloc_inst.is_some());
1372
1373 let (_, inst) = alloc_inst.unwrap();
1374 match &inst.data {
1375 InstData::Alloc { name, is_mut, .. } => {
1376 assert_eq!(interner.resolve(&name.unwrap()), "x");
1377 assert!(*is_mut);
1378 }
1379 _ => panic!("expected Alloc"),
1380 }
1381 }
1382
1383 #[test]
1384 fn test_gen_var_ref() {
1385 let (rir, interner) = gen_rir("fn main() -> i32 { let x = 42; x }");
1386
1387 let fn_inst = rir.iter().last().unwrap().1;
1389 match &fn_inst.data {
1390 InstData::FnDecl { body, .. } => {
1391 let body_inst = rir.get(*body);
1392 match &body_inst.data {
1393 InstData::Block { extra_start, len } => {
1394 assert_eq!(*len, 2);
1396 let inst_refs = rir.get_extra(*extra_start, *len);
1397 let var_ref_inst = rir.get(InstRef::from_raw(inst_refs[1]));
1399 match &var_ref_inst.data {
1400 InstData::VarRef { name } => {
1401 assert_eq!(interner.resolve(name), "x");
1402 }
1403 _ => panic!("expected VarRef"),
1404 }
1405 }
1406 _ => panic!("expected Block, got {:?}", body_inst.data),
1407 }
1408 }
1409 _ => panic!("expected FnDecl"),
1410 }
1411 }
1412
1413 #[test]
1414 fn test_gen_assignment() {
1415 let (rir, interner) = gen_rir("fn main() -> i32 { let mut x = 10; x = 20; x }");
1416
1417 let assign_inst = rir
1419 .iter()
1420 .find(|(_, inst)| matches!(inst.data, InstData::Assign { .. }));
1421 assert!(assign_inst.is_some());
1422
1423 let (_, inst) = assign_inst.unwrap();
1424 match &inst.data {
1425 InstData::Assign { name, value } => {
1426 assert_eq!(interner.resolve(name), "x");
1427 assert!(matches!(rir.get(*value).data, InstData::IntConst(20)));
1428 }
1429 _ => panic!("expected Assign"),
1430 }
1431 }
1432
1433 #[test]
1434 fn test_gen_multiple_statements() {
1435 let (rir, _interner) = gen_rir("fn main() -> i32 { let x = 1; let y = 2; x + y }");
1436
1437 let alloc_count = rir
1439 .iter()
1440 .filter(|(_, inst)| matches!(inst.data, InstData::Alloc { .. }))
1441 .count();
1442 assert_eq!(alloc_count, 2);
1443
1444 let fn_inst = rir.iter().last().unwrap().1;
1446 match &fn_inst.data {
1447 InstData::FnDecl { body, .. } => {
1448 let body_inst = rir.get(*body);
1449 match &body_inst.data {
1450 InstData::Block { extra_start, len } => {
1451 assert_eq!(*len, 3);
1453 let inst_refs = rir.get_extra(*extra_start, *len);
1454 let add_inst = rir.get(InstRef::from_raw(inst_refs[2]));
1456 assert!(matches!(add_inst.data, InstData::Add { .. }));
1457 }
1458 _ => panic!("expected Block"),
1459 }
1460 }
1461 _ => panic!("expected FnDecl"),
1462 }
1463 }
1464
1465 #[test]
1467 fn test_gen_struct_with_method() {
1468 let source = r#"
1469 struct Point {
1470 x: i32,
1471 y: i32,
1472 fn get_x(self) -> i32 {
1473 self.x
1474 }
1475 }
1476 fn main() -> i32 { 0 }
1477 "#;
1478 let (rir, interner) = gen_rir(source);
1479
1480 let struct_decl = rir
1482 .iter()
1483 .find(|(_, inst)| matches!(inst.data, InstData::StructDecl { .. }));
1484 assert!(struct_decl.is_some(), "Expected StructDecl instruction");
1485
1486 let (_, inst) = struct_decl.unwrap();
1487 match &inst.data {
1488 InstData::StructDecl {
1489 name,
1490 methods_start,
1491 methods_len,
1492 ..
1493 } => {
1494 assert_eq!(interner.resolve(name), "Point");
1495 let methods = rir.get_inst_refs(*methods_start, *methods_len);
1496 assert_eq!(methods.len(), 1);
1497
1498 let method_inst = rir.get(methods[0]);
1500 match &method_inst.data {
1501 InstData::FnDecl { name, has_self, .. } => {
1502 assert_eq!(interner.resolve(name), "get_x");
1503 assert!(*has_self);
1504 }
1505 _ => panic!("expected FnDecl"),
1506 }
1507 }
1508 _ => panic!("expected StructDecl"),
1509 }
1510 }
1511
1512 #[test]
1513 fn test_gen_struct_with_multiple_methods() {
1514 let source = r#"
1515 struct Point {
1516 x: i32,
1517 y: i32,
1518 fn get_x(self) -> i32 { self.x }
1519 fn get_y(self) -> i32 { self.y }
1520 fn origin() -> Point { Point { x: 0, y: 0 } }
1521 }
1522 fn main() -> i32 { 0 }
1523 "#;
1524 let (rir, interner) = gen_rir(source);
1525
1526 let struct_decl = rir
1527 .iter()
1528 .find(|(_, inst)| matches!(inst.data, InstData::StructDecl { .. }));
1529 assert!(struct_decl.is_some());
1530
1531 let (_, inst) = struct_decl.unwrap();
1532 match &inst.data {
1533 InstData::StructDecl {
1534 methods_start,
1535 methods_len,
1536 ..
1537 } => {
1538 let methods = rir.get_inst_refs(*methods_start, *methods_len);
1539 assert_eq!(methods.len(), 3);
1540
1541 for method_ref in methods {
1543 let method_inst = rir.get(method_ref);
1544 match &method_inst.data {
1545 InstData::FnDecl { name, has_self, .. } => {
1546 let method_name = interner.resolve(name);
1547 if method_name == "origin" {
1548 assert!(!has_self, "origin should not have self");
1549 } else {
1550 assert!(*has_self, "{} should have self", method_name);
1551 }
1552 }
1553 _ => panic!("expected FnDecl"),
1554 }
1555 }
1556 }
1557 _ => panic!("expected StructDecl"),
1558 }
1559 }
1560
1561 #[test]
1562 fn test_gen_method_call() {
1563 let source = r#"
1564 struct Point {
1565 x: i32,
1566 fn get_x(self) -> i32 { self.x }
1567 }
1568 fn main() -> i32 {
1569 let p = Point { x: 42 };
1570 p.get_x()
1571 }
1572 "#;
1573 let (rir, interner) = gen_rir(source);
1574
1575 let method_call = rir
1577 .iter()
1578 .find(|(_, inst)| matches!(inst.data, InstData::MethodCall { .. }));
1579 assert!(method_call.is_some(), "Expected MethodCall instruction");
1580
1581 let (_, inst) = method_call.unwrap();
1582 match &inst.data {
1583 InstData::MethodCall {
1584 receiver: _,
1585 method,
1586 args_start,
1587 args_len,
1588 } => {
1589 assert_eq!(interner.resolve(method), "get_x");
1590 let args = rir.get_call_args(*args_start, *args_len);
1591 assert!(args.is_empty()); }
1593 _ => panic!("expected MethodCall"),
1594 }
1595 }
1596
1597 #[test]
1598 fn test_gen_assoc_fn_call() {
1599 let source = r#"
1600 struct Point {
1601 x: i32,
1602 y: i32,
1603 fn origin() -> Point { Point { x: 0, y: 0 } }
1604 }
1605 fn main() -> i32 {
1606 let p = Point::origin();
1607 0
1608 }
1609 "#;
1610 let (rir, interner) = gen_rir(source);
1611
1612 let assoc_fn_call = rir
1614 .iter()
1615 .find(|(_, inst)| matches!(inst.data, InstData::AssocFnCall { .. }));
1616 assert!(assoc_fn_call.is_some(), "Expected AssocFnCall instruction");
1617
1618 let (_, inst) = assoc_fn_call.unwrap();
1619 match &inst.data {
1620 InstData::AssocFnCall {
1621 type_name,
1622 function,
1623 args_start,
1624 args_len,
1625 } => {
1626 assert_eq!(interner.resolve(type_name), "Point");
1627 assert_eq!(interner.resolve(function), "origin");
1628 let args = rir.get_call_args(*args_start, *args_len);
1629 assert!(args.is_empty());
1630 }
1631 _ => panic!("expected AssocFnCall"),
1632 }
1633 }
1634
1635 #[test]
1637 fn test_gen_match_wildcard_pattern() {
1638 let source = r#"
1639 fn main() -> i32 {
1640 let x = 5;
1641 match x {
1642 _ => 42,
1643 }
1644 }
1645 "#;
1646 let (rir, _interner) = gen_rir(source);
1647
1648 let match_inst = rir
1650 .iter()
1651 .find(|(_, inst)| matches!(inst.data, InstData::Match { .. }));
1652 assert!(match_inst.is_some(), "Expected Match instruction");
1653
1654 let (_, inst) = match_inst.unwrap();
1655 match &inst.data {
1656 InstData::Match {
1657 arms_start,
1658 arms_len,
1659 ..
1660 } => {
1661 let arms = rir.get_match_arms(*arms_start, *arms_len);
1662 assert_eq!(arms.len(), 1);
1663 assert!(matches!(arms[0].0, RirPattern::Wildcard(_)));
1664 }
1665 _ => panic!("expected Match"),
1666 }
1667 }
1668
1669 #[test]
1670 fn test_gen_match_int_patterns() {
1671 let source = r#"
1672 fn main() -> i32 {
1673 let x = 5;
1674 match x {
1675 1 => 10,
1676 2 => 20,
1677 _ => 0,
1678 }
1679 }
1680 "#;
1681 let (rir, _interner) = gen_rir(source);
1682
1683 let match_inst = rir
1684 .iter()
1685 .find(|(_, inst)| matches!(inst.data, InstData::Match { .. }));
1686 assert!(match_inst.is_some());
1687
1688 let (_, inst) = match_inst.unwrap();
1689 match &inst.data {
1690 InstData::Match {
1691 arms_start,
1692 arms_len,
1693 ..
1694 } => {
1695 let arms = rir.get_match_arms(*arms_start, *arms_len);
1696 assert_eq!(arms.len(), 3);
1697 assert!(matches!(arms[0].0, RirPattern::Int(1, _)));
1698 assert!(matches!(arms[1].0, RirPattern::Int(2, _)));
1699 assert!(matches!(arms[2].0, RirPattern::Wildcard(_)));
1700 }
1701 _ => panic!("expected Match"),
1702 }
1703 }
1704
1705 #[test]
1706 fn test_gen_match_negative_int_pattern() {
1707 let source = r#"
1708 fn main() -> i32 {
1709 let x: i32 = -5;
1710 match x {
1711 -5 => 1,
1712 -10 => 2,
1713 _ => 0,
1714 }
1715 }
1716 "#;
1717 let (rir, _interner) = gen_rir(source);
1718
1719 let match_inst = rir
1720 .iter()
1721 .find(|(_, inst)| matches!(inst.data, InstData::Match { .. }));
1722 assert!(match_inst.is_some());
1723
1724 let (_, inst) = match_inst.unwrap();
1725 match &inst.data {
1726 InstData::Match {
1727 arms_start,
1728 arms_len,
1729 ..
1730 } => {
1731 let arms = rir.get_match_arms(*arms_start, *arms_len);
1732 assert_eq!(arms.len(), 3);
1733 assert!(matches!(arms[0].0, RirPattern::Int(-5, _)));
1734 assert!(matches!(arms[1].0, RirPattern::Int(-10, _)));
1735 assert!(matches!(arms[2].0, RirPattern::Wildcard(_)));
1736 }
1737 _ => panic!("expected Match"),
1738 }
1739 }
1740
1741 #[test]
1742 fn test_gen_match_bool_patterns() {
1743 let source = r#"
1744 fn main() -> i32 {
1745 let b = true;
1746 match b {
1747 true => 1,
1748 false => 0,
1749 }
1750 }
1751 "#;
1752 let (rir, _interner) = gen_rir(source);
1753
1754 let match_inst = rir
1755 .iter()
1756 .find(|(_, inst)| matches!(inst.data, InstData::Match { .. }));
1757 assert!(match_inst.is_some());
1758
1759 let (_, inst) = match_inst.unwrap();
1760 match &inst.data {
1761 InstData::Match {
1762 arms_start,
1763 arms_len,
1764 ..
1765 } => {
1766 let arms = rir.get_match_arms(*arms_start, *arms_len);
1767 assert_eq!(arms.len(), 2);
1768 assert!(matches!(arms[0].0, RirPattern::Bool(true, _)));
1769 assert!(matches!(arms[1].0, RirPattern::Bool(false, _)));
1770 }
1771 _ => panic!("expected Match"),
1772 }
1773 }
1774
1775 #[test]
1776 fn test_gen_match_enum_patterns() {
1777 let source = r#"
1778 enum Color { Red, Green, Blue }
1779 fn main() -> i32 {
1780 let c = Color::Red;
1781 match c {
1782 Color::Red => 1,
1783 Color::Green => 2,
1784 Color::Blue => 3,
1785 }
1786 }
1787 "#;
1788 let (rir, interner) = gen_rir(source);
1789
1790 let match_inst = rir
1791 .iter()
1792 .find(|(_, inst)| matches!(inst.data, InstData::Match { .. }));
1793 assert!(match_inst.is_some());
1794
1795 let (_, inst) = match_inst.unwrap();
1796 match &inst.data {
1797 InstData::Match {
1798 arms_start,
1799 arms_len,
1800 ..
1801 } => {
1802 let arms = rir.get_match_arms(*arms_start, *arms_len);
1803 assert_eq!(arms.len(), 3);
1804
1805 match &arms[0].0 {
1807 RirPattern::Path {
1808 type_name, variant, ..
1809 } => {
1810 assert_eq!(interner.resolve(type_name), "Color");
1811 assert_eq!(interner.resolve(variant), "Red");
1812 }
1813 _ => panic!("expected Path pattern"),
1814 }
1815
1816 match &arms[1].0 {
1818 RirPattern::Path {
1819 type_name, variant, ..
1820 } => {
1821 assert_eq!(interner.resolve(type_name), "Color");
1822 assert_eq!(interner.resolve(variant), "Green");
1823 }
1824 _ => panic!("expected Path pattern"),
1825 }
1826
1827 match &arms[2].0 {
1829 RirPattern::Path {
1830 type_name, variant, ..
1831 } => {
1832 assert_eq!(interner.resolve(type_name), "Color");
1833 assert_eq!(interner.resolve(variant), "Blue");
1834 }
1835 _ => panic!("expected Path pattern"),
1836 }
1837 }
1838 _ => panic!("expected Match"),
1839 }
1840 }
1841
1842 #[test]
1843 fn test_gen_self_expr() {
1844 let source = r#"
1845 struct Point {
1846 x: i32,
1847 fn get_x(self) -> i32 { self.x }
1848 }
1849 fn main() -> i32 { 0 }
1850 "#;
1851 let (rir, interner) = gen_rir(source);
1852
1853 let self_ref = rir.iter().find(|(_, inst)| match &inst.data {
1855 InstData::VarRef { name } => interner.resolve(name) == "self",
1856 _ => false,
1857 });
1858 assert!(self_ref.is_some(), "Expected self VarRef instruction");
1859 }
1860
1861 #[test]
1862 fn test_gen_drop_fn() {
1863 let source = r#"
1864 struct Resource { value: i32 }
1865 drop fn Resource(self) { () }
1866 fn main() -> i32 { 0 }
1867 "#;
1868 let (rir, interner) = gen_rir(source);
1869
1870 let drop_fn = rir
1872 .iter()
1873 .find(|(_, inst)| matches!(inst.data, InstData::DropFnDecl { .. }));
1874 assert!(drop_fn.is_some(), "Expected DropFnDecl instruction");
1875
1876 let (_, inst) = drop_fn.unwrap();
1877 match &inst.data {
1878 InstData::DropFnDecl { type_name, body: _ } => {
1879 assert_eq!(interner.resolve(type_name), "Resource");
1880 }
1881 _ => panic!("expected DropFnDecl"),
1882 }
1883 }
1884
1885 #[test]
1886 fn test_gen_enum_variant() {
1887 let source = r#"
1888 enum Color { Red, Green, Blue }
1889 fn main() -> i32 {
1890 let c = Color::Red;
1891 0
1892 }
1893 "#;
1894 let (rir, interner) = gen_rir(source);
1895
1896 let enum_variant = rir
1898 .iter()
1899 .find(|(_, inst)| matches!(inst.data, InstData::EnumVariant { .. }));
1900 assert!(enum_variant.is_some(), "Expected EnumVariant instruction");
1901
1902 let (_, inst) = enum_variant.unwrap();
1903 match &inst.data {
1904 InstData::EnumVariant {
1905 type_name, variant, ..
1906 } => {
1907 assert_eq!(interner.resolve(type_name), "Color");
1908 assert_eq!(interner.resolve(variant), "Red");
1909 }
1910 _ => panic!("expected EnumVariant"),
1911 }
1912 }
1913
1914 #[test]
1915 fn test_gen_method_with_params() {
1916 let source = r#"
1917 struct Counter {
1918 value: i32,
1919 fn add(self, amount: i32) -> i32 { self.value + amount }
1920 }
1921 fn main() -> i32 { 0 }
1922 "#;
1923 let (rir, interner) = gen_rir(source);
1924
1925 let struct_decl = rir
1927 .iter()
1928 .find(|(_, inst)| matches!(inst.data, InstData::StructDecl { .. }));
1929 assert!(struct_decl.is_some());
1930
1931 let (_, inst) = struct_decl.unwrap();
1932 match &inst.data {
1933 InstData::StructDecl {
1934 methods_start,
1935 methods_len,
1936 ..
1937 } => {
1938 let methods = rir.get_inst_refs(*methods_start, *methods_len);
1939 let method_inst = rir.get(methods[0]);
1940 match &method_inst.data {
1941 InstData::FnDecl {
1942 name,
1943 params_start,
1944 params_len,
1945 has_self,
1946 ..
1947 } => {
1948 assert_eq!(interner.resolve(name), "add");
1949 assert!(*has_self);
1950 let params = rir.get_params(*params_start, *params_len);
1952 assert_eq!(params.len(), 1);
1953 assert_eq!(interner.resolve(¶ms[0].name), "amount");
1954 }
1955 _ => panic!("expected FnDecl"),
1956 }
1957 }
1958 _ => panic!("expected StructDecl"),
1959 }
1960 }
1961
1962 #[test]
1964 fn test_printer_integration() {
1965 let source = r#"
1966 struct Point {
1967 x: i32,
1968 y: i32,
1969 fn origin() -> Point { Point { x: 0, y: 0 } }
1970 }
1971 fn main() -> i32 {
1972 let p = Point::origin();
1973 p.x
1974 }
1975 "#;
1976 let (rir, interner) = gen_rir(source);
1977
1978 let printer = RirPrinter::new(&rir, &interner);
1979 let output = printer.to_string();
1980
1981 assert!(output.contains("struct Point"));
1983 assert!(output.contains("methods: ["));
1984 assert!(output.contains("fn origin"));
1985 assert!(output.contains("fn main"));
1986 assert!(output.contains("struct_init Point"));
1987 assert!(output.contains("assoc_fn_call Point::origin"));
1988 assert!(output.contains("field_get"));
1989 }
1990
1991 #[test]
1994 fn test_function_spans_simple() {
1995 let (rir, interner) = gen_rir("fn main() -> i32 { 42 }");
1996
1997 assert_eq!(rir.function_count(), 1);
1999
2000 let spans: Vec<_> = rir.functions().collect();
2001 assert_eq!(spans.len(), 1);
2002
2003 let span = &spans[0];
2004 assert_eq!(interner.resolve(&span.name), "main");
2005
2006 assert_eq!(span.instruction_count(), 2);
2008
2009 let fn_inst = rir.get(span.decl);
2011 assert!(matches!(fn_inst.data, InstData::FnDecl { .. }));
2012 }
2013
2014 #[test]
2015 fn test_function_spans_multiple_functions() {
2016 let source = r#"
2017 fn helper() -> i32 { 1 }
2018 fn main() -> i32 { 42 }
2019 "#;
2020 let (rir, interner) = gen_rir(source);
2021
2022 assert_eq!(rir.function_count(), 2);
2024
2025 let spans: Vec<_> = rir.functions().collect();
2026 assert_eq!(spans.len(), 2);
2027
2028 assert_eq!(interner.resolve(&spans[0].name), "helper");
2030 assert_eq!(spans[0].instruction_count(), 2);
2031
2032 assert_eq!(interner.resolve(&spans[1].name), "main");
2034 assert_eq!(spans[1].instruction_count(), 2);
2035
2036 assert!(
2038 spans[0].decl.as_u32() < spans[1].body_start.as_u32(),
2039 "helper should end before main starts"
2040 );
2041 }
2042
2043 #[test]
2044 fn test_function_spans_with_methods() {
2045 let source = r#"
2046 struct Point {
2047 x: i32,
2048 fn get_x(self) -> i32 { self.x }
2049 fn origin() -> Point { Point { x: 0 } }
2050 }
2051 fn main() -> i32 { 0 }
2052 "#;
2053 let (rir, interner) = gen_rir(source);
2054
2055 assert_eq!(rir.function_count(), 3);
2057
2058 let spans: Vec<_> = rir.functions().collect();
2059
2060 let names: Vec<_> = spans.iter().map(|s| interner.resolve(&s.name)).collect();
2062 assert!(names.contains(&"get_x"));
2063 assert!(names.contains(&"origin"));
2064 assert!(names.contains(&"main"));
2065 }
2066
2067 #[test]
2068 fn test_function_view() {
2069 let source = r#"
2070 fn helper(x: i32) -> i32 { x + 1 }
2071 fn main() -> i32 { helper(41) }
2072 "#;
2073 let (rir, interner) = gen_rir(source);
2074
2075 let main_span = rir.find_function(interner.get_or_intern("main")).unwrap();
2077
2078 let view = rir.function_view(main_span);
2080
2081 assert_eq!(view.len(), main_span.instruction_count() as usize);
2083
2084 let fn_decl = view.fn_decl();
2086 match &fn_decl.data {
2087 InstData::FnDecl { name, .. } => {
2088 assert_eq!(interner.resolve(name), "main");
2089 }
2090 _ => panic!("Expected FnDecl"),
2091 }
2092
2093 let mut found_call = false;
2095 for (_, inst) in view.iter() {
2096 if matches!(inst.data, InstData::Call { .. }) {
2097 found_call = true;
2098 }
2099 }
2100 assert!(found_call, "main should contain a call to helper");
2101 }
2102
2103 #[test]
2104 fn test_function_span_complex_body() {
2105 let source = r#"
2106 fn complex() -> i32 {
2107 let x = 1;
2108 let y = 2;
2109 if x < y {
2110 x + y
2111 } else {
2112 x - y
2113 }
2114 }
2115 "#;
2116 let (rir, interner) = gen_rir(source);
2117
2118 assert_eq!(rir.function_count(), 1);
2119
2120 let span = rir
2121 .find_function(interner.get_or_intern("complex"))
2122 .unwrap();
2123
2124 assert!(
2127 span.instruction_count() >= 8,
2128 "Complex function should have at least 8 instructions, got {}",
2129 span.instruction_count()
2130 );
2131
2132 let view = rir.function_view(span);
2134 let mut has_alloc = false;
2135 let mut has_branch = false;
2136
2137 for (_, inst) in view.iter() {
2138 if matches!(inst.data, InstData::Alloc { .. }) {
2139 has_alloc = true;
2140 }
2141 if matches!(inst.data, InstData::Branch { .. }) {
2142 has_branch = true;
2143 }
2144 }
2145
2146 assert!(has_alloc, "Function should have Alloc instructions");
2147 assert!(has_branch, "Function should have Branch instruction");
2148 }
2149
2150 #[test]
2151 fn test_find_function() {
2152 let source = r#"
2153 fn foo() -> i32 { 1 }
2154 fn bar() -> i32 { 2 }
2155 fn baz() -> i32 { 3 }
2156 "#;
2157 let (rir, interner) = gen_rir(source);
2158
2159 let foo_sym = interner.get_or_intern("foo");
2161 let bar_sym = interner.get_or_intern("bar");
2162 let baz_sym = interner.get_or_intern("baz");
2163 let nonexistent_sym = interner.get_or_intern("nonexistent");
2164
2165 assert!(rir.find_function(foo_sym).is_some());
2166 assert!(rir.find_function(bar_sym).is_some());
2167 assert!(rir.find_function(baz_sym).is_some());
2168 assert!(rir.find_function(nonexistent_sym).is_none());
2169 }
2170
2171 #[test]
2172 fn test_function_span_ordering() {
2173 let source = r#"
2174 fn a() -> i32 { 1 }
2175 fn b() -> i32 { 2 }
2176 fn c() -> i32 { 3 }
2177 "#;
2178 let (rir, _interner) = gen_rir(source);
2179
2180 let spans: Vec<_> = rir.functions().collect();
2181 assert_eq!(spans.len(), 3);
2182
2183 for i in 1..spans.len() {
2185 assert!(
2186 spans[i - 1].decl.as_u32() < spans[i].body_start.as_u32(),
2187 "Function {} should end before function {} starts",
2188 i - 1,
2189 i
2190 );
2191 }
2192 }
2193
2194 #[test]
2195 fn test_anon_struct_with_methods() {
2196 let source = r#"
2198 fn MakePoint(comptime T: type) -> type {
2199 struct {
2200 x: T,
2201 y: T,
2202
2203 fn get_x(self) -> T { self.x }
2204 fn origin() -> Self { Self { x: 0, y: 0 } }
2205 }
2206 }
2207 fn main() -> i32 { 0 }
2208 "#;
2209 let (rir, interner) = gen_rir(source);
2210
2211 let anon_struct = rir
2213 .iter()
2214 .find(|(_, inst)| matches!(inst.data, InstData::AnonStructType { .. }));
2215 assert!(
2216 anon_struct.is_some(),
2217 "Expected to find AnonStructType instruction"
2218 );
2219
2220 let (_, inst) = anon_struct.unwrap();
2221 match &inst.data {
2222 InstData::AnonStructType {
2223 fields_start,
2224 fields_len,
2225 methods_start,
2226 methods_len,
2227 } => {
2228 let fields = rir.get_field_decls(*fields_start, *fields_len);
2230 assert_eq!(fields.len(), 2);
2231 assert_eq!(interner.resolve(&fields[0].0), "x");
2232 assert_eq!(interner.resolve(&fields[1].0), "y");
2233
2234 assert_eq!(*methods_len, 2);
2236 let methods = rir.get_inst_refs(*methods_start, *methods_len);
2237 assert_eq!(methods.len(), 2);
2238
2239 for method_ref in methods {
2241 let method_inst = rir.get(method_ref);
2242 match &method_inst.data {
2243 InstData::FnDecl { name, has_self, .. } => {
2244 let name_str = interner.resolve(name);
2245 if name_str == "get_x" {
2247 assert!(*has_self, "get_x should have self parameter");
2248 } else if name_str == "origin" {
2249 assert!(!*has_self, "origin should not have self parameter");
2250 }
2251 }
2252 _ => panic!("Expected FnDecl for method"),
2253 }
2254 }
2255 }
2256 _ => panic!("Expected AnonStructType"),
2257 }
2258 }
2259
2260 #[test]
2261 fn test_anon_struct_without_methods() {
2262 let source = r#"
2264 fn MakePair(comptime T: type) -> type {
2265 struct { first: T, second: T }
2266 }
2267 fn main() -> i32 { 0 }
2268 "#;
2269 let (rir, _interner) = gen_rir(source);
2270
2271 let anon_struct = rir
2273 .iter()
2274 .find(|(_, inst)| matches!(inst.data, InstData::AnonStructType { .. }));
2275 assert!(
2276 anon_struct.is_some(),
2277 "Expected to find AnonStructType instruction"
2278 );
2279
2280 let (_, inst) = anon_struct.unwrap();
2281 match &inst.data {
2282 InstData::AnonStructType { methods_len, .. } => {
2283 assert_eq!(*methods_len, 0, "Expected no methods");
2284 }
2285 _ => panic!("Expected AnonStructType"),
2286 }
2287 }
2288
2289 #[test]
2290 fn test_anon_struct_method_function_spans() {
2291 let source = r#"
2293 fn Container(comptime T: type) -> type {
2294 struct {
2295 value: T,
2296 fn get(self) -> T { self.value }
2297 fn set(self, v: T) -> Self { Self { value: v } }
2298 }
2299 }
2300 fn main() -> i32 { 0 }
2301 "#;
2302 let (rir, interner) = gen_rir(source);
2303
2304 assert_eq!(
2306 rir.function_count(),
2307 4,
2308 "Expected 4 functions (Container, get, set, main)"
2309 );
2310
2311 let get_sym = interner.get_or_intern("get");
2313 let set_sym = interner.get_or_intern("set");
2314 assert!(
2315 rir.find_function(get_sym).is_some(),
2316 "Should find 'get' method"
2317 );
2318 assert!(
2319 rir.find_function(set_sym).is_some(),
2320 "Should find 'set' method"
2321 );
2322 }
2323}