gruel_air/inference/
constraint.rs

1//! Constraint types and substitution for type inference.
2//!
3//! This module provides:
4//! - [`Constraint`] - Type constraints generated during analysis
5//! - [`Substitution`] - Mapping from type variables to resolved types
6
7use super::types::{InferType, TypeVarId};
8use gruel_span::Span;
9use std::cell::RefCell;
10
11/// A type constraint generated during analysis.
12///
13/// Constraints express relationships between types that must hold.
14/// They are collected during constraint generation and then solved
15/// during unification.
16#[derive(Debug, Clone)]
17pub enum Constraint {
18    /// Two types must be equal: τ₁ = τ₂.
19    ///
20    /// This is the primary constraint type. Generated for:
21    /// - Binary operations (both operands must have same type)
22    /// - Assignments (value type must match variable type)
23    /// - Function calls (argument types must match parameter types)
24    /// - Return statements (returned type must match declared return type)
25    Equal(InferType, InferType, Span),
26
27    /// Type must be a signed integer: τ ∈ {i8, i16, i32, i64}.
28    ///
29    /// Generated for unary negation which requires signed types.
30    /// Unsigned integers cannot be negated.
31    IsSigned(InferType, Span),
32
33    /// Type must be an integer (signed or unsigned): τ ∈ {i8, i16, i32, i64, u8, u16, u32, u64}.
34    ///
35    /// Generated for bitwise NOT which works on any integer type.
36    IsInteger(InferType, Span),
37
38    /// Type must be an unsigned integer: τ ∈ {u8, u16, u32, u64}.
39    ///
40    /// Generated for array indexing which requires non-negative indices.
41    IsUnsigned(InferType, Span),
42
43    /// Type must be numeric (integer or float): τ ∈ {i8..i64, u8..u64, isize, usize, f16..f64}.
44    ///
45    /// Generated for arithmetic operators (+, -, *, /, %) which work on both
46    /// integer and floating-point types.
47    IsNumeric(InferType, Span),
48}
49
50impl Constraint {
51    /// Create an equality constraint.
52    pub fn equal(lhs: InferType, rhs: InferType, span: Span) -> Self {
53        Constraint::Equal(lhs, rhs, span)
54    }
55
56    /// Create a "must be signed" constraint.
57    pub fn is_signed(ty: InferType, span: Span) -> Self {
58        Constraint::IsSigned(ty, span)
59    }
60
61    /// Create a "must be integer" constraint.
62    pub fn is_integer(ty: InferType, span: Span) -> Self {
63        Constraint::IsInteger(ty, span)
64    }
65
66    /// Create a "must be unsigned" constraint.
67    pub fn is_unsigned(ty: InferType, span: Span) -> Self {
68        Constraint::IsUnsigned(ty, span)
69    }
70
71    /// Create a "must be numeric" constraint (integer or float).
72    pub fn is_numeric(ty: InferType, span: Span) -> Self {
73        Constraint::IsNumeric(ty, span)
74    }
75
76    /// Get the span for this constraint (for error reporting).
77    pub fn span(&self) -> Span {
78        match self {
79            Constraint::Equal(_, _, span)
80            | Constraint::IsSigned(_, span)
81            | Constraint::IsInteger(_, span)
82            | Constraint::IsUnsigned(_, span)
83            | Constraint::IsNumeric(_, span) => *span,
84        }
85    }
86}
87
88/// A substitution mapping type variables to their resolved types.
89///
90/// The substitution is built incrementally during unification.
91/// It maps type variable IDs to `InferType`s, which may themselves
92/// be type variables (requiring transitive lookup via `apply`).
93///
94/// # Performance
95///
96/// Uses a `Vec<Option<InferType>>` instead of `HashMap<TypeVarId, InferType>` for O(1)
97/// lookups without hashing overhead. This works because `TypeVarId` is a sequential
98/// `u32` starting from 0.
99///
100/// Additionally implements path compression: when following a chain of variable
101/// references, intermediate links are updated to point directly to the final result,
102/// amortizing the cost of chain traversal.
103#[derive(Debug, Default)]
104pub struct Substitution {
105    /// Mapping from type variable index to its resolved type.
106    /// Uses `RefCell` to allow path compression during immutable lookups.
107    mapping: RefCell<Vec<Option<InferType>>>,
108}
109
110impl Substitution {
111    /// Create an empty substitution.
112    pub fn new() -> Self {
113        Substitution {
114            mapping: RefCell::new(Vec::new()),
115        }
116    }
117
118    /// Create a substitution with pre-allocated capacity.
119    ///
120    /// Use when you know approximately how many type variables will be created.
121    pub fn with_capacity(capacity: usize) -> Self {
122        let mut mapping = Vec::with_capacity(capacity);
123        // Pre-fill with None to allow direct indexing
124        mapping.resize(capacity, None);
125        Substitution {
126            mapping: RefCell::new(mapping),
127        }
128    }
129
130    /// Insert a mapping from a type variable to a type.
131    ///
132    /// If the variable is already mapped, the old mapping is replaced.
133    pub fn insert(&mut self, var: TypeVarId, ty: InferType) {
134        let idx = var.index() as usize;
135        let mut mapping = self.mapping.borrow_mut();
136        // Grow the vector if necessary
137        if idx >= mapping.len() {
138            mapping.resize(idx + 1, None);
139        }
140        mapping[idx] = Some(ty);
141    }
142
143    /// Look up a type variable's immediate mapping (without following chains).
144    pub fn get(&self, var: TypeVarId) -> Option<InferType> {
145        let idx = var.index() as usize;
146        let mapping = self.mapping.borrow();
147        if idx < mapping.len() {
148            mapping[idx].clone()
149        } else {
150            None
151        }
152    }
153
154    /// Apply the substitution to a type, following type variable chains
155    /// to their ultimate resolution.
156    ///
157    /// - `Concrete(ty)` → `Concrete(ty)` (unchanged)
158    /// - `Var(id)` → follows chain until concrete or unbound variable
159    /// - `IntLiteral` → `IntLiteral` (unchanged, unless we add IntLiteral
160    ///   to variable mappings)
161    /// - `Array { element, length }` → recursively apply to element type
162    ///
163    /// # Path Compression
164    ///
165    /// When following a chain like `v0 -> v1 -> v2 -> i32`, this method
166    /// updates all intermediate links to point directly to the final result.
167    /// This amortizes the cost of repeated lookups.
168    pub fn apply(&self, ty: &InferType) -> InferType {
169        match ty {
170            InferType::Concrete(_) => ty.clone(),
171            InferType::Var(id) => self.apply_var(*id),
172            InferType::IntLiteral | InferType::FloatLiteral => ty.clone(),
173            InferType::Array { element, length } => {
174                let resolved_element = self.apply(element);
175                InferType::Array {
176                    element: Box::new(resolved_element),
177                    length: *length,
178                }
179            }
180        }
181    }
182
183    /// Apply substitution to a type variable with path compression.
184    fn apply_var(&self, id: TypeVarId) -> InferType {
185        let idx = id.index() as usize;
186
187        // First, get the resolved type without holding the borrow
188        let resolved = {
189            let mapping = self.mapping.borrow();
190            if idx >= mapping.len() {
191                return InferType::Var(id);
192            }
193            match &mapping[idx] {
194                None => return InferType::Var(id),
195                Some(ty) => ty.clone(),
196            }
197        };
198
199        // Recursively resolve
200        let final_type = self.apply(&resolved);
201
202        // Path compression: if we followed a chain to reach a different result,
203        // update this mapping to point directly to the final result.
204        // This avoids traversing the same chain repeatedly.
205        if final_type != resolved {
206            let mut mapping = self.mapping.borrow_mut();
207            if idx < mapping.len() {
208                mapping[idx] = Some(final_type.clone());
209            }
210        }
211
212        final_type
213    }
214
215    /// Check if a type variable occurs in a type (for occurs check).
216    ///
217    /// This prevents creating infinite types like `α = List<α>`.
218    /// Returns `true` if the variable occurs in the type.
219    pub fn occurs_in(&self, var: TypeVarId, ty: &InferType) -> bool {
220        match ty {
221            InferType::Concrete(_) => false,
222            InferType::Var(id) => {
223                if *id == var {
224                    return true;
225                }
226                // Check if the variable chain leads to our target
227                match self.get(*id) {
228                    Some(resolved) => self.occurs_in(var, &resolved),
229                    None => false,
230                }
231            }
232            InferType::IntLiteral | InferType::FloatLiteral => false,
233            InferType::Array { element, .. } => self.occurs_in(var, element),
234        }
235    }
236
237    /// Get the number of mappings in the substitution.
238    ///
239    /// Note: This counts all slots that have values, requiring a full scan.
240    /// For performance-critical code, consider tracking this separately.
241    pub fn len(&self) -> usize {
242        self.mapping.borrow().iter().filter(|c| c.is_some()).count()
243    }
244
245    /// Check if the substitution is empty.
246    pub fn is_empty(&self) -> bool {
247        self.mapping.borrow().iter().all(|c| c.is_none())
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::Type;
255
256    #[test]
257    fn test_substitution_apply_concrete() {
258        let subst = Substitution::new();
259        let ty = InferType::Concrete(Type::I64);
260        assert_eq!(subst.apply(&ty), InferType::Concrete(Type::I64));
261    }
262
263    #[test]
264    fn test_substitution_apply_unbound_var() {
265        let subst = Substitution::new();
266        let v0 = TypeVarId::new(0);
267        let ty = InferType::Var(v0);
268        // Unbound variable returns itself
269        assert_eq!(subst.apply(&ty), InferType::Var(v0));
270    }
271
272    #[test]
273    fn test_substitution_apply_bound_var() {
274        let mut subst = Substitution::new();
275        let v0 = TypeVarId::new(0);
276        subst.insert(v0, InferType::Concrete(Type::BOOL));
277
278        let ty = InferType::Var(v0);
279        assert_eq!(subst.apply(&ty), InferType::Concrete(Type::BOOL));
280    }
281
282    #[test]
283    fn test_substitution_apply_chain() {
284        let mut subst = Substitution::new();
285        let v0 = TypeVarId::new(0);
286        let v1 = TypeVarId::new(1);
287        let v2 = TypeVarId::new(2);
288
289        // Create chain: v0 -> v1 -> v2 -> i32
290        subst.insert(v0, InferType::Var(v1));
291        subst.insert(v1, InferType::Var(v2));
292        subst.insert(v2, InferType::Concrete(Type::I32));
293
294        assert_eq!(
295            subst.apply(&InferType::Var(v0)),
296            InferType::Concrete(Type::I32)
297        );
298    }
299
300    #[test]
301    fn test_occurs_check_simple() {
302        let subst = Substitution::new();
303        let v0 = TypeVarId::new(0);
304
305        // Variable occurs in itself
306        assert!(subst.occurs_in(v0, &InferType::Var(v0)));
307
308        // Variable doesn't occur in different variable
309        assert!(!subst.occurs_in(v0, &InferType::Var(TypeVarId::new(1))));
310
311        // Variable doesn't occur in concrete type
312        assert!(!subst.occurs_in(v0, &InferType::Concrete(Type::I32)));
313    }
314
315    #[test]
316    fn test_occurs_check_through_chain() {
317        let mut subst = Substitution::new();
318        let v0 = TypeVarId::new(0);
319        let v1 = TypeVarId::new(1);
320
321        // Create chain: v1 -> v0
322        subst.insert(v1, InferType::Var(v0));
323
324        // v0 occurs in v1 (through substitution)
325        assert!(subst.occurs_in(v0, &InferType::Var(v1)));
326    }
327
328    #[test]
329    fn test_constraint_creation() {
330        let span = Span::new(10, 20);
331        let c1 = Constraint::equal(InferType::Concrete(Type::I32), InferType::IntLiteral, span);
332        let c2 = Constraint::is_signed(InferType::Var(TypeVarId::new(0)), span);
333
334        assert_eq!(c1.span(), span);
335        assert_eq!(c2.span(), span);
336    }
337
338    #[test]
339    fn test_substitution_with_capacity() {
340        let subst = Substitution::with_capacity(10);
341        // Pre-allocated substitution should be empty (no mappings yet)
342        assert!(subst.is_empty());
343        assert_eq!(subst.len(), 0);
344    }
345
346    #[test]
347    fn test_substitution_path_compression() {
348        // Create a long chain: v0 -> v1 -> v2 -> v3 -> v4 -> i32
349        let mut subst = Substitution::new();
350        let v0 = TypeVarId::new(0);
351        let v1 = TypeVarId::new(1);
352        let v2 = TypeVarId::new(2);
353        let v3 = TypeVarId::new(3);
354        let v4 = TypeVarId::new(4);
355
356        subst.insert(v0, InferType::Var(v1));
357        subst.insert(v1, InferType::Var(v2));
358        subst.insert(v2, InferType::Var(v3));
359        subst.insert(v3, InferType::Var(v4));
360        subst.insert(v4, InferType::Concrete(Type::I32));
361
362        // First lookup should resolve the chain
363        assert_eq!(
364            subst.apply(&InferType::Var(v0)),
365            InferType::Concrete(Type::I32)
366        );
367
368        // After path compression, all intermediate variables should point directly to i32
369        // Verify by checking that v0 now directly points to i32
370        let resolved = subst.get(v0);
371        assert_eq!(resolved, Some(InferType::Concrete(Type::I32)));
372
373        // v1, v2, v3 should also be compressed
374        assert_eq!(subst.get(v1), Some(InferType::Concrete(Type::I32)));
375        assert_eq!(subst.get(v2), Some(InferType::Concrete(Type::I32)));
376        assert_eq!(subst.get(v3), Some(InferType::Concrete(Type::I32)));
377    }
378
379    #[test]
380    fn test_substitution_len_and_is_empty() {
381        let mut subst = Substitution::new();
382        assert!(subst.is_empty());
383        assert_eq!(subst.len(), 0);
384
385        subst.insert(TypeVarId::new(0), InferType::Concrete(Type::I32));
386        assert!(!subst.is_empty());
387        assert_eq!(subst.len(), 1);
388
389        // Insert at a higher index - only counts actual mappings
390        subst.insert(TypeVarId::new(5), InferType::Concrete(Type::BOOL));
391        assert_eq!(subst.len(), 2);
392    }
393}