gruel_air/inference/constraint.rs
1//! Constraint types and substitution for type inference.
2//!
3//! This module provides:
4//! - [`Constraint`] - Type constraints generated during analysis
5//! - [`Substitution`] - Mapping from type variables to resolved types
6
7use super::types::{InferType, TypeVarId};
8use gruel_span::Span;
9use std::cell::RefCell;
10
11/// A type constraint generated during analysis.
12///
13/// Constraints express relationships between types that must hold.
14/// They are collected during constraint generation and then solved
15/// during unification.
16#[derive(Debug, Clone)]
17pub enum Constraint {
18 /// Two types must be equal: τ₁ = τ₂.
19 ///
20 /// This is the primary constraint type. Generated for:
21 /// - Binary operations (both operands must have same type)
22 /// - Assignments (value type must match variable type)
23 /// - Function calls (argument types must match parameter types)
24 /// - Return statements (returned type must match declared return type)
25 Equal(InferType, InferType, Span),
26
27 /// Type must be a signed integer: τ ∈ {i8, i16, i32, i64}.
28 ///
29 /// Generated for unary negation which requires signed types.
30 /// Unsigned integers cannot be negated.
31 IsSigned(InferType, Span),
32
33 /// Type must be an integer (signed or unsigned): τ ∈ {i8, i16, i32, i64, u8, u16, u32, u64}.
34 ///
35 /// Generated for bitwise NOT which works on any integer type.
36 IsInteger(InferType, Span),
37
38 /// Type must be an unsigned integer: τ ∈ {u8, u16, u32, u64}.
39 ///
40 /// Generated for array indexing which requires non-negative indices.
41 IsUnsigned(InferType, Span),
42
43 /// Type must be numeric (integer or float): τ ∈ {i8..i64, u8..u64, isize, usize, f16..f64}.
44 ///
45 /// Generated for arithmetic operators (+, -, *, /, %) which work on both
46 /// integer and floating-point types.
47 IsNumeric(InferType, Span),
48}
49
50impl Constraint {
51 /// Create an equality constraint.
52 pub fn equal(lhs: InferType, rhs: InferType, span: Span) -> Self {
53 Constraint::Equal(lhs, rhs, span)
54 }
55
56 /// Create a "must be signed" constraint.
57 pub fn is_signed(ty: InferType, span: Span) -> Self {
58 Constraint::IsSigned(ty, span)
59 }
60
61 /// Create a "must be integer" constraint.
62 pub fn is_integer(ty: InferType, span: Span) -> Self {
63 Constraint::IsInteger(ty, span)
64 }
65
66 /// Create a "must be unsigned" constraint.
67 pub fn is_unsigned(ty: InferType, span: Span) -> Self {
68 Constraint::IsUnsigned(ty, span)
69 }
70
71 /// Create a "must be numeric" constraint (integer or float).
72 pub fn is_numeric(ty: InferType, span: Span) -> Self {
73 Constraint::IsNumeric(ty, span)
74 }
75
76 /// Get the span for this constraint (for error reporting).
77 pub fn span(&self) -> Span {
78 match self {
79 Constraint::Equal(_, _, span)
80 | Constraint::IsSigned(_, span)
81 | Constraint::IsInteger(_, span)
82 | Constraint::IsUnsigned(_, span)
83 | Constraint::IsNumeric(_, span) => *span,
84 }
85 }
86}
87
88/// A substitution mapping type variables to their resolved types.
89///
90/// The substitution is built incrementally during unification.
91/// It maps type variable IDs to `InferType`s, which may themselves
92/// be type variables (requiring transitive lookup via `apply`).
93///
94/// # Performance
95///
96/// Uses a `Vec<Option<InferType>>` instead of `HashMap<TypeVarId, InferType>` for O(1)
97/// lookups without hashing overhead. This works because `TypeVarId` is a sequential
98/// `u32` starting from 0.
99///
100/// Additionally implements path compression: when following a chain of variable
101/// references, intermediate links are updated to point directly to the final result,
102/// amortizing the cost of chain traversal.
103#[derive(Debug, Default)]
104pub struct Substitution {
105 /// Mapping from type variable index to its resolved type.
106 /// Uses `RefCell` to allow path compression during immutable lookups.
107 mapping: RefCell<Vec<Option<InferType>>>,
108}
109
110impl Substitution {
111 /// Create an empty substitution.
112 pub fn new() -> Self {
113 Substitution {
114 mapping: RefCell::new(Vec::new()),
115 }
116 }
117
118 /// Create a substitution with pre-allocated capacity.
119 ///
120 /// Use when you know approximately how many type variables will be created.
121 pub fn with_capacity(capacity: usize) -> Self {
122 let mut mapping = Vec::with_capacity(capacity);
123 // Pre-fill with None to allow direct indexing
124 mapping.resize(capacity, None);
125 Substitution {
126 mapping: RefCell::new(mapping),
127 }
128 }
129
130 /// Insert a mapping from a type variable to a type.
131 ///
132 /// If the variable is already mapped, the old mapping is replaced.
133 pub fn insert(&mut self, var: TypeVarId, ty: InferType) {
134 let idx = var.index() as usize;
135 let mut mapping = self.mapping.borrow_mut();
136 // Grow the vector if necessary
137 if idx >= mapping.len() {
138 mapping.resize(idx + 1, None);
139 }
140 mapping[idx] = Some(ty);
141 }
142
143 /// Look up a type variable's immediate mapping (without following chains).
144 pub fn get(&self, var: TypeVarId) -> Option<InferType> {
145 let idx = var.index() as usize;
146 let mapping = self.mapping.borrow();
147 if idx < mapping.len() {
148 mapping[idx].clone()
149 } else {
150 None
151 }
152 }
153
154 /// Apply the substitution to a type, following type variable chains
155 /// to their ultimate resolution.
156 ///
157 /// - `Concrete(ty)` → `Concrete(ty)` (unchanged)
158 /// - `Var(id)` → follows chain until concrete or unbound variable
159 /// - `IntLiteral` → `IntLiteral` (unchanged, unless we add IntLiteral
160 /// to variable mappings)
161 /// - `Array { element, length }` → recursively apply to element type
162 ///
163 /// # Path Compression
164 ///
165 /// When following a chain like `v0 -> v1 -> v2 -> i32`, this method
166 /// updates all intermediate links to point directly to the final result.
167 /// This amortizes the cost of repeated lookups.
168 pub fn apply(&self, ty: &InferType) -> InferType {
169 match ty {
170 InferType::Concrete(_) => ty.clone(),
171 InferType::Var(id) => self.apply_var(*id),
172 InferType::IntLiteral | InferType::FloatLiteral => ty.clone(),
173 InferType::Array { element, length } => {
174 let resolved_element = self.apply(element);
175 InferType::Array {
176 element: Box::new(resolved_element),
177 length: *length,
178 }
179 }
180 }
181 }
182
183 /// Apply substitution to a type variable with path compression.
184 fn apply_var(&self, id: TypeVarId) -> InferType {
185 let idx = id.index() as usize;
186
187 // First, get the resolved type without holding the borrow
188 let resolved = {
189 let mapping = self.mapping.borrow();
190 if idx >= mapping.len() {
191 return InferType::Var(id);
192 }
193 match &mapping[idx] {
194 None => return InferType::Var(id),
195 Some(ty) => ty.clone(),
196 }
197 };
198
199 // Recursively resolve
200 let final_type = self.apply(&resolved);
201
202 // Path compression: if we followed a chain to reach a different result,
203 // update this mapping to point directly to the final result.
204 // This avoids traversing the same chain repeatedly.
205 if final_type != resolved {
206 let mut mapping = self.mapping.borrow_mut();
207 if idx < mapping.len() {
208 mapping[idx] = Some(final_type.clone());
209 }
210 }
211
212 final_type
213 }
214
215 /// Check if a type variable occurs in a type (for occurs check).
216 ///
217 /// This prevents creating infinite types like `α = List<α>`.
218 /// Returns `true` if the variable occurs in the type.
219 pub fn occurs_in(&self, var: TypeVarId, ty: &InferType) -> bool {
220 match ty {
221 InferType::Concrete(_) => false,
222 InferType::Var(id) => {
223 if *id == var {
224 return true;
225 }
226 // Check if the variable chain leads to our target
227 match self.get(*id) {
228 Some(resolved) => self.occurs_in(var, &resolved),
229 None => false,
230 }
231 }
232 InferType::IntLiteral | InferType::FloatLiteral => false,
233 InferType::Array { element, .. } => self.occurs_in(var, element),
234 }
235 }
236
237 /// Get the number of mappings in the substitution.
238 ///
239 /// Note: This counts all slots that have values, requiring a full scan.
240 /// For performance-critical code, consider tracking this separately.
241 pub fn len(&self) -> usize {
242 self.mapping.borrow().iter().filter(|c| c.is_some()).count()
243 }
244
245 /// Check if the substitution is empty.
246 pub fn is_empty(&self) -> bool {
247 self.mapping.borrow().iter().all(|c| c.is_none())
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::Type;
255
256 #[test]
257 fn test_substitution_apply_concrete() {
258 let subst = Substitution::new();
259 let ty = InferType::Concrete(Type::I64);
260 assert_eq!(subst.apply(&ty), InferType::Concrete(Type::I64));
261 }
262
263 #[test]
264 fn test_substitution_apply_unbound_var() {
265 let subst = Substitution::new();
266 let v0 = TypeVarId::new(0);
267 let ty = InferType::Var(v0);
268 // Unbound variable returns itself
269 assert_eq!(subst.apply(&ty), InferType::Var(v0));
270 }
271
272 #[test]
273 fn test_substitution_apply_bound_var() {
274 let mut subst = Substitution::new();
275 let v0 = TypeVarId::new(0);
276 subst.insert(v0, InferType::Concrete(Type::BOOL));
277
278 let ty = InferType::Var(v0);
279 assert_eq!(subst.apply(&ty), InferType::Concrete(Type::BOOL));
280 }
281
282 #[test]
283 fn test_substitution_apply_chain() {
284 let mut subst = Substitution::new();
285 let v0 = TypeVarId::new(0);
286 let v1 = TypeVarId::new(1);
287 let v2 = TypeVarId::new(2);
288
289 // Create chain: v0 -> v1 -> v2 -> i32
290 subst.insert(v0, InferType::Var(v1));
291 subst.insert(v1, InferType::Var(v2));
292 subst.insert(v2, InferType::Concrete(Type::I32));
293
294 assert_eq!(
295 subst.apply(&InferType::Var(v0)),
296 InferType::Concrete(Type::I32)
297 );
298 }
299
300 #[test]
301 fn test_occurs_check_simple() {
302 let subst = Substitution::new();
303 let v0 = TypeVarId::new(0);
304
305 // Variable occurs in itself
306 assert!(subst.occurs_in(v0, &InferType::Var(v0)));
307
308 // Variable doesn't occur in different variable
309 assert!(!subst.occurs_in(v0, &InferType::Var(TypeVarId::new(1))));
310
311 // Variable doesn't occur in concrete type
312 assert!(!subst.occurs_in(v0, &InferType::Concrete(Type::I32)));
313 }
314
315 #[test]
316 fn test_occurs_check_through_chain() {
317 let mut subst = Substitution::new();
318 let v0 = TypeVarId::new(0);
319 let v1 = TypeVarId::new(1);
320
321 // Create chain: v1 -> v0
322 subst.insert(v1, InferType::Var(v0));
323
324 // v0 occurs in v1 (through substitution)
325 assert!(subst.occurs_in(v0, &InferType::Var(v1)));
326 }
327
328 #[test]
329 fn test_constraint_creation() {
330 let span = Span::new(10, 20);
331 let c1 = Constraint::equal(InferType::Concrete(Type::I32), InferType::IntLiteral, span);
332 let c2 = Constraint::is_signed(InferType::Var(TypeVarId::new(0)), span);
333
334 assert_eq!(c1.span(), span);
335 assert_eq!(c2.span(), span);
336 }
337
338 #[test]
339 fn test_substitution_with_capacity() {
340 let subst = Substitution::with_capacity(10);
341 // Pre-allocated substitution should be empty (no mappings yet)
342 assert!(subst.is_empty());
343 assert_eq!(subst.len(), 0);
344 }
345
346 #[test]
347 fn test_substitution_path_compression() {
348 // Create a long chain: v0 -> v1 -> v2 -> v3 -> v4 -> i32
349 let mut subst = Substitution::new();
350 let v0 = TypeVarId::new(0);
351 let v1 = TypeVarId::new(1);
352 let v2 = TypeVarId::new(2);
353 let v3 = TypeVarId::new(3);
354 let v4 = TypeVarId::new(4);
355
356 subst.insert(v0, InferType::Var(v1));
357 subst.insert(v1, InferType::Var(v2));
358 subst.insert(v2, InferType::Var(v3));
359 subst.insert(v3, InferType::Var(v4));
360 subst.insert(v4, InferType::Concrete(Type::I32));
361
362 // First lookup should resolve the chain
363 assert_eq!(
364 subst.apply(&InferType::Var(v0)),
365 InferType::Concrete(Type::I32)
366 );
367
368 // After path compression, all intermediate variables should point directly to i32
369 // Verify by checking that v0 now directly points to i32
370 let resolved = subst.get(v0);
371 assert_eq!(resolved, Some(InferType::Concrete(Type::I32)));
372
373 // v1, v2, v3 should also be compressed
374 assert_eq!(subst.get(v1), Some(InferType::Concrete(Type::I32)));
375 assert_eq!(subst.get(v2), Some(InferType::Concrete(Type::I32)));
376 assert_eq!(subst.get(v3), Some(InferType::Concrete(Type::I32)));
377 }
378
379 #[test]
380 fn test_substitution_len_and_is_empty() {
381 let mut subst = Substitution::new();
382 assert!(subst.is_empty());
383 assert_eq!(subst.len(), 0);
384
385 subst.insert(TypeVarId::new(0), InferType::Concrete(Type::I32));
386 assert!(!subst.is_empty());
387 assert_eq!(subst.len(), 1);
388
389 // Insert at a higher index - only counts actual mappings
390 subst.insert(TypeVarId::new(5), InferType::Concrete(Type::BOOL));
391 assert_eq!(subst.len(), 2);
392 }
393}