diff --git a/crates/account-abstraction-core/core/src/mempool.rs b/crates/account-abstraction-core/core/src/mempool.rs index 3cc274f..637bf9c 100644 --- a/crates/account-abstraction-core/core/src/mempool.rs +++ b/crates/account-abstraction-core/core/src/mempool.rs @@ -1,8 +1,9 @@ -use crate::types::{PoolOperation, UserOpHash}; +use crate::types::{UserOpHash, WrappedUserOperation}; +use alloy_primitives::Address; +use std::cmp::Ordering; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; -use std::sync::atomic::AtomicU64; -use std::sync::atomic::Ordering; +use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering}; pub struct PoolConfig { minimum_max_fee_per_gas: u128, @@ -10,63 +11,107 @@ pub struct PoolConfig { #[derive(Eq, PartialEq, Clone, Debug)] pub struct OrderedPoolOperation { - pub pool_operation: PoolOperation, + pub pool_operation: WrappedUserOperation, pub submission_id: u64, - pub priority_order: u64, } -impl Ord for OrderedPoolOperation { - /// TODO: There can be invalid opperations, where base fee, + expected gas price - /// is greater that the maximum gas, in that case we don't include it in the mempool as such mempool changes. - fn cmp(&self, other: &Self) -> std::cmp::Ordering { +impl OrderedPoolOperation { + pub fn from_wrapped(operation: &WrappedUserOperation, submission_id: u64) -> Self { + Self { + pool_operation: operation.clone(), + submission_id, + } + } + + pub fn sender(&self) -> Address { + self.pool_operation.operation.sender() + } +} + +/// Ordering by max priority fee (desc) then submission id, then hash to ensure total order +#[derive(Clone, Debug)] +pub struct ByMaxFeeAndSubmissionId(pub OrderedPoolOperation); + +impl PartialEq for ByMaxFeeAndSubmissionId { + fn eq(&self, other: &Self) -> bool { + self.0.pool_operation.hash == other.0.pool_operation.hash + } +} +impl Eq for ByMaxFeeAndSubmissionId {} + +impl PartialOrd for ByMaxFeeAndSubmissionId { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ByMaxFeeAndSubmissionId { + fn cmp(&self, other: &Self) -> Ordering { other + .0 .pool_operation .operation .max_priority_fee_per_gas() - .cmp(&self.pool_operation.operation.max_priority_fee_per_gas()) - .then_with(|| self.submission_id.cmp(&other.submission_id)) + .cmp(&self.0.pool_operation.operation.max_priority_fee_per_gas()) + .then_with(|| self.0.submission_id.cmp(&other.0.submission_id)) + .then_with(|| self.0.pool_operation.hash.cmp(&other.0.pool_operation.hash)) + } +} + +/// Ordering by nonce (asc), then submission id, then hash to ensure total order +#[derive(Clone, Debug)] +pub struct ByNonce(pub OrderedPoolOperation); + +impl PartialEq for ByNonce { + fn eq(&self, other: &Self) -> bool { + self.0.pool_operation.hash == other.0.pool_operation.hash } } +impl Eq for ByNonce {} -impl PartialOrd for OrderedPoolOperation { - fn partial_cmp(&self, other: &Self) -> Option { +impl PartialOrd for ByNonce { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl OrderedPoolOperation { - pub fn create_from_pool_operation(operation: &PoolOperation, submission_id: u64) -> Self { - Self { - pool_operation: operation.clone(), - priority_order: submission_id, - submission_id, - } +impl Ord for ByNonce { + /// TODO: There can be invalid opperations, where base fee, + expected gas price + /// is greater that the maximum gas, in that case we don't include it in the mempool as such mempool changes. + fn cmp(&self, other: &Self) -> Ordering { + self.0 + .pool_operation + .operation + .nonce() + .cmp(&other.0.pool_operation.operation.nonce()) + .then_with(|| self.0.submission_id.cmp(&other.0.submission_id)) } } pub trait Mempool { fn add_operation( &mut self, - operation: &PoolOperation, + operation: &WrappedUserOperation, ) -> Result, anyhow::Error>; - fn get_top_operations(&self, n: usize) -> impl Iterator>; + fn get_top_operations(&self, n: usize) -> impl Iterator>; fn remove_operation( &mut self, operation_hash: &UserOpHash, - ) -> Result, anyhow::Error>; + ) -> Result, anyhow::Error>; } pub struct MempoolImpl { config: PoolConfig, - best: BTreeSet, + best: BTreeSet, hash_to_operation: HashMap, + operations_by_account: HashMap>, submission_id_counter: AtomicU64, } impl Mempool for MempoolImpl { fn add_operation( &mut self, - operation: &PoolOperation, + operation: &WrappedUserOperation, ) -> Result, anyhow::Error> { if operation.operation.max_fee_per_gas() < self.config.minimum_max_fee_per_gas { return Err(anyhow::anyhow!( @@ -77,19 +122,47 @@ impl Mempool for MempoolImpl { Ok(ordered_operation_result) } - fn get_top_operations(&self, n: usize) -> impl Iterator> { + fn get_top_operations(&self, n: usize) -> impl Iterator> { + // TODO: There is a case where we skip operations that are not the lowest nonce for an account. + // But we still have not given the N number of operations, meaning we don't return those operations. + self.best .iter() + .filter_map(|op_by_fee| { + let lowest = self + .operations_by_account + .get(&op_by_fee.0.sender()) + .and_then(|set| set.first()); + + match lowest { + Some(lowest) + if lowest.0.pool_operation.hash == op_by_fee.0.pool_operation.hash => + { + Some(Arc::new(op_by_fee.0.pool_operation.clone())) + } + Some(_) => None, + None => { + println!( + "No operations found for account: {} but one was found in the best set", + op_by_fee.0.sender() + ); + None + } + } + }) .take(n) - .map(|o| Arc::new(o.pool_operation.clone())) } fn remove_operation( &mut self, operation_hash: &UserOpHash, - ) -> Result, anyhow::Error> { + ) -> Result, anyhow::Error> { if let Some(ordered_operation) = self.hash_to_operation.remove(operation_hash) { - self.best.remove(&ordered_operation); + self.best + .remove(&ByMaxFeeAndSubmissionId(ordered_operation.clone())); + self.operations_by_account + .get_mut(&ordered_operation.sender()) + .map(|set| set.remove(&ByNonce(ordered_operation.clone()))); Ok(Some(ordered_operation.pool_operation)) } else { Ok(None) @@ -97,31 +170,35 @@ impl Mempool for MempoolImpl { } } +// When user opperation is added to the mempool we need to check + impl MempoolImpl { fn handle_add_operation( &mut self, - operation: &PoolOperation, + operation: &WrappedUserOperation, ) -> Result, anyhow::Error> { - if let Some(old_ordered_operation) = self.hash_to_operation.get(&operation.hash) { - if operation.should_replace(&old_ordered_operation.pool_operation) { - self.best.remove(old_ordered_operation); - self.hash_to_operation.remove(&operation.hash); - } else { - return Ok(None); - } + // Account + if self.hash_to_operation.contains_key(&operation.hash) { + return Ok(None); } let order = self.get_next_order_id(); - let ordered_operation = OrderedPoolOperation::create_from_pool_operation(operation, order); + let ordered_operation = OrderedPoolOperation::from_wrapped(operation, order); - self.best.insert(ordered_operation.clone()); + self.best + .insert(ByMaxFeeAndSubmissionId(ordered_operation.clone())); + self.operations_by_account + .entry(ordered_operation.sender()) + .or_default() + .insert(ByNonce(ordered_operation.clone())); self.hash_to_operation .insert(operation.hash, ordered_operation.clone()); Ok(Some(ordered_operation)) } fn get_next_order_id(&self) -> u64 { - self.submission_id_counter.fetch_add(1, Ordering::SeqCst) + self.submission_id_counter + .fetch_add(1, AtomicOrdering::SeqCst) } pub fn new(config: PoolConfig) -> Self { @@ -129,6 +206,7 @@ impl MempoolImpl { config, best: BTreeSet::new(), hash_to_operation: HashMap::new(), + operations_by_account: HashMap::new(), submission_id_counter: AtomicU64::new(0), } } @@ -142,7 +220,7 @@ mod tests { use alloy_rpc_types::erc4337; fn create_test_user_operation(max_priority_fee_per_gas: u128) -> VersionedUserOperation { VersionedUserOperation::UserOperation(erc4337::UserOperation { - sender: Address::ZERO, + sender: Address::random(), nonce: Uint::from(0), init_code: Default::default(), call_data: Default::default(), @@ -156,8 +234,11 @@ mod tests { }) } - fn create_pool_operation(max_priority_fee_per_gas: u128, hash: UserOpHash) -> PoolOperation { - PoolOperation { + fn create_wrapped_operation( + max_priority_fee_per_gas: u128, + hash: UserOpHash, + ) -> WrappedUserOperation { + WrappedUserOperation { operation: create_test_user_operation(max_priority_fee_per_gas), hash, } @@ -174,7 +255,7 @@ mod tests { fn test_add_operation_success() { let mut mempool = create_test_mempool(1000); let hash = FixedBytes::from([1u8; 32]); - let operation = create_pool_operation(2000, hash); + let operation = create_wrapped_operation(2000, hash); let result = mempool.add_operation(&operation); @@ -194,7 +275,7 @@ mod tests { fn test_add_operation_below_minimum_gas() { let mut mempool = create_test_mempool(2000); let hash = FixedBytes::from([1u8; 32]); - let operation = create_pool_operation(1000, hash); + let operation = create_wrapped_operation(1000, hash); let result = mempool.add_operation(&operation); @@ -207,76 +288,25 @@ mod tests { ); } - // Tests adding an operation with the same hash but higher gas price - #[test] - fn test_add_operation_duplicate_hash_higher_gas() { - let mut mempool = create_test_mempool(1000); - let hash = FixedBytes::from([1u8; 32]); - - let operation1 = create_pool_operation(2000, hash); - let result1 = mempool.add_operation(&operation1); - assert!(result1.is_ok()); - assert!(result1.unwrap().is_some()); - - let operation2 = create_pool_operation(3000, hash); - let result2 = mempool.add_operation(&operation2); - assert!(result2.is_ok()); - assert!(result2.unwrap().is_some()); - } - - // Tests adding an operation with the same hash but lower gas price - #[test] - fn test_add_operation_duplicate_hash_lower_gas() { - let mut mempool = create_test_mempool(1000); - let hash = FixedBytes::from([1u8; 32]); - - let operation1 = create_pool_operation(3000, hash); - let result1 = mempool.add_operation(&operation1); - assert!(result1.is_ok()); - assert!(result1.unwrap().is_some()); - - let operation2 = create_pool_operation(2000, hash); - let result2 = mempool.add_operation(&operation2); - assert!(result2.is_ok()); - assert!(result2.unwrap().is_none()); - } - - // Tests adding an operation with the same hash and equal gas price - #[test] - fn test_add_operation_duplicate_hash_equal_gas() { - let mut mempool = create_test_mempool(1000); - let hash = FixedBytes::from([1u8; 32]); - - let operation1 = create_pool_operation(2000, hash); - let result1 = mempool.add_operation(&operation1); - assert!(result1.is_ok()); - assert!(result1.unwrap().is_some()); - - let operation2 = create_pool_operation(2000, hash); - let result2 = mempool.add_operation(&operation2); - assert!(result2.is_ok()); - assert!(result2.unwrap().is_none()); - } - // Tests adding multiple operations with different hashes #[test] - fn test_add_multiple_operations_with_different_hashes() { + fn test_add_multiple_operations() { let mut mempool = create_test_mempool(1000); let hash1 = FixedBytes::from([1u8; 32]); - let operation1 = create_pool_operation(2000, hash1); + let operation1 = create_wrapped_operation(2000, hash1); let result1 = mempool.add_operation(&operation1); assert!(result1.is_ok()); assert!(result1.unwrap().is_some()); let hash2 = FixedBytes::from([2u8; 32]); - let operation2 = create_pool_operation(3000, hash2); + let operation2 = create_wrapped_operation(3000, hash2); let result2 = mempool.add_operation(&operation2); assert!(result2.is_ok()); assert!(result2.unwrap().is_some()); let hash3 = FixedBytes::from([3u8; 32]); - let operation3 = create_pool_operation(1500, hash3); + let operation3 = create_wrapped_operation(1500, hash3); let result3 = mempool.add_operation(&operation3); assert!(result3.is_ok()); assert!(result3.unwrap().is_some()); @@ -301,7 +331,7 @@ mod tests { fn test_remove_operation_exists() { let mut mempool = create_test_mempool(1000); let hash = FixedBytes::from([1u8; 32]); - let operation = create_pool_operation(2000, hash); + let operation = create_wrapped_operation(2000, hash); mempool.add_operation(&operation).unwrap(); @@ -319,7 +349,7 @@ mod tests { fn test_remove_operation_and_check_best() { let mut mempool = create_test_mempool(1000); let hash = FixedBytes::from([1u8; 32]); - let operation = create_pool_operation(2000, hash); + let operation = create_wrapped_operation(2000, hash); mempool.add_operation(&operation).unwrap(); @@ -341,15 +371,15 @@ mod tests { let mut mempool = create_test_mempool(1000); let hash1 = FixedBytes::from([1u8; 32]); - let operation1 = create_pool_operation(2000, hash1); + let operation1 = create_wrapped_operation(2000, hash1); mempool.add_operation(&operation1).unwrap(); let hash2 = FixedBytes::from([2u8; 32]); - let operation2 = create_pool_operation(3000, hash2); + let operation2 = create_wrapped_operation(3000, hash2); mempool.add_operation(&operation2).unwrap(); let hash3 = FixedBytes::from([3u8; 32]); - let operation3 = create_pool_operation(1500, hash3); + let operation3 = create_wrapped_operation(1500, hash3); mempool.add_operation(&operation3).unwrap(); let best: Vec<_> = mempool.get_top_operations(10).collect(); @@ -365,15 +395,15 @@ mod tests { let mut mempool = create_test_mempool(1000); let hash1 = FixedBytes::from([1u8; 32]); - let operation1 = create_pool_operation(2000, hash1); + let operation1 = create_wrapped_operation(2000, hash1); mempool.add_operation(&operation1).unwrap(); let hash2 = FixedBytes::from([2u8; 32]); - let operation2 = create_pool_operation(3000, hash2); + let operation2 = create_wrapped_operation(3000, hash2); mempool.add_operation(&operation2).unwrap(); let hash3 = FixedBytes::from([3u8; 32]); - let operation3 = create_pool_operation(1500, hash3); + let operation3 = create_wrapped_operation(1500, hash3); mempool.add_operation(&operation3).unwrap(); let best: Vec<_> = mempool.get_top_operations(2).collect(); @@ -388,11 +418,11 @@ mod tests { let mut mempool = create_test_mempool(1000); let hash1 = FixedBytes::from([1u8; 32]); - let operation1 = create_pool_operation(2000, hash1); + let operation1 = create_wrapped_operation(2000, hash1); mempool.add_operation(&operation1).unwrap().unwrap(); let hash2 = FixedBytes::from([2u8; 32]); - let operation2 = create_pool_operation(2000, hash2); + let operation2 = create_wrapped_operation(2000, hash2); mempool.add_operation(&operation2).unwrap().unwrap(); let best: Vec<_> = mempool.get_top_operations(2).collect(); @@ -400,4 +430,42 @@ mod tests { assert_eq!(best[0].hash, hash1); assert_eq!(best[1].hash, hash2); } + + #[test] + fn test_get_top_operations_should_return_the_lowest_nonce_operation_for_each_account() { + let mut mempool = create_test_mempool(1000); + let hash1 = FixedBytes::from([1u8; 32]); + let test_user_operation = create_test_user_operation(2000); + + // Destructure to the inner struct, then update nonce + let base_op = match test_user_operation.clone() { + VersionedUserOperation::UserOperation(op) => op, + _ => panic!("expected UserOperation variant"), + }; + + let operation1 = WrappedUserOperation { + operation: VersionedUserOperation::UserOperation(erc4337::UserOperation { + nonce: Uint::from(0), + max_fee_per_gas: Uint::from(2000), + ..base_op.clone() + }), + hash: hash1, + }; + + mempool.add_operation(&operation1).unwrap().unwrap(); + let hash2 = FixedBytes::from([2u8; 32]); + let operation2 = WrappedUserOperation { + operation: VersionedUserOperation::UserOperation(erc4337::UserOperation { + nonce: Uint::from(1), + max_fee_per_gas: Uint::from(10_000), + ..base_op.clone() + }), + hash: hash2, + }; + mempool.add_operation(&operation2).unwrap().unwrap(); + + let best: Vec<_> = mempool.get_top_operations(2).collect(); + assert_eq!(best.len(), 1); + assert_eq!(best[0].operation.nonce(), Uint::from(0)); + } } diff --git a/crates/account-abstraction-core/core/src/types.rs b/crates/account-abstraction-core/core/src/types.rs index 04e3c9a..4600839 100644 --- a/crates/account-abstraction-core/core/src/types.rs +++ b/crates/account-abstraction-core/core/src/types.rs @@ -26,8 +26,20 @@ impl VersionedUserOperation { VersionedUserOperation::PackedUserOperation(op) => op.max_priority_fee_per_gas, } } -} + pub fn nonce(&self) -> U256 { + match self { + VersionedUserOperation::UserOperation(op) => op.nonce, + VersionedUserOperation::PackedUserOperation(op) => op.nonce, + } + } + pub fn sender(&self) -> Address { + match self { + VersionedUserOperation::UserOperation(op) => op.sender, + VersionedUserOperation::PackedUserOperation(op) => op.sender, + } + } +} #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct UserOperationRequest { pub user_operation: VersionedUserOperation, @@ -127,13 +139,13 @@ pub struct AggregatorInfo { pub type UserOpHash = FixedBytes<32>; #[derive(Eq, PartialEq, Clone, Debug)] -pub struct PoolOperation { +pub struct WrappedUserOperation { pub operation: VersionedUserOperation, pub hash: UserOpHash, } -impl PoolOperation { - pub fn should_replace(&self, other: &PoolOperation) -> bool { +impl WrappedUserOperation { + pub fn has_higher_max_fee(&self, other: &WrappedUserOperation) -> bool { self.operation.max_fee_per_gas() > other.operation.max_fee_per_gas() } }