diff --git a/libs/@local/hashql/mir/src/body/operand.rs b/libs/@local/hashql/mir/src/body/operand.rs index 30cd4f2a516..8652917601d 100644 --- a/libs/@local/hashql/mir/src/body/operand.rs +++ b/libs/@local/hashql/mir/src/body/operand.rs @@ -35,6 +35,30 @@ pub enum Operand<'heap> { Constant(Constant<'heap>), } +impl<'heap> Operand<'heap> { + /// Returns the contained [`Place`] if this operand is a place reference. + /// + /// Returns [`None`] if this operand is a constant. + #[must_use] + pub const fn as_place(&self) -> Option<&Place<'heap>> { + match self { + Operand::Place(place) => Some(place), + Operand::Constant(_) => None, + } + } + + /// Returns the contained [`Constant`] if this operand is an immediate value. + /// + /// Returns [`None`] if this operand is a place reference. + #[must_use] + pub const fn as_constant(&self) -> Option<&Constant<'heap>> { + match self { + Operand::Constant(constant) => Some(constant), + Operand::Place(_) => None, + } + } +} + impl<'heap> From> for Operand<'heap> { fn from(place: Place<'heap>) -> Self { Operand::Place(place) diff --git a/libs/@local/hashql/mir/src/pass/transform/cp/mod.rs b/libs/@local/hashql/mir/src/pass/transform/cp/mod.rs new file mode 100644 index 00000000000..36f0bf75f4f --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/transform/cp/mod.rs @@ -0,0 +1,275 @@ +#[cfg(test)] +mod tests; + +use core::{alloc::Allocator, convert::Infallible}; + +use hashql_core::{ + graph::Predecessors as _, + heap::{BumpAllocator, Scratch, TransferInto as _}, + id::IdVec, +}; + +use crate::{ + body::{ + Body, + basic_block::BasicBlockId, + constant::Constant, + local::{Local, LocalVec}, + location::Location, + operand::Operand, + rvalue::RValue, + statement::Assign, + }, + context::MirContext, + intern::Interner, + pass::TransformPass, + visit::{self, VisitorMut, r#mut::filter}, +}; + +/// Propagates constant values through block parameters by analyzing predecessor branches. +/// +/// For each block parameter, examines all predecessor branches that target this block. +/// If all predecessors pass the same constant value (as determined by `eval`), that value +/// is passed to `insert` for the corresponding block parameter. +/// +/// # Type Parameters +/// +/// - `T`: The constant value type being propagated (e.g., `Constant<'heap>`, `Int`). +/// - `A`: The allocator for the scratch buffer. +/// +/// # Parameters +/// +/// - `args`: Scratch buffer for collecting argument values. Must be empty on entry; will be drained +/// before return. +/// - `body`: The MIR body containing the CFG. +/// - `id`: The basic block whose parameters are being analyzed. +/// - `eval`: Closure that evaluates an operand to `Some(T)` if it represents a known constant, or +/// `None` otherwise. This is called for each argument in predecessor branch targets. +/// - `insert`: Closure called for each block parameter that has a consistent constant value across +/// all predecessors. Receives the parameter's local and the constant value. +/// +/// # Algorithm +/// +/// 1. Skips blocks with effectful predecessors (e.g., `GraphRead`), as their arguments are implicit +/// and not inspectable. +/// 2. Collects all explicit branch targets from predecessors that jump to this block. +/// 3. For each parameter position, computes the "meet" of all argument values: if all predecessors +/// pass the same constant, that constant is propagated; otherwise, no constant is recorded. +/// +/// # Limitations +/// +/// - Does not perform fix-point iteration for loops. Constants on back-edges may not be discovered +/// because predecessors forming back-edges have not been visited when the loop header is +/// processed. +/// - Blocks reachable only via implicit edges (entry blocks, effectful continuations) have no +/// explicit targets to analyze. +#[expect( + clippy::iter_on_single_items, + clippy::iter_on_empty_collections, + reason = "impl return type" +)] +pub(crate) fn propagate_block_params<'args, 'heap: 'args, T, A, E>( + args: &'args mut Vec, A>, + body: &Body<'heap>, + id: BasicBlockId, + mut eval: E, +) -> impl IntoIterator + 'args +where + T: Copy + Eq, + A: Allocator, + E: FnMut(Operand<'heap>) -> Option, +{ + let pred = body.basic_blocks.predecessors(id); + + // Effectful terminators (like GraphRead) pass arguments implicitly, where they set the + // block param directly. We cannot inspect those values, so we conservatively skip + // propagation for blocks reachable from effectful predecessors (they have single + // successors). + if pred + .clone() + .any(|pred| body.basic_blocks[pred].terminator.kind.is_effectful()) + { + return None.into_iter().flatten(); + } + + // Collect all predecessor targets that branch to this block. A single predecessor + // may have multiple targets to us (e.g., a switch with two arms to the same block). + let mut targets = pred + .flat_map(|pred| body.basic_blocks[pred].terminator.kind.successor_targets()) + .filter(|&target| target.block == id); + + let Some(first) = targets.next() else { + // No explicit targets means this block is only reachable via implicit edges + // (e.g., entry block or effectful continuations). Nothing to propagate. + return None.into_iter().flatten(); + }; + + // Seed with the first target's argument values. Each position holds `Some(T)` if + // that argument evaluated to a constant, `None` otherwise. + args.extend(first.args.iter().map(|&arg| eval(arg))); + + // Check remaining targets for consensus. If any target passes a different value + // (or non-constant) for a parameter position, clear that position to `None`. + for target in targets { + debug_assert_eq!(args.len(), target.args.len()); + + for (lhs, &rhs) in args.iter_mut().zip(target.args.iter()) { + let rhs = eval(rhs); + if *lhs != rhs { + *lhs = None; + } + } + } + + // Record constants for block parameters where all predecessors agreed. + + let params = body.basic_blocks[id].params; + + Some( + params + .0 + .iter() + .zip(args.drain(..)) + .filter_map(|(&local, constant)| constant.map(|constant| (local, constant))), + ) + .into_iter() + .flatten() +} + +pub struct CopyPropagation { + alloc: A, +} + +impl CopyPropagation { + #[must_use] + pub fn new() -> Self { + Self { + alloc: Scratch::new(), + } + } +} + +impl Default for CopyPropagation { + fn default() -> Self { + Self::new() + } +} + +impl CopyPropagation { + pub const fn new_in(alloc: A) -> Self { + Self { alloc } + } +} + +impl<'env, 'heap, A: BumpAllocator> TransformPass<'env, 'heap> for CopyPropagation { + fn run(&mut self, context: &mut MirContext<'env, 'heap>, body: &mut Body<'heap>) { + self.alloc.reset(); + + let mut visitor = CopyPropagationVisitor { + interner: context.interner, + constants: IdVec::with_capacity_in(body.local_decls.len(), &self.alloc), + }; + + let reverse_postorder = body + .basic_blocks + .reverse_postorder() + .transfer_into(&self.alloc); + + let mut args = Vec::new_in(&self.alloc); + + for &mut id in reverse_postorder { + for (local, constant) in + propagate_block_params(&mut args, body, id, |operand| visitor.try_eval(operand)) + { + visitor.constants.insert(local, constant); + } + + Ok(()) = + visitor.visit_basic_block(id, &mut body.basic_blocks.as_mut_preserving_cfg()[id]); + } + } +} + +struct CopyPropagationVisitor<'env, 'heap, A: Allocator> { + interner: &'env Interner<'heap>, + constants: LocalVec>, A>, +} + +impl<'heap, A: Allocator> CopyPropagationVisitor<'_, 'heap, A> { + /// Attempts to evaluate an operand to a known constant or classify it for simplification. + /// + /// Returns `Int` if the operand is a constant integer or a local known to hold one, + /// `Place` if it's a non-constant place, or `Other` for operands that can't be simplified. + fn try_eval(&self, operand: Operand<'heap>) -> Option> { + if let Operand::Constant(constant) = operand { + return Some(constant); + } + + if let Operand::Place(place) = operand + && place.projections.is_empty() + && let Some(&constant) = self.constants.lookup(place.local) + { + return Some(constant); + } + + None + } +} + +impl<'heap, A: Allocator> VisitorMut<'heap> for CopyPropagationVisitor<'_, 'heap, A> { + type Filter = filter::Deep; + type Residual = Result; + type Result + = Result + where + T: 'heap; + + fn interner(&self) -> &Interner<'heap> { + self.interner + } + + fn visit_operand(&mut self, _: Location, operand: &mut Operand<'heap>) -> Self::Result<()> { + if let Operand::Place(place) = operand + && place.projections.is_empty() + && let Some(&constant) = self.constants.lookup(place.local) + { + *operand = Operand::Constant(constant); + } + + Ok(()) + } + + fn visit_statement_assign( + &mut self, + location: Location, + assign: &mut Assign<'heap>, + ) -> Self::Result<()> { + Ok(()) = visit::r#mut::walk_statement_assign(self, location, assign); + let Assign { lhs, rhs } = assign; + + if !lhs.projections.is_empty() { + // We're not interested in assignments with projections, as that is out of scope for + // copy propagation. + return Ok(()); + } + + let RValue::Load(load) = rhs else { + // copy propagation is only applicable to load values + return Ok(()); + }; + + match load { + Operand::Place(place) if place.projections.is_empty() => { + if let Some(&constant) = self.constants.lookup(place.local) { + self.constants.insert(lhs.local, constant); + } + } + Operand::Place(_) => {} + &mut Operand::Constant(constant) => { + self.constants.insert(lhs.local, constant); + } + } + + Ok(()) + } +} diff --git a/libs/@local/hashql/mir/src/pass/transform/cp/tests.rs b/libs/@local/hashql/mir/src/pass/transform/cp/tests.rs new file mode 100644 index 00000000000..57575fa75f6 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/transform/cp/tests.rs @@ -0,0 +1,530 @@ +#![expect(clippy::min_ident_chars, reason = "tests")] + +use std::path::PathBuf; + +use bstr::ByteVec as _; +use hashql_core::{ + pretty::Formatter, + r#type::{TypeBuilder, TypeFormatter, TypeFormatterOptions, environment::Environment}, +}; +use hashql_diagnostics::DiagnosticIssues; +use insta::{Settings, assert_snapshot}; + +use super::CopyPropagation; +use crate::{ + body::{ + Body, + operand::Operand, + terminator::{GraphRead, GraphReadHead, GraphReadTail, TerminatorKind}, + }, + builder::{op, scaffold}, + context::MirContext, + def::DefIdSlice, + pass::TransformPass as _, + pretty::TextFormat, +}; + +#[track_caller] +fn assert_cp_pass<'heap>( + name: &'static str, + body: Body<'heap>, + context: &mut MirContext<'_, 'heap>, +) { + let formatter = Formatter::new(context.heap); + let mut formatter = TypeFormatter::new( + &formatter, + context.env, + TypeFormatterOptions::terse().with_qualified_opaque_names(true), + ); + let mut text_format = TextFormat { + writer: Vec::new(), + indent: 4, + sources: (), + types: &mut formatter, + }; + + let mut bodies = [body]; + + text_format + .format(DefIdSlice::from_raw(&bodies), &[]) + .expect("should be able to write bodies"); + + text_format + .writer + .extend(b"\n\n------------------------------------\n\n"); + + CopyPropagation::new().run(context, &mut bodies[0]); + + text_format + .format(DefIdSlice::from_raw(&bodies), &[]) + .expect("should be able to write bodies"); + + let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let mut settings = Settings::clone_current(); + settings.set_snapshot_path(dir.join("tests/ui/pass/cp")); + settings.set_prepend_module_to_snapshot(false); + + let _drop = settings.bind_to_scope(); + + let value = text_format.writer.into_string_lossy(); + assert_snapshot!(name, value); +} + +/// Tests basic constant propagation through operands. +/// +/// ```text +/// bb0: +/// x = 1 +/// y = x == x +/// return y +/// ``` +/// +/// After copy propagation, uses of `x` should be replaced with `const 1`. +#[test] +fn single_constant() { + scaffold!(heap, interner, builder); + let env = Environment::new(&heap); + + let int_ty = TypeBuilder::synthetic(&env).integer(); + let bool_ty = TypeBuilder::synthetic(&env).boolean(); + + let x = builder.local("x", int_ty); + let y = builder.local("y", bool_ty); + + let const_1 = builder.const_int(1); + + let bb0 = builder.reserve_block([]); + + builder + .build_block(bb0) + .assign_place(x, |rv| rv.load(const_1)) + .assign_place(y, |rv| rv.binary(x, op![==], x)) + .ret(y); + + let body = builder.finish(0, bool_ty); + + assert_cp_pass( + "single_constant", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests chain propagation through multiple loads. +/// +/// ```text +/// bb0: +/// x = 1 +/// y = x +/// z = y +/// w = z == z +/// return w +/// ``` +/// +/// All locals in the chain should be tracked, and uses of `z` replaced with `const 1`. +#[test] +fn constant_chain() { + scaffold!(heap, interner, builder); + let env = Environment::new(&heap); + + let int_ty = TypeBuilder::synthetic(&env).integer(); + let bool_ty = TypeBuilder::synthetic(&env).boolean(); + + let x = builder.local("x", int_ty); + let y = builder.local("y", int_ty); + let z = builder.local("z", int_ty); + let w = builder.local("w", bool_ty); + + let const_1 = builder.const_int(1); + + let bb0 = builder.reserve_block([]); + + builder + .build_block(bb0) + .assign_place(x, |rv| rv.load(const_1)) + .assign_place(y, |rv| rv.load(x)) + .assign_place(z, |rv| rv.load(y)) + .assign_place(w, |rv| rv.binary(z, op![==], z)) + .ret(w); + + let body = builder.finish(0, bool_ty); + + assert_cp_pass( + "constant_chain", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests block parameter propagation when all predecessors agree on the same constant. +/// +/// ```text +/// bb0: +/// cond = input +/// if cond -> bb1 else bb2 +/// bb1: +/// goto bb3(1) +/// bb2: +/// goto bb3(1) +/// bb3(p): +/// r = p == p +/// return r +/// ``` +/// +/// Both predecessors pass `const 1`, so `p` should be propagated as a constant. +#[test] +fn block_param_unanimous() { + scaffold!(heap, interner, builder); + let env = Environment::new(&heap); + + let int_ty = TypeBuilder::synthetic(&env).integer(); + let bool_ty = TypeBuilder::synthetic(&env).boolean(); + + let cond = builder.local("cond", bool_ty); + let p = builder.local("p", int_ty); + let r = builder.local("r", bool_ty); + + let const_1 = builder.const_int(1); + let const_true = builder.const_bool(true); + + let bb0 = builder.reserve_block([]); + let bb1 = builder.reserve_block([]); + let bb2 = builder.reserve_block([]); + let bb3 = builder.reserve_block([p.local]); + + builder + .build_block(bb0) + .assign_place(cond, |rv| rv.load(const_true)) + .if_else(cond, bb1, [], bb2, []); + + builder.build_block(bb1).goto(bb3, [const_1]); + builder.build_block(bb2).goto(bb3, [const_1]); + + builder + .build_block(bb3) + .assign_place(r, |rv| rv.binary(p, op![==], p)) + .ret(r); + + let body = builder.finish(0, bool_ty); + + assert_cp_pass( + "block_param_unanimous", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that block parameters are not propagated when predecessors disagree. +/// +/// ```text +/// bb0: +/// cond = input +/// if cond -> bb1 else bb2 +/// bb1: +/// goto bb3(1) +/// bb2: +/// goto bb3(2) +/// bb3(p): +/// r = p == p +/// return r +/// ``` +/// +/// Predecessors pass different values, so `p` should not be propagated. +#[test] +fn block_param_disagreement() { + scaffold!(heap, interner, builder); + let env = Environment::new(&heap); + + let int_ty = TypeBuilder::synthetic(&env).integer(); + let bool_ty = TypeBuilder::synthetic(&env).boolean(); + + let cond = builder.local("cond", bool_ty); + let p = builder.local("p", int_ty); + let r = builder.local("r", bool_ty); + + let const_1 = builder.const_int(1); + let const_2 = builder.const_int(2); + let const_true = builder.const_bool(true); + + let bb0 = builder.reserve_block([]); + let bb1 = builder.reserve_block([]); + let bb2 = builder.reserve_block([]); + let bb3 = builder.reserve_block([p.local]); + + builder + .build_block(bb0) + .assign_place(cond, |rv| rv.load(const_true)) + .if_else(cond, bb1, [], bb2, []); + + builder.build_block(bb1).goto(bb3, [const_1]); + builder.build_block(bb2).goto(bb3, [const_2]); + + builder + .build_block(bb3) + .assign_place(r, |rv| rv.binary(p, op![==], p)) + .ret(r); + + let body = builder.finish(0, bool_ty); + + assert_cp_pass( + "block_param_disagreement", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that block parameter propagation resolves locals through the constants map. +/// +/// ```text +/// bb0: +/// x = 1 +/// goto bb1(x) +/// bb1(p): +/// r = p == p +/// return r +/// ``` +/// +/// The predecessor passes local `x` which is known to be `const 1`. The `try_eval` +/// function should resolve this, allowing `p` to be propagated as a constant. +#[test] +fn block_param_via_local() { + scaffold!(heap, interner, builder); + let env = Environment::new(&heap); + + let int_ty = TypeBuilder::synthetic(&env).integer(); + let bool_ty = TypeBuilder::synthetic(&env).boolean(); + + let x = builder.local("x", int_ty); + let p = builder.local("p", int_ty); + let r = builder.local("r", bool_ty); + + let const_1 = builder.const_int(1); + + let bb0 = builder.reserve_block([]); + let bb1 = builder.reserve_block([p.local]); + + builder + .build_block(bb0) + .assign_place(x, |rv| rv.load(const_1)) + .goto(bb1, [x.into()]); + + builder + .build_block(bb1) + .assign_place(r, |rv| rv.binary(p, op![==], p)) + .ret(r); + + let body = builder.finish(0, bool_ty); + + assert_cp_pass( + "block_param_via_local", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that blocks with effectful predecessors are conservatively skipped. +/// +/// ```text +/// bb0: +/// graph_read -> bb2 +/// bb1: +/// goto bb2(1) +/// bb2(p): +/// r = p == p +/// return r +/// ``` +/// +/// Even though bb1 passes `const 1`, bb0 is an effectful predecessor (`GraphRead`) so +/// block parameter propagation is skipped entirely for bb2. The param `p` is NOT +/// propagated as a constant. +#[test] +fn block_param_effectful() { + scaffold!(heap, interner, builder); + let env = Environment::new(&heap); + + let int_ty = TypeBuilder::synthetic(&env).integer(); + let bool_ty = TypeBuilder::synthetic(&env).boolean(); + + let axis = builder.local("axis", TypeBuilder::synthetic(&env).unknown()); + let p = builder.local("p", int_ty); + let r = builder.local("r", bool_ty); + + let const_1 = builder.const_int(1); + + let bb0 = builder.reserve_block([]); + let bb1 = builder.reserve_block([]); + let bb2 = builder.reserve_block([p.local]); + + builder + .build_block(bb0) + .finish_with_terminator(TerminatorKind::GraphRead(GraphRead { + head: GraphReadHead::Entity { + axis: Operand::Place(axis), + }, + body: Vec::new_in(&heap), + tail: GraphReadTail::Collect, + target: bb2, + })); + + builder.build_block(bb1).goto(bb2, [const_1]); + + builder + .build_block(bb2) + .assign_place(r, |rv| rv.binary(p, op![==], p)) + .ret(r); + + let body = builder.finish(1, bool_ty); + + assert_cp_pass( + "block_param_effectful", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that places with projections are not propagated. +/// +/// ```text +/// bb0: +/// x = (1, 2) +/// y = x.0 +/// r = y == y +/// return r +/// ``` +/// +/// Copy propagation only handles simple locals without projections. The projection +/// `x.0` should not be replaced, though `y` (if tracked) could be. +#[test] +fn projection_unchanged() { + scaffold!(heap, interner, builder); + let env = Environment::new(&heap); + + let int_ty = TypeBuilder::synthetic(&env).integer(); + let bool_ty = TypeBuilder::synthetic(&env).boolean(); + let tuple_ty = TypeBuilder::synthetic(&env).tuple([int_ty, int_ty]); + + let x = builder.local("x", tuple_ty); + let y = builder.local("y", int_ty); + let r = builder.local("r", bool_ty); + + let x_0 = builder.place(|place| place.from(x).field(0, int_ty)); + + let const_1 = builder.const_int(1); + let const_2 = builder.const_int(2); + + let bb0 = builder.reserve_block([]); + + builder + .build_block(bb0) + .assign_place(x, |rv| rv.tuple([const_1, const_2])) + .assign_place(y, |rv| rv.load(x_0)) + .assign_place(r, |rv| rv.binary(y, op![==], y)) + .ret(r); + + let body = builder.finish(0, bool_ty); + + assert_cp_pass( + "projection_unchanged", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that constants on loop back-edges are not discovered (no fix-point iteration). +/// +/// ```text +/// bb0: +/// x = 1 +/// goto bb1 +/// bb1: +/// // x comes from bb0 (const 1) or bb1 (const 2) - disagreement +/// r = x == x +/// x = 2 +/// if cond -> bb1 else bb2 +/// bb2: +/// return r +/// ``` +/// +/// This documents the limitation: even though the back-edge always passes `const 2`, +/// we don't discover this because predecessors forming back-edges haven't been visited +/// when the loop header is processed. +#[test] +fn loop_back_edge() { + scaffold!(heap, interner, builder); + let env = Environment::new(&heap); + + let int_ty = TypeBuilder::synthetic(&env).integer(); + let bool_ty = TypeBuilder::synthetic(&env).boolean(); + + let x = builder.local("x", int_ty); + let r = builder.local("r", bool_ty); + let cond = builder.local("cond", bool_ty); + + let const_1 = builder.const_int(1); + let const_2 = builder.const_int(2); + let const_true = builder.const_bool(true); + + let bb0 = builder.reserve_block([]); + let bb1 = builder.reserve_block([x.local]); + let bb2 = builder.reserve_block([]); + + builder + .build_block(bb0) + .assign_place(cond, |rv| rv.load(const_true)) + .goto(bb1, [const_1]); + + builder + .build_block(bb1) + .assign_place(r, |rv| rv.binary(x, op![==], x)) + .if_else(cond, bb1, [const_2], bb2, []); + + builder.build_block(bb2).ret(r); + + let body = builder.finish(0, bool_ty); + + assert_cp_pass( + "loop_back_edge", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} diff --git a/libs/@local/hashql/mir/src/pass/transform/inst_simplify/mod.rs b/libs/@local/hashql/mir/src/pass/transform/inst_simplify/mod.rs index 59860910abd..2329ac1f50d 100644 --- a/libs/@local/hashql/mir/src/pass/transform/inst_simplify/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/inst_simplify/mod.rs @@ -91,17 +91,16 @@ mod tests; use core::{alloc::Allocator, convert::Infallible}; use hashql_core::{ - graph::Predecessors as _, heap::{BumpAllocator, Scratch, TransferInto as _}, id::IdVec, r#type::{environment::Environment, kind::PrimitiveType}, }; use hashql_hir::node::operation::UnOp; +use super::cp::propagate_block_params; use crate::{ body::{ Body, - basic_block::BasicBlockId, constant::{Constant, Int}, local::{LocalDecl, LocalSlice, LocalVec}, location::Location, @@ -169,78 +168,6 @@ impl InstSimplify { pub const fn new_in(alloc: A) -> Self { Self { alloc } } - - /// Propagates constant values through block parameters. - /// - /// For each block parameter, examines all predecessor branches that target this block. - /// If all predecessors pass the same constant value (or a local that evaluates to the - /// same constant), records that constant in `evaluated` for the block parameter. - /// - /// This enables constant folding for values that converge at control flow join points, - /// complementing SROA's structural resolution with runtime constant tracking. - /// - /// # Limitations - /// - /// - Predecessors with effectful terminators (e.g., `GraphRead`) are skipped entirely, as their - /// arguments are implicit rather than explicit. - /// - Back-edges from loops may not have been visited yet, so their contributions are based on - /// whatever constants were discovered in previous iterations (if any). Full loop-carried - /// propagation would require fix-point iteration. - fn propagate_block_params<'heap>( - args: &mut Vec, &A>, - visitor: &mut InstSimplifyVisitor<'_, 'heap, &A>, - body: &Body<'heap>, - id: BasicBlockId, - ) { - let pred = body.basic_blocks.predecessors(id); - - // Effectful terminators (like GraphRead) pass arguments implicitly, where they set the - // block param directly. We cannot inspect those values, so we conservatively skip - // propagation for blocks reachable from effectful predecessors (they have single - // successors). - if pred - .clone() - .any(|pred| body.basic_blocks[pred].terminator.kind.is_effectful()) - { - return; - } - - // Collect all predecessor targets that branch to this block. A single predecessor - // may have multiple targets to us (e.g., a switch with two arms to the same block). - let mut targets = pred - .flat_map(|pred| body.basic_blocks[pred].terminator.kind.successor_targets()) - .filter(|&target| target.block == id); - - let Some(first) = targets.next() else { - // No explicit targets means this block is only reachable via implicit edges - // (e.g., entry block or effectful continuations). Nothing to propagate. - return; - }; - - // Seed with the first target's argument values. Each position holds `Some(int)` if - // that argument evaluated to a constant, `None` otherwise. - args.extend(first.args.iter().map(|&arg| visitor.try_eval(arg).as_int())); - - // Check remaining targets for consensus. If any target passes a different value - // (or non-constant) for a parameter position, clear that position to `None`. - for target in targets { - debug_assert_eq!(args.len(), target.args.len()); - - for (lhs, &rhs) in args.iter_mut().zip(target.args.iter()) { - let rhs = visitor.try_eval(rhs).as_int(); - if *lhs != rhs { - *lhs = None; - } - } - } - - // Record constants for block parameters where all predecessors agreed. - for (&local, constant) in body.basic_blocks[id].params.iter().zip(args.drain(..)) { - if let Some(constant) = constant { - visitor.evaluated.insert(local, constant); - } - } - } } impl<'env, 'heap, A: BumpAllocator> TransformPass<'env, 'heap> for InstSimplify { @@ -263,7 +190,11 @@ impl<'env, 'heap, A: BumpAllocator> TransformPass<'env, 'heap> for InstSimplify< let mut args = Vec::new_in(&self.alloc); for &mut id in reverse_postorder { - Self::propagate_block_params(&mut args, &mut visitor, body, id); + for (local, int) in propagate_block_params(&mut args, body, id, |operand| { + visitor.try_eval(operand).as_int() + }) { + visitor.evaluated.insert(local, int); + } Ok(()) = visitor.visit_basic_block(id, &mut body.basic_blocks.as_mut_preserving_cfg()[id]); @@ -302,7 +233,7 @@ impl<'heap, A: Allocator> InstSimplifyVisitor<'_, 'heap, A> { if let Operand::Place(place) = operand && place.projections.is_empty() - && let Some(&Some(int)) = self.evaluated.get(place.local) + && let Some(&int) = self.evaluated.lookup(place.local) { return OperandKind::Int(int); } @@ -601,9 +532,9 @@ impl<'heap, A: Allocator> VisitorMut<'heap> for InstSimplifyVisitor<'_, 'heap, A // already a constant, we can propagate it. if let RValue::Load(Operand::Place(place)) = trampoline && place.projections.is_empty() - && let Some(&Some(constant)) = self.evaluated.get(place.local) + && let Some(&int) = self.evaluated.lookup(place.local) { - self.evaluated.insert(assign.lhs.local, constant); + self.evaluated.insert(assign.lhs.local, int); } assign.rhs = trampoline; diff --git a/libs/@local/hashql/mir/src/pass/transform/mod.rs b/libs/@local/hashql/mir/src/pass/transform/mod.rs index 5cc9a9f3622..30ad2ae545a 100644 --- a/libs/@local/hashql/mir/src/pass/transform/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/mod.rs @@ -1,4 +1,5 @@ mod cfg_simplify; +mod cp; mod dbe; mod dle; mod dse; @@ -8,6 +9,7 @@ mod sroa; mod ssa_repair; pub use self::{ - cfg_simplify::CfgSimplify, dbe::DeadBlockElimination, dle::DeadLocalElimination, - dse::DeadStoreElimination, inst_simplify::InstSimplify, sroa::Sroa, ssa_repair::SsaRepair, + cfg_simplify::CfgSimplify, cp::CopyPropagation, dbe::DeadBlockElimination, + dle::DeadLocalElimination, dse::DeadStoreElimination, inst_simplify::InstSimplify, sroa::Sroa, + ssa_repair::SsaRepair, }; diff --git a/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_disagreement.snap b/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_disagreement.snap new file mode 100644 index 00000000000..8774f3e0959 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_disagreement.snap @@ -0,0 +1,57 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/cp/tests.rs +expression: value +--- +fn {intrinsic#4294967040}() -> Boolean { + let %0: Boolean + let %1: Integer + let %2: Boolean + + bb0(): { + %0 = 1 + + switchInt(%0) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(1) + } + + bb2(): { + goto -> bb3(2) + } + + bb3(%1): { + %2 = %1 == %1 + + return %2 + } +} + +------------------------------------ + +fn {intrinsic#4294967040}() -> Boolean { + let %0: Boolean + let %1: Integer + let %2: Boolean + + bb0(): { + %0 = 1 + + switchInt(1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(1) + } + + bb2(): { + goto -> bb3(2) + } + + bb3(%1): { + %2 = %1 == %1 + + return %2 + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_effectful.snap b/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_effectful.snap new file mode 100644 index 00000000000..d82a34c135d --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_effectful.snap @@ -0,0 +1,45 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/cp/tests.rs +expression: value +--- +fn {intrinsic#4294967040}(%0: ?) -> Boolean { + let %1: Integer + let %2: Boolean + + bb0(): { + graph read entities(%0) + |> collect -> bb2(_) + } + + bb1(): { + goto -> bb2(1) + } + + bb2(%1): { + %2 = %1 == %1 + + return %2 + } +} + +------------------------------------ + +fn {intrinsic#4294967040}(%0: ?) -> Boolean { + let %1: Integer + let %2: Boolean + + bb0(): { + graph read entities(%0) + |> collect -> bb2(_) + } + + bb1(): { + goto -> bb2(1) + } + + bb2(%1): { + %2 = %1 == %1 + + return %2 + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_unanimous.snap b/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_unanimous.snap new file mode 100644 index 00000000000..553b7e7985a --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_unanimous.snap @@ -0,0 +1,57 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/cp/tests.rs +expression: value +--- +fn {intrinsic#4294967040}() -> Boolean { + let %0: Boolean + let %1: Integer + let %2: Boolean + + bb0(): { + %0 = 1 + + switchInt(%0) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(1) + } + + bb2(): { + goto -> bb3(1) + } + + bb3(%1): { + %2 = %1 == %1 + + return %2 + } +} + +------------------------------------ + +fn {intrinsic#4294967040}() -> Boolean { + let %0: Boolean + let %1: Integer + let %2: Boolean + + bb0(): { + %0 = 1 + + switchInt(1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(1) + } + + bb2(): { + goto -> bb3(1) + } + + bb3(%1): { + %2 = 1 == 1 + + return %2 + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_via_local.snap b/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_via_local.snap new file mode 100644 index 00000000000..a9cc1677e5f --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/cp/block_param_via_local.snap @@ -0,0 +1,41 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/cp/tests.rs +expression: value +--- +fn {intrinsic#4294967040}() -> Boolean { + let %0: Integer + let %1: Integer + let %2: Boolean + + bb0(): { + %0 = 1 + + goto -> bb1(%0) + } + + bb1(%1): { + %2 = %1 == %1 + + return %2 + } +} + +------------------------------------ + +fn {intrinsic#4294967040}() -> Boolean { + let %0: Integer + let %1: Integer + let %2: Boolean + + bb0(): { + %0 = 1 + + goto -> bb1(1) + } + + bb1(%1): { + %2 = 1 == 1 + + return %2 + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/cp/constant_chain.snap b/libs/@local/hashql/mir/tests/ui/pass/cp/constant_chain.snap new file mode 100644 index 00000000000..869bd6aa199 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/cp/constant_chain.snap @@ -0,0 +1,37 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/cp/tests.rs +expression: value +--- +fn {intrinsic#4294967040}() -> Boolean { + let %0: Integer + let %1: Integer + let %2: Integer + let %3: Boolean + + bb0(): { + %0 = 1 + %1 = %0 + %2 = %1 + %3 = %2 == %2 + + return %3 + } +} + +------------------------------------ + +fn {intrinsic#4294967040}() -> Boolean { + let %0: Integer + let %1: Integer + let %2: Integer + let %3: Boolean + + bb0(): { + %0 = 1 + %1 = 1 + %2 = 1 + %3 = 1 == 1 + + return %3 + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/cp/loop_back_edge.snap b/libs/@local/hashql/mir/tests/ui/pass/cp/loop_back_edge.snap new file mode 100644 index 00000000000..dbb947c404d --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/cp/loop_back_edge.snap @@ -0,0 +1,49 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/cp/tests.rs +expression: value +--- +fn {intrinsic#4294967040}() -> Boolean { + let %0: Integer + let %1: Boolean + let %2: Boolean + + bb0(): { + %2 = 1 + + goto -> bb1(1) + } + + bb1(%0): { + %1 = %0 == %0 + + switchInt(%2) -> [0: bb2(), 1: bb1(2)] + } + + bb2(): { + return %1 + } +} + +------------------------------------ + +fn {intrinsic#4294967040}() -> Boolean { + let %0: Integer + let %1: Boolean + let %2: Boolean + + bb0(): { + %2 = 1 + + goto -> bb1(1) + } + + bb1(%0): { + %1 = %0 == %0 + + switchInt(1) -> [0: bb2(), 1: bb1(2)] + } + + bb2(): { + return %1 + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/cp/projection_unchanged.snap b/libs/@local/hashql/mir/tests/ui/pass/cp/projection_unchanged.snap new file mode 100644 index 00000000000..9a6fdabf992 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/cp/projection_unchanged.snap @@ -0,0 +1,33 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/cp/tests.rs +expression: value +--- +fn {intrinsic#4294967040}() -> Boolean { + let %0: (Integer, Integer) + let %1: Integer + let %2: Boolean + + bb0(): { + %0 = (1, 2) + %1 = %0.0 + %2 = %1 == %1 + + return %2 + } +} + +------------------------------------ + +fn {intrinsic#4294967040}() -> Boolean { + let %0: (Integer, Integer) + let %1: Integer + let %2: Boolean + + bb0(): { + %0 = (1, 2) + %1 = %0.0 + %2 = %1 == %1 + + return %2 + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/cp/single_constant.snap b/libs/@local/hashql/mir/tests/ui/pass/cp/single_constant.snap new file mode 100644 index 00000000000..dc817a19b82 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/cp/single_constant.snap @@ -0,0 +1,29 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/cp/tests.rs +expression: value +--- +fn {intrinsic#4294967040}() -> Boolean { + let %0: Integer + let %1: Boolean + + bb0(): { + %0 = 1 + %1 = %0 == %0 + + return %1 + } +} + +------------------------------------ + +fn {intrinsic#4294967040}() -> Boolean { + let %0: Integer + let %1: Boolean + + bb0(): { + %0 = 1 + %1 = 1 == 1 + + return %1 + } +}