From 6ff8966f11d34e2bf9349fa6d88128d58f37575c Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 18 Dec 2025 00:28:16 +0000 Subject: [PATCH 1/3] flash feature refactor This PR extends the Qwen2 architecture to other models other than `Alibaba-NLP/gte-Qwen2-7B-instruct`, given that the prior implementation was only covering such cases so as to use causal attention on CUDA and to rely on the provided tokenizer rather than patching it. # What does this PR do? - makes flash-attn-3 and flash-attn-cpu easier to add. Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/text-embeddings-inference/blob/main/CONTRIBUTING.md)? - [ ] Was this discussed/approved via a GitHub issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs). - [ ] Did you write any new necessary tests? If applicable, did you include or update the `insta` snapshots? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- backends/candle/src/lib.rs | 83 +++++++++++++------------------------- 1 file changed, 28 insertions(+), 55 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index ff824f555..7e1ab15a5 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -3,6 +3,7 @@ mod alibi; mod compute_cap; #[cfg(feature = "cuda")] mod flash_attn; +mod flash_attn_cpu; mod layers; mod models; @@ -115,6 +116,23 @@ enum Config { XlmRoberta(BertConfig), } +/// Helper function to check if flash attention should be used +/// require_fa2_capabilities: true for models that need flash attention v2, false for v1/v2 compatible models +fn should_use_flash_attention(require_fa2_capabilities: bool) -> bool { + let flash_attn_enabled = &std::env::var("USE_FLASH_ATTENTION").unwrap_or("true".to_string()).to_lowercase() == "true"; + + #cfg!(feature = "cuda") else { + // if not cuda support, always false for now. + return false; + }; + // cuda + if require_fa2_capabilities { + cfg!(feature = "flash-attn") && flash_attn_enabled + } else { + cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) && flash_attn_enabled + } +} + pub struct CandleBackend { device: Device, model: Box, @@ -307,12 +325,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::Bert(config), Device::Cuda(_)) => { - if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - && dtype == DType::F16 - // Allow disabling because of flash attention v1 precision problems - // See: https://github.com/huggingface/text-embeddings-inference/issues/37 - && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" - { + if dtype == DType::F16 && should_use_flash_attention(false) { match config { BertConfigWrapper::JinaBert(config) => { tracing::info!("Starting FlashJinaBert model on {:?}", device); @@ -354,13 +367,8 @@ impl CandleBackend { ( Config::Camembert(config) | Config::Roberta(config) | Config::XlmRoberta(config), Device::Cuda(_), - ) => { - if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - && dtype == DType::F16 - // Allow disabling because of flash attention v1 precision problems - // See: https://github.com/huggingface/text-embeddings-inference/issues/37 - && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" - { +) => { + if dtype == DType::F16 && should_use_flash_attention(false) { tracing::info!("Starting FlashBert model on {:?}", device); Ok(Box::new( FlashBertModel::load_roberta(vb, &config, model_type).s()?, @@ -374,12 +382,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::DistilBert(config), Device::Cuda(_)) => { - if cfg!(feature = "flash-attn") - && dtype == DType::F16 - && &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - == "true" + if dtype == DType::F16 && should_use_flash_attention(true) { tracing::info!("Starting FlashDistilBert model on {:?}", device); Ok(Box::new( @@ -405,12 +408,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::Gte(config), Device::Cuda(_)) => { - if dtype != DType::F16 - || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - || &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - != "true" + if dtype != DType::F16 || !should_use_flash_attention(false) { tracing::info!("Starting GTE model on {:?}", device); Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?)) @@ -421,13 +419,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::Mistral(config), Device::Cuda(_)) => { - if dtype != DType::F16 - || !cfg!(feature = "flash-attn") - || get_runtime_compute_cap().unwrap() < 80 - || &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - != "true" + if dtype != DType::F16 || !should_use_flash_attention(true) { return Err(BackendError::Start("Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string())); } @@ -438,11 +430,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::ModernBert(config), Device::Cuda(_)) => { - if cfg!(feature = "flash-attn") - && dtype == DType::F16 - // Allow disabling because of flash attention v1 precision problems - // See: https://github.com/huggingface/text-embeddings-inference/issues/37 - && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" + if dtype == DType::F16 && should_use_flash_attention(true) { tracing::info!("Starting FlashModernBert model on {:?}", device); Ok(Box::new( @@ -459,12 +447,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::NomicBert(config), Device::Cuda(_)) => { - if cfg!(feature = "flash-attn") - && dtype == DType::F16 - && &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - == "true" + if dtype == DType::F16 && should_use_flash_attention(true) { tracing::info!("Starting FlashNomicBert model on {:?}", device); Ok(Box::new( @@ -477,12 +460,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::Qwen2(config), Device::Cuda(_)) => { - if dtype != DType::F16 - || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - || &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - != "true" + if dtype != DType::F16 || !should_use_flash_attention(false) { return Err(BackendError::Start("Qwen2 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string())); } @@ -493,12 +471,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::Qwen3(config), Device::Cuda(_)) => { - if dtype != DType::F16 - || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - || &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - != "true" + if dtype != DType::F16 || !should_use_flash_attention(false) { tracing::info!("Starting Qwen3 model on {:?}", device); Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?)) From 6679919462d31055f415c26e1971aa55ffb89ff9 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 18 Dec 2025 00:28:58 +0000 Subject: [PATCH 2/3] add varlen attention interface --- backends/candle/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 7e1ab15a5..1d4f77a83 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -121,7 +121,7 @@ enum Config { fn should_use_flash_attention(require_fa2_capabilities: bool) -> bool { let flash_attn_enabled = &std::env::var("USE_FLASH_ATTENTION").unwrap_or("true".to_string()).to_lowercase() == "true"; - #cfg!(feature = "cuda") else { + if cfg!(not(feature = "cuda")) { // if not cuda support, always false for now. return false; }; From 6e5b16393983bf53bc1d4624cd6e9b722718eaef Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:37:16 -0800 Subject: [PATCH 3/3] Update lib.rs --- backends/candle/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 1d4f77a83..206e56e3a 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -3,7 +3,6 @@ mod alibi; mod compute_cap; #[cfg(feature = "cuda")] mod flash_attn; -mod flash_attn_cpu; mod layers; mod models;