diff --git a/Cargo.lock b/Cargo.lock index 6b51b64..7de4c63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -260,7 +260,7 @@ dependencies = [ [[package]] name = "boringtun" version = "0.6.0" -source = "git+https://github.com/DefGuard/wireguard-rs?rev=886186c1e088e4805ab8049436c28cf3ea26d727#886186c1e088e4805ab8049436c28cf3ea26d727" +source = "git+https://github.com/DefGuard/wireguard-rs?rev=c4d1f69585b5fc5c4f21d9d98b1fd9cc96d4d24d#c4d1f69585b5fc5c4f21d9d98b1fd9cc96d4d24d" dependencies = [ "aead", "base64", @@ -328,9 +328,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.48" +version = "1.2.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c481bdbf0ed3b892f6f806287d72acd515b352a4ec27a208489b8c1bc839633a" +checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215" dependencies = [ "find-msvc-tools", "jobserver", @@ -573,7 +573,7 @@ dependencies = [ [[package]] name = "defguard_version" version = "0.0.0" -source = "git+https://github.com/DefGuard/defguard.git?rev=8649a9ba225d7bd2066a09c9e1347705c34bd158#8649a9ba225d7bd2066a09c9e1347705c34bd158" +source = "git+https://github.com/DefGuard/defguard.git?rev=640bae9a0aea1e11395f0a29fb8c84eeefd7f115#640bae9a0aea1e11395f0a29fb8c84eeefd7f115" dependencies = [ "axum", "http", @@ -590,7 +590,7 @@ dependencies = [ [[package]] name = "defguard_wireguard_rs" version = "0.9.0" -source = "git+https://github.com/DefGuard/wireguard-rs?rev=886186c1e088e4805ab8049436c28cf3ea26d727#886186c1e088e4805ab8049436c28cf3ea26d727" +source = "git+https://github.com/DefGuard/wireguard-rs?rev=c4d1f69585b5fc5c4f21d9d98b1fd9cc96d4d24d#c4d1f69585b5fc5c4f21d9d98b1fd9cc96d4d24d" dependencies = [ "base64", "boringtun", @@ -604,6 +604,7 @@ dependencies = [ "netlink-packet-wireguard", "netlink-sys", "nix", + "regex", "serde", "thiserror 2.0.17", "windows", @@ -870,9 +871,9 @@ dependencies = [ [[package]] name = "git2" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2deb07a133b1520dc1a5690e9bd08950108873d7ed5de38dcc74d3b5ebffa110" +checksum = "3e2b37e2f62729cdada11f0e6b3b6fe383c69c29fc619e391223e12856af308c" dependencies = [ "bitflags", "libc", @@ -1289,9 +1290,9 @@ checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "libgit2-sys" -version = "0.18.2+1.9.1" +version = "0.18.3+1.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c42fe03df2bd3c53a3a9c7317ad91d80c81cd1fb0caec8d7cc4cd2bfa10c222" +checksum = "c9b3acc4b91781bb0b3386669d325163746af5f6e4f73e6d2d630e09a35f3487" dependencies = [ "cc", "libc", @@ -2329,9 +2330,9 @@ dependencies = [ [[package]] name = "simd-adler32" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" [[package]] name = "siphasher" diff --git a/Cargo.toml b/Cargo.toml index b1ad159..3f5022b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,8 @@ edition = "2024" axum = "0.8" base64 = "0.22" clap = { version = "4.5", features = ["derive", "env"] } -defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "8649a9ba225d7bd2066a09c9e1347705c34bd158" } -defguard_wireguard_rs = { git = "https://github.com/DefGuard/wireguard-rs", rev = "886186c1e088e4805ab8049436c28cf3ea26d727" } +defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "640bae9a0aea1e11395f0a29fb8c84eeefd7f115" } +defguard_wireguard_rs = { git = "https://github.com/DefGuard/wireguard-rs", rev = "c4d1f69585b5fc5c4f21d9d98b1fd9cc96d4d24d" } env_logger = "0.11" gethostname = "1.0" ipnetwork = "0.21" @@ -27,6 +27,7 @@ toml = { version = "0.9", default-features = false, features = [ tonic = { version = "0.14", default-features = false, features = [ "codegen", "gzip", + "router", "tls-native-roots", "tls-ring", ] } diff --git a/proto b/proto index 883487d..7137ff1 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 883487df67d90fd14fae900737cd8b5ea6c10de3 +Subproject commit 7137ff12807ab8fd807e2439d0812f1d2a5f5055 diff --git a/src/config.rs b/src/config.rs index 00445ff..7bb0050 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use std::{fs, net::IpAddr, path::PathBuf}; +use std::{fs, net::IpAddr, path::PathBuf, time::Duration}; use clap::Parser; use serde::Deserialize; @@ -36,26 +36,23 @@ pub struct Config { #[arg(long, env = "DEFGUARD_GATEWAY_NAME")] pub name: Option, - /// defguard server gRPC endpoint URL - #[arg( - long, - short = 'g', - required_unless_present = "config_path", - env = "DEFGUARD_GRPC_URL", - default_value = "" - )] - #[serde(default)] - pub grpc_url: String, + /// Gateway gRPC server port. + #[arg(long, env = "DEFGUARD_GRPC_PORT", default_value = "50066")] + pub(crate) grpc_port: u16, + + /// Gateway gRPC server certificate. + #[arg(long, env = "DEFGUARD_GATEWAY_GRPC_CERT")] + pub(crate) grpc_cert: Option, + + /// Gateway gRPC server private key. + #[arg(long, env = "DEFGUARD_GATEWAY_GRPC_KEY")] + pub(crate) grpc_key: Option, /// Use userspace WireGuard implementation e.g. wireguard-go #[arg(long, short = 'u', env = "DEFGUARD_USERSPACE")] pub userspace: bool, - /// Path to CA file - #[arg(long, env = "DEFGUARD_GRPC_CA")] - pub grpc_ca: Option, - - /// Defines how often (in seconds) interface statistics are sent to Defguard server + /// Defines how often (in seconds) interface statistics are sent to Defguard Core. #[arg(long, short = 'p', env = "DEFGUARD_STATS_PERIOD", default_value = "30")] pub stats_period: u64, @@ -100,9 +97,9 @@ pub struct Config { /// Command to run after bringing down the interface. #[arg(long, env = "POST_DOWN")] pub post_down: Option, - /// A HTTP port that will expose the REST HTTP gateway health status - /// 200 Gateway is working and is connected to CORE - /// 503 - gateway works but is not connected to CORE + /// HTTP port that will expose the REST Gateway health status endpoint. + /// 200: Gateway is working and is connected to Core + /// 503: Gateway is working, but is not connected to Core #[arg(long, env = "HEALTH_PORT")] pub health_port: Option, @@ -125,15 +122,23 @@ pub struct Config { pub http_bind_address: Option, } +impl Config { + #[must_use] + pub fn stats_period(&self) -> Duration { + Duration::from_secs(self.stats_period) + } +} + impl Default for Config { fn default() -> Self { Self { log_level: "info".into(), token: "TOKEN".into(), name: None, - grpc_url: "http://localhost:50051".into(), + grpc_port: 50066, userspace: false, - grpc_ca: None, + grpc_cert: None, + grpc_key: None, stats_period: 15, ifname: "wg0".into(), pidfile: None, diff --git a/src/gateway.rs b/src/gateway.rs index 78fe18c..637fd51 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1,35 +1,29 @@ use defguard_version::{ - ComponentInfo, DefguardComponent, Version, client::ClientVersionInterceptor, - get_tracing_variables, + ComponentInfo, DefguardComponent, Version, get_tracing_variables, server::DefguardVersionLayer, }; use defguard_wireguard_rs::{WireguardInterfaceApi, net::IpAddrMask}; use gethostname::gethostname; use std::{ collections::HashMap, fs::read_to_string, + net::{IpAddr, Ipv4Addr, SocketAddr}, str::FromStr, sync::{ Arc, Mutex, - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, }, time::{Duration, SystemTime}, }; -use tokio::{ - select, - sync::mpsc, - task::{JoinHandle, spawn}, - time::{interval, sleep}, -}; +use tokio::{sync::mpsc, time::interval}; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{ - Request, Status, Streaming, - codegen::InterceptedService, + Request, Response, Status, Streaming, metadata::{Ascii, MetadataValue}, - service::{Interceptor, InterceptorLayer}, - transport::{Certificate, Channel, ClientTlsConfig, Endpoint}, + service::Interceptor, + transport::{Identity, Server, ServerTlsConfig}, }; use tower::ServiceBuilder; -use tracing::{Instrument, instrument}; +use tracing::instrument; use crate::{ VERSION, @@ -41,15 +35,13 @@ use crate::{ error::GatewayError, execute_command, mask, proto::gateway::{ - Configuration, ConfigurationRequest, Peer, StatsUpdate, Update, - gateway_service_client::GatewayServiceClient, stats_update::Payload, update, + Configuration, ConfigurationRequest, CoreRequest, CoreResponse, Peer, Update, core_request, + core_response, gateway_server, update, }, - version::ensure_core_version_supported, + version::is_core_version_supported, }; -const TEN_SECS: Duration = Duration::from_secs(10); - -// helper struct which stores just the interface config without peers +// Helper struct which stores just the interface config without peers. #[derive(Clone, PartialEq)] struct InterfaceConfiguration { name: String, @@ -75,7 +67,9 @@ impl From for InterfaceConfiguration { } } -/// Intercepts all grpc requests adding authentication and version metadata +type ClientMap = HashMap>>; + +/// Intercepts all gRPC requests adding authentication and version metadata. struct AuthInterceptor { hostname: MetadataValue, token: MetadataValue, @@ -105,7 +99,7 @@ impl AuthInterceptor { impl Interceptor for AuthInterceptor { fn call(&mut self, mut request: Request<()>) -> Result, Status> { - // Add auth headers + // Add authorisation headers. let metadata = request.metadata_mut(); metadata.insert("authorization", self.token.clone()); metadata.insert("hostname", self.hostname.clone()); @@ -115,23 +109,17 @@ impl Interceptor for AuthInterceptor { } type PubKey = String; -type GatewayClientType = GatewayServiceClient< - InterceptedService, ClientVersionInterceptor>, ->; pub struct Gateway { config: Config, interface_configuration: Option, peers: HashMap, wgapi: Arc>, - #[cfg_attr(not(target_os = "linux"), allow(unused))] firewall_api: FirewallApi, - #[cfg_attr(not(target_os = "linux"), allow(unused))] firewall_config: Option, pub connected: Arc, - client: GatewayClientType, - core_info: Option, - stats_thread: Option>, + // TODO: allow only one client. + pub(super) clients: ClientMap, } impl Gateway { @@ -140,22 +128,19 @@ impl Gateway { wgapi: impl WireguardInterfaceApi + Send + Sync + 'static, firewall_api: FirewallApi, ) -> Result { - let client = Self::setup_client(&config)?; Ok(Self { config, interface_configuration: None, peers: HashMap::new(), wgapi: Arc::new(Mutex::new(wgapi)), - connected: Arc::new(AtomicBool::new(false)), - client, - stats_thread: None, firewall_api, firewall_config: None, - core_info: None, + connected: Arc::new(AtomicBool::new(false)), + clients: ClientMap::new(), }) } - // replace current peer map with a new list of peers + // Replace current peer map with a new list of peers. fn replace_peers(&mut self, new_peers: Vec) { debug!("Replacing stored peers with {} new peers", new_peers.len()); let peers = new_peers @@ -165,7 +150,7 @@ impl Gateway { self.peers = peers; } - // check if new received configuration is different than current one + // Check if new received configuration is different than current one. fn is_interface_config_changed( &self, new_interface_configuration: &InterfaceConfiguration, @@ -178,7 +163,7 @@ impl Gateway { true } - // check if new peers are the same as the stored ones + // Check if new peers are the same as the stored ones. fn is_peer_list_changed(&self, new_peers: &[Peer]) -> bool { // check if number of peers is different if self.peers.len() != new_peers.len() { @@ -194,7 +179,7 @@ impl Gateway { return true; } - // check if all IPs are the same + // Check if all IP addresses are the same. !new_peers.iter().all(|peer| { self.peers .get(&peer.pubkey) @@ -202,89 +187,7 @@ impl Gateway { }) } - /// Starts tokio thread collecting stats and sending them to backend service via gRPC. - #[instrument(skip_all)] - fn spawn_stats_thread(&mut self) -> UnboundedReceiverStream { - if let Some(handle) = self.stats_thread.take() { - debug!("Aborting previous stats thread before starting a new one"); - handle.abort(); - } - // Create an async stream that periodically yields WireGuard interface statistics. - let period = Duration::from_secs(self.config.stats_period); - let wgapi = Arc::clone(&self.wgapi); - let (tx, rx) = mpsc::unbounded_channel(); - debug!("Spawning stats thread"); - let handle = spawn( - async move { - // helper map to track if peer data is actually changing - // and avoid sending duplicate stats - let mut peer_map = HashMap::new(); - let mut interval = interval(period); - let mut id = 1; - 'outer: loop { - // wait until next iteration - interval.tick().await; - debug!("Sending active peer stats updates."); - let interface_data = wgapi.lock().unwrap().read_interface_data(); - match interface_data { - Ok(host) => { - let peers = host.peers; - debug!( - "Found {} peers configured on WireGuard interface", - peers.len() - ); - for peer in peers.into_values().filter(|p| { - p.last_handshake - .is_some_and(|lhs| lhs != SystemTime::UNIX_EPOCH) - }) { - let has_changed = peer_map - .get(&peer.public_key) - .is_none_or(|last_peer| *last_peer != peer); - if has_changed { - peer_map.insert(peer.public_key.clone(), peer.clone()); - id += 1; - if tx - .send(StatsUpdate { - id, - payload: Some(Payload::PeerStats((&peer).into())), - }) - .is_err() - { - debug!("Stats stream disappeared"); - break 'outer; - } - } else { - debug!( - "Stats for peer {} have not changed. Skipping.", - peer.public_key - ); - } - } - } - Err(err) => error!("Failed to retrieve WireGuard interface stats: {err}"), - } - debug!("Sent peer stats updates for all peers."); - } - } - .instrument(tracing::Span::current()), - ); - self.stats_thread = Some(handle); - UnboundedReceiverStream::new(rx) - } - - #[instrument(skip_all)] - async fn handle_stats_thread( - mut client: GatewayClientType, - rx: UnboundedReceiverStream, - ) { - let status = client.stats(rx).await; - match status { - Ok(_) => info!("Stats thread terminated successfully."), - Err(err) => error!("Stats thread terminated with error: {err}"), - } - } - - /// Checks whether the firewall config changed + /// Checks whether the firewall config have changed. fn has_firewall_config_changed(&self, new_fw_config: &FirewallConfig) -> bool { if let Some(current_config) = &self.firewall_config { return current_config.default_policy != new_fw_config.default_policy @@ -459,7 +362,7 @@ impl Gateway { ); } - // process received firewall config unless firewall management is disabled + // Process received firewall configuration, unless firewall management is disabled. if self.config.disable_firewall_management { debug!("Firewall management is disabled. Skipping updating firewall configuration"); } else { @@ -476,208 +379,118 @@ impl Gateway { Ok(()) } - /// Continuously tries to connect to gRPC endpoint. Once the connection is established - /// configures the interface, starts the stats thread, connects and returns the updates stream. - async fn connect(&mut self) -> Streaming { - // set diconnected if we are in this function and drop mutex - self.connected.store(false, Ordering::Relaxed); - loop { - debug!( - "Connecting to Defguard gRPC endpoint: {}", - self.config.grpc_url - ); - let (response, stream) = { - let response = self - .client - .config(ConfigurationRequest { - name: self.config.name.clone(), - }) - .await; - let stream = self.client.updates(()).await; - (response, stream) - }; - match (response, stream) { - (Ok(response), Ok(stream)) => { - self.core_info = ComponentInfo::from_metadata(response.metadata()); - let (version, info) = get_tracing_variables(&self.core_info); - let span = tracing::info_span!( - "core_configuration", - component = %DefguardComponent::Core, - version = version.to_string(), - info - ); - let _guard = span.enter(); - - // check core version and exit if it's not supported - let version = self.core_info.as_ref().map(|info| &info.version); - ensure_core_version_supported(version); + /// Send message to all connected clients. + fn broadcast_to_clients(&self, message: &CoreRequest) { + for (addr, tx) in &self.clients { + if tx.send(Ok(message.clone())).is_err() { + debug!("Failed to send message to {addr}"); + } + } + } - if let Err(err) = self.configure(response.into_inner()) { - error!("Interface configuration failed: {err}"); - continue; + #[instrument(skip_all)] + fn handle_updates(&mut self, update: Update) { + debug!("Received update: {update:?}"); + match update.update { + Some(update::Update::Network(configuration)) => { + if let Err(err) = self.configure(configuration) { + error!("Failed to update network configuration: {err}"); + } + } + Some(update::Update::Peer(peer_config)) => { + debug!("Applying peer configuration: {peer_config:?}"); + // UpdateType::Delete + if update.update_type == 2 { + debug!("Deleting peer {peer_config:?}"); + self.peers.remove(&peer_config.pubkey); + if let Err(err) = + self.wgapi.lock().unwrap().remove_peer( + &peer_config.pubkey.as_str().try_into().unwrap_or_default(), + ) + { + error!("Failed to delete peer: {err}"); } - info!( - "Connected to Defguard gRPC endpoint: {}", - self.config.grpc_url - ); - self.connected.store(true, Ordering::Relaxed); - break stream.into_inner(); } - (Err(err), _) => { - error!( - "Couldn't retrieve gateway configuration from the core. Using gRPC URL: \ - {}. Retrying in 10s. Error: {err}", - self.config.grpc_url + // UpdateType::Create, UpdateType::Modify + else { + debug!( + "Updating peer {peer_config:?}, update type: {}", + update.update_type ); + self.peers + .insert(peer_config.pubkey.clone(), peer_config.clone()); + if let Err(err) = self + .wgapi + .lock() + .unwrap() + .configure_peer(&peer_config.into()) + { + error!("Failed to update peer: {err}"); + } } - (_, Err(err)) => { - error!( - "Couldn't establish streaming connection to the core. Using gRPC URL: \ - {}. Retrying in 10s. Error: {err}", - self.config.grpc_url + } + Some(update::Update::FirewallConfig(config)) => { + if self.config.disable_firewall_management { + debug!( + "Received firewall config update, but firewall management is disabled. \ + Skipping processing this update: {config:?}" ); + return; } - } - sleep(TEN_SECS).await; - } - } - fn setup_client(config: &Config) -> Result { - debug!("Preparing gRPC client configuration"); - let tls = ClientTlsConfig::new(); - // Use CA if provided, otherwise load certificates from system. - let tls = if let Some(ca) = &config.grpc_ca { - let ca = read_to_string(ca).map_err(|err| { - error!("Failed to read CA file: {err}"); - GatewayError::InvalidCaFile - })?; - tls.ca_certificate(Certificate::from_pem(ca)) - } else { - tls.with_enabled_roots() - }; - let endpoint = Endpoint::from_shared(config.grpc_url.clone())? - .http2_keep_alive_interval(TEN_SECS) - .tcp_keepalive(Some(TEN_SECS)) - .keep_alive_while_idle(true) - .tls_config(tls)?; - let channel = endpoint.connect_lazy(); - let version_interceptor = ClientVersionInterceptor::new(Version::parse(VERSION)?); - let auth_interceptor = AuthInterceptor::new(&config.token)?; - let channel = ServiceBuilder::new() - .layer(InterceptorLayer::new(version_interceptor)) - .layer(InterceptorLayer::new(auth_interceptor)) - .service(channel); - let client = GatewayServiceClient::new(channel); - - debug!("gRPC client configuration done"); - Ok(client) - } - - #[instrument(skip_all)] - async fn handle_updates(&mut self, updates_stream: &mut Streaming) { - loop { - match updates_stream.message().await { - Ok(Some(update)) => { - debug!("Received update: {update:?}"); - match update.update { - Some(update::Update::Network(configuration)) => { - if let Err(err) = self.configure(configuration) { - error!("Failed to update network configuration: {err}"); - } - } - Some(update::Update::Peer(peer_config)) => { - debug!("Applying peer configuration: {peer_config:?}"); - // UpdateType::Delete - if update.update_type == 2 { - debug!("Deleting peer {peer_config:?}"); - self.peers.remove(&peer_config.pubkey); - if let Err(err) = self.wgapi.lock().unwrap().remove_peer( - &peer_config.pubkey.as_str().try_into().unwrap_or_default(), - ) { - error!("Failed to delete peer: {err}"); - } - } - // UpdateType::Create, UpdateType::Modify - else { - debug!( - "Updating peer {peer_config:?}, update type: {}", - update.update_type - ); - self.peers - .insert(peer_config.pubkey.clone(), peer_config.clone()); - if let Err(err) = self - .wgapi - .lock() - .unwrap() - .configure_peer(&peer_config.into()) - { - error!("Failed to update peer: {err}"); - } - } + debug!("Applying received firewall configuration: {config:?}"); + let config_str = format!("{config:?}"); + match FirewallConfig::from_proto(config) { + Ok(new_firewall_config) => { + debug!( + "Parsed the received firewall configuration: {new_firewall_config:?}, \ + processing it and applying changes" + ); + if let Err(err) = self.process_firewall_changes(Some(&new_firewall_config)) + { + error!("Failed to process received firewall configuration: {err}"); } - Some(update::Update::FirewallConfig(config)) => { - if self.config.disable_firewall_management { - debug!( - "Received firewall config update, but firewall management \ - is disabled. Skipping processing this update: {config:?}" - ); - continue; - } - - debug!("Applying received firewall configuration: {config:?}"); - let config_str = format!("{config:?}"); - match FirewallConfig::from_proto(config) { - Ok(new_firewall_config) => { - debug!( - "Parsed the received firewall configuration: \ - {new_firewall_config:?}, processing it and applying \ - changes" - ); - if let Err(err) = - self.process_firewall_changes(Some(&new_firewall_config)) - { - error!( - "Failed to process received firewall configuration: \ - {err}" - ); - } - } - Err(err) => { - error!( - "Failed to parse received firewall configuration: {err}. \ - Configuration: {config_str}" - ); - } - } - } - Some(update::Update::DisableFirewall(())) => { - if self.config.disable_firewall_management { - debug!( - "Received firewall disable request, but firewall management \ - is disabled. Skipping processing this update" - ); - continue; - } - - debug!("Disabling firewall configuration"); - if let Err(err) = self.process_firewall_changes(None) { - error!("Failed to disable firewall configuration: {err}"); - } - } - _ => warn!("Unsupported kind of update: {update:?}"), + } + Err(err) => { + error!( + "Failed to parse received firewall configuration: {err}. \ + Configuration: {config_str}" + ); } } - Ok(None) => { - break; - } - Err(err) => { - error!( - "Disconnected from Defguard gRPC endoint: {}: {err}", - self.config.grpc_url + } + Some(update::Update::DisableFirewall(())) => { + if self.config.disable_firewall_management { + debug!( + "Received firewall disable request, but firewall management is disabled. \ + Skipping processing this update" ); - break; + return; + } + + debug!("Disabling firewall configuration"); + if let Err(err) = self.process_firewall_changes(None) { + error!("Failed to disable firewall configuration: {err}"); } } + _ => warn!("Unsupported kind of update: {update:?}"), + } + } +} + +pub struct GatewayServer { + auth_token: String, + message_id: AtomicU64, + gateway: Arc>, +} + +impl GatewayServer { + #[must_use] + pub fn new(auth_token: String, gateway: Arc>) -> Self { + Self { + auth_token, + message_id: AtomicU64::new(0), + gateway, } } @@ -685,57 +498,252 @@ impl Gateway { /// * Retrieves configuration and configuration updates from Defguard gRPC server /// * Manages the interface according to configuration and updates /// * Sends interface statistics to Defguard server periodically - pub async fn start(&mut self) -> Result<(), GatewayError> { + pub async fn start(self, config: Config) -> Result<(), GatewayError> { info!( - "Starting Defguard gateway version {VERSION} with configuration: {:?}", - mask!(self.config, token) + "Starting Defguard Gateway version {VERSION} with configuration: {:?}", + mask!(config, token) ); // Try to create network interface for WireGuard. // FIXME: check if the interface already exists, or somehow be more clever. - if let Err(err) = self.wgapi.lock().unwrap().create_interface() { - warn!( - "Couldn't create network interface {}: {err}. Proceeding anyway.", - self.config.ifname - ); - } else { - #[cfg(target_os = "linux")] - if !self.config.disable_firewall_management && self.config.masquerade { - self.firewall_api.begin()?; - self.firewall_api.setup_nat(self.config.masquerade, &[])?; - self.firewall_api.commit()?; + { + let gateway = &self.gateway.lock().expect("gateway mutex poison"); + if let Err(err) = gateway + .wgapi + .lock() + .expect("wgapi mutex poison") + .create_interface() + { + warn!( + "Couldn't create network interface {}: {err}. Proceeding anyway.", + config.ifname + ); + } else { + #[cfg(target_os = "linux")] + if !config.disable_firewall_management && config.masquerade { + gateway.firewall_api.begin()?; + gateway.firewall_api.setup_nat(config.masquerade, &[])?; + &gateway.firewall_api.commit()?; + } } } - info!( - "Trying to connect to {} and obtain the gateway configuration from Defguard.", - self.config.grpc_url + if let Some(post_up) = &config.post_up { + debug!("Executing specified POST_UP command: {post_up}"); + execute_command(post_up)?; + } + + // Optionally, read gRPC TLS certificate and key. + debug!("Configuring certificates for gRPC"); + let grpc_cert = config + .grpc_cert + .as_ref() + .and_then(|path| read_to_string(path).ok()); + let grpc_key = config + .grpc_key + .as_ref() + .and_then(|path| read_to_string(path).ok()); + debug!("Configured certificates for gRPC, cert: {grpc_cert:?}"); + + // Build gRPC server. + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), config.grpc_port); + info!("gRPC server is listening on {addr}"); + let mut builder = if let (Some(cert), Some(key)) = (grpc_cert, grpc_key) { + let identity = Identity::from_pem(cert, key); + Server::builder().tls_config(ServerTlsConfig::new().identity(identity))? + } else { + Server::builder() + }; + + // Start gRPC server. This should run indefinitely. + debug!("Serving gRPC"); + builder + .add_service( + ServiceBuilder::new() + // .layer(InterceptorLayer::new(CoreVersionInterceptor::new( + // MIN_CORE_VERSION, + // incompatible_components, + // ))) + .layer(DefguardVersionLayer::new(Version::parse(VERSION)?)) + .service(gateway_server::GatewayServer::new(self)), + ) + .serve(addr) + .await?; + + Ok(()) + } +} + +#[tonic::async_trait] +impl gateway_server::Gateway for GatewayServer { + type BidiStream = UnboundedReceiverStream>; + + /// Handle bidirectional communication with Defguard Core. + async fn bidi( + &self, + request: Request>, + ) -> Result, Status> { + let Some(address) = request.remote_addr() else { + error!("Failed to determine Defguard Core's address for request: {request:?}"); + return Err(Status::internal( + "Failed to determine Defguard Core's address", + )); + }; + info!("Defguard Core gRPC client connected from {address}"); + + let core_info = ComponentInfo::from_metadata(request.metadata()); + let (version, info) = get_tracing_variables(&core_info); + + // Tracing span. + let span = tracing::info_span!( + "core_communication", + component = %DefguardComponent::Core, + version = version.to_string(), + info ); - loop { - let mut updates_stream = self.connect().await; - if let Some(post_up) = &self.config.post_up { - debug!("Executing specified POST_UP command: {post_up}"); - execute_command(post_up)?; + let _guard = span.enter(); + + // Check Defguard Core's version and exit if it's not supported. + let version = core_info.as_ref().map(|info| &info.version); + if !is_core_version_supported(version) { + return Err(Status::internal("Unsupported Defguard Core version")); + } + + let (tx, rx) = mpsc::unbounded_channel(); + let Ok(hostname) = gethostname().into_string() else { + error!("Unable to get hostname"); + return Err(Status::internal("failed to get hostname")); + }; + + // First, send configuration request. + let payload = ConfigurationRequest { + name: None, // TODO: remove? + auth_token: self.auth_token.clone(), + hostname, + }; + let req = CoreRequest { + id: self.message_id.fetch_add(1, Ordering::Relaxed), + payload: Some(core_request::Payload::ConfigRequest(payload)), + }; + + match tx.send(Ok(req)) { + Ok(()) => info!("Requesting network configuration from {address}"), + Err(err) => { + error!("Unable to send network configuration request to {address}: {err}"); + return Err(Status::internal("failed to send configuration request")); } - let (version, info) = get_tracing_variables(&self.core_info); - let span = tracing::info_span!( - "core_grpc", - component = %DefguardComponent::Core, - version = version.to_string(), - info, - ); - let _guard = span.enter(); - let stats_stream = self.spawn_stats_thread(); - let client = self.client.clone(); - select! { - biased; - () = Self::handle_stats_thread(client, stats_stream) => { - error!("Stats stream aborted; reconnecting"); + } + + self.gateway.lock().unwrap().clients.insert(address, tx); + + let gateway = Arc::clone(&self.gateway); + let mut stream = request.into_inner(); + tokio::spawn(async move { + loop { + match stream.message().await { + Ok(Some(response)) => { + debug!("Received message from Defguard Core: {response:?}"); + // Discard empty payloads. + if let Some(payload) = response.payload { + match payload { + core_response::Payload::Config(configuration) => { + match gateway.lock() { + Ok(mut gw) => { + gw.connected.store(true, Ordering::Relaxed); + let _ = gw.configure(configuration); + } + Err(err) => error!("Lock failed: {err}"), + } + } + core_response::Payload::Update(update) => match gateway.lock() { + Ok(mut gw) => { + gw.handle_updates(update); + } + Err(err) => error!("Lock failed: {err}"), + }, + core_response::Payload::Empty(()) => (), + } + } + } + Ok(None) => { + info!("gRPC stream from Defguard Core has been closed"); + break; + } + Err(err) => { + error!("gRPC stream from Defguard Core failed with error: {err}"); + break; + } } - () = self.handle_updates(&mut updates_stream) => { - error!("Updates stream aborted; reconnecting"); + } + info!("Defguard Core gRPC stream has been disconnected: {address}"); + gateway + .lock() + .unwrap() + .connected + .store(false, Ordering::Relaxed); + gateway.lock().unwrap().clients.remove(&address); + }); + + Ok(Response::new(UnboundedReceiverStream::new(rx))) + } +} + +/// Gather WireGuard statistics and send them to Core through gRPC. +pub async fn run_stats(gateway: Arc>, period: Duration) -> Result<(), GatewayError> { + // Helper map to track if peer data is actually changing to avoid sending duplicate stats. + let mut peer_map = HashMap::new(); + let mut interval = interval(period); + let mut id = 1; + loop { + // Wait until next iteration. + interval.tick().await; + + debug!("Obtaining peer statistics from WireGuard"); + let result = gateway + .lock() + .expect("gateway mutex poison") + .wgapi + .lock() + .expect("wgapi mutex poison") + .read_interface_data(); + match result { + Ok(host) => { + let peers = host.peers; + debug!( + "Found {} peers configured on WireGuard interface", + peers.len() + ); + // Filter out never connected peers. + for peer in peers.into_values().filter(|p| { + p.last_handshake + .map_or(false, |last_hs| last_hs != SystemTime::UNIX_EPOCH) + }) { + let has_changed = match peer_map.get(&peer.public_key) { + Some(last_peer) => *last_peer != peer, + None => true, + }; + if has_changed { + peer_map.insert(peer.public_key.clone(), peer.clone()); + let payload = core_request::Payload::PeerStats((&peer).into()); + let message = CoreRequest { + id, + payload: Some(payload), + }; + id += 1; + gateway + .lock() + .expect("gateway mutex poison") + .broadcast_to_clients(&message); + debug!("Sent statistics for peer {}", peer.public_key); + } else { + debug!( + "Statistics for peer {} have not changed. Skipping.", + peer.public_key + ); + } } } + Err(err) => error!("Failed to retrieve WireGuard interface statistics: {err}"), } } } @@ -790,19 +798,16 @@ mod tests { let wgapi = WG::new("wg0").unwrap(); let config = Config::default(); - let client = Gateway::setup_client(&config).unwrap(); let firewall_api = FirewallApi::new("wg0").unwrap(); let gateway = Gateway { config, interface_configuration: Some(old_config.clone()), peers: old_peers_map, wgapi: Arc::new(Mutex::new(wgapi)), - connected: Arc::new(AtomicBool::new(false)), - client, - stats_thread: None, firewall_api, firewall_config: None, - core_info: None, + connected: Arc::new(AtomicBool::new(false)), + clients: ClientMap::new(), }; // new config is the same @@ -968,29 +973,26 @@ mod tests { let config1 = FirewallConfig { rules: vec![rule1.clone(), rule2.clone()], default_policy: Policy::Allow, - snat_bindings: vec![], + snat_bindings: Vec::new(), }; let config_empty = FirewallConfig { rules: Vec::new(), default_policy: Policy::Allow, - snat_bindings: vec![], + snat_bindings: Vec::new(), }; let wgapi = WG::new("wg0").unwrap(); let config = Config::default(); - let client = Gateway::setup_client(&config).unwrap(); let mut gateway = Gateway { config, interface_configuration: None, peers: HashMap::new(), wgapi: Arc::new(Mutex::new(wgapi)), - connected: Arc::new(AtomicBool::new(false)), - client, - stats_thread: None, firewall_api: FirewallApi::new("test_interface").unwrap(), firewall_config: None, - core_info: None, + connected: Arc::new(AtomicBool::new(false)), + clients: ClientMap::new(), }; // Gateway has no firewall config, new rules are empty @@ -1031,35 +1033,32 @@ mod tests { let config1 = FirewallConfig { rules: Vec::new(), default_policy: Policy::Allow, - snat_bindings: vec![], + snat_bindings: Vec::new(), }; let config2 = FirewallConfig { rules: Vec::new(), default_policy: Policy::Deny, - snat_bindings: vec![], + snat_bindings: Vec::new(), }; let config3 = FirewallConfig { rules: Vec::new(), default_policy: Policy::Allow, - snat_bindings: vec![], + snat_bindings: Vec::new(), }; let wgapi = WG::new("wg0").unwrap(); let config = Config::default(); - let client = Gateway::setup_client(&config).unwrap(); let mut gateway = Gateway { config, interface_configuration: None, peers: HashMap::new(), wgapi: Arc::new(Mutex::new(wgapi)), - connected: Arc::new(AtomicBool::new(false)), - client, - stats_thread: None, firewall_api: FirewallApi::new("test_interface").unwrap(), firewall_config: None, - core_info: None, + connected: Arc::new(AtomicBool::new(false)), + clients: ClientMap::new(), }; // Gateway has no config gateway.firewall_config = None; @@ -1086,7 +1085,7 @@ mod tests { ipv4: true, }], default_policy: Policy::Allow, - snat_bindings: vec![], + snat_bindings: Vec::new(), }; gateway.firewall_config = Some(config1); assert!(gateway.has_firewall_config_changed(&config4)); @@ -1104,7 +1103,7 @@ mod tests { ipv4: false, }], default_policy: Policy::Allow, - snat_bindings: vec![], + snat_bindings: Vec::new(), }; gateway.firewall_config = Some(config4); assert!(gateway.has_firewall_config_changed(&config5)); diff --git a/src/main.rs b/src/main.rs index 7ad00ac..ea2bf2b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,19 @@ -use std::{fs::File, io::Write, process, sync::Arc}; +use std::{ + fs::File, + io::Write, + process, + sync::{Arc, Mutex}, +}; use defguard_gateway::{ - VERSION, config::get_config, enterprise::firewall::api::FirewallApi, error::GatewayError, - execute_command, gateway::Gateway, init_syslog, server::run_server, + VERSION, + config::get_config, + enterprise::firewall::api::FirewallApi, + error::GatewayError, + execute_command, + gateway::{Gateway, GatewayServer, run_stats}, + init_syslog, + server::run_server, }; use defguard_version::Version; #[cfg(not(any(target_os = "macos", target_os = "netbsd")))] @@ -42,7 +53,7 @@ async fn main() -> Result<(), GatewayError> { let ifname = config.ifname.clone(); let firewall_api = FirewallApi::new(&ifname)?; - let mut gateway = if config.userspace { + let gateway = if config.userspace { let wgapi = WGApi::::new(ifname)?; Gateway::new(config.clone(), wgapi, firewall_api)? } else { @@ -58,7 +69,10 @@ async fn main() -> Result<(), GatewayError> { } }; + // Keep track of spawned tasks. let mut tasks = JoinSet::new(); + + // Optionally, launch HTTP server to report gateway's health. if let Some(health_port) = config.health_port { tasks.spawn(run_server( health_port, @@ -66,7 +80,15 @@ async fn main() -> Result<(), GatewayError> { Arc::clone(&gateway.connected), )); } - tasks.spawn(async move { gateway.start().await }); + + // Launch statistics gathering task. + let gateway = Arc::new(Mutex::new(gateway)); + tasks.spawn(run_stats(Arc::clone(&gateway), config.stats_period())); + + // Launch gRPC server. + let gateway_server = GatewayServer::new(config.token.clone(), gateway); + tasks.spawn(gateway_server.start(config.clone())); + while let Some(Ok(result)) = tasks.join_next().await { result?; } diff --git a/src/version.rs b/src/version.rs index 8e0b75f..3120c8d 100644 --- a/src/version.rs +++ b/src/version.rs @@ -1,23 +1,25 @@ use defguard_version::{Version, is_version_lower}; -const MIN_CORE_VERSION: Version = Version::new(1, 5, 0); +const MIN_CORE_VERSION: Version = Version::new(1, 6, 0); -/// Ensures the core version meets minimum version requirements. -/// Terminates the process if it doesn't. -pub(crate) fn ensure_core_version_supported(core_version: Option<&Version>) { +/// Checks if Defguard Core's version meets minimum version requirements. +pub(crate) fn is_core_version_supported(core_version: Option<&Version>) -> bool { let Some(core_version) = core_version else { error!( - "Missing core component version information. This most likely means that core component uses outdated version. Exiting." + "Missing Defguard Core version information. This most likely means that Defguard Core \ + uses outdated version." ); - std::process::exit(1); + return false; }; if is_version_lower(core_version, &MIN_CORE_VERSION) { error!( - "Core version {core_version} is not supported. Minimal supported core version is {MIN_CORE_VERSION}. Exiting." + "Defguard Core version {core_version} is not supported. Minimal supported version is \ + {MIN_CORE_VERSION}." ); - std::process::exit(1); + false + } else { + info!("Defguard Core version {core_version} is supported"); + true } - - info!("Core version {core_version} is supported"); }