Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 27 additions & 55 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ enum Config {
XlmRoberta(BertConfig),
}

/// Helper function to check if flash attention should be used
/// require_fa2_capabilities: true for models that need flash attention v2, false for v1/v2 compatible models
fn should_use_flash_attention(require_fa2_capabilities: bool) -> bool {
let flash_attn_enabled = &std::env::var("USE_FLASH_ATTENTION").unwrap_or("true".to_string()).to_lowercase() == "true";

if cfg!(not(feature = "cuda")) {
// if not cuda support, always false for now.
return false;
};
// cuda
if require_fa2_capabilities {
cfg!(feature = "flash-attn") && flash_attn_enabled
} else {
cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) && flash_attn_enabled
}
}

pub struct CandleBackend {
device: Device,
model: Box<dyn Model + Send>,
Expand Down Expand Up @@ -307,12 +324,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
if dtype == DType::F16 && should_use_flash_attention(false) {
match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting FlashJinaBert model on {:?}", device);
Expand Down Expand Up @@ -354,13 +366,8 @@ impl CandleBackend {
(
Config::Camembert(config) | Config::Roberta(config) | Config::XlmRoberta(config),
Device::Cuda(_),
) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
) => {
if dtype == DType::F16 && should_use_flash_attention(false) {
tracing::info!("Starting FlashBert model on {:?}", device);
Ok(Box::new(
FlashBertModel::load_roberta(vb, &config, model_type).s()?,
Expand All @@ -374,12 +381,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::DistilBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
&& &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
== "true"
if dtype == DType::F16 && should_use_flash_attention(true)
{
tracing::info!("Starting FlashDistilBert model on {:?}", device);
Ok(Box::new(
Expand All @@ -405,12 +407,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::Gte(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
|| &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
!= "true"
if dtype != DType::F16 || !should_use_flash_attention(false)
{
tracing::info!("Starting GTE model on {:?}", device);
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
Expand All @@ -421,13 +418,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::Mistral(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(feature = "flash-attn")
|| get_runtime_compute_cap().unwrap() < 80
|| &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
!= "true"
if dtype != DType::F16 || !should_use_flash_attention(true)
{
return Err(BackendError::Start("Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
}
Expand All @@ -438,11 +429,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::ModernBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
if dtype == DType::F16 && should_use_flash_attention(true)
{
tracing::info!("Starting FlashModernBert model on {:?}", device);
Ok(Box::new(
Expand All @@ -459,12 +446,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::NomicBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
&& &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
== "true"
if dtype == DType::F16 && should_use_flash_attention(true)
{
tracing::info!("Starting FlashNomicBert model on {:?}", device);
Ok(Box::new(
Expand All @@ -477,12 +459,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::Qwen2(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
|| &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
!= "true"
if dtype != DType::F16 || !should_use_flash_attention(false)
{
return Err(BackendError::Start("Qwen2 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
}
Expand All @@ -493,12 +470,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::Qwen3(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
|| &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
!= "true"
if dtype != DType::F16 || !should_use_flash_attention(false)
{
tracing::info!("Starting Qwen3 model on {:?}", device);
Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))
Expand Down