diff --git a/src/bfd.rs b/src/bfd.rs new file mode 100644 index 0000000..b7d4fba --- /dev/null +++ b/src/bfd.rs @@ -0,0 +1,253 @@ +use std::{ffi::{OsStr, OsString}, fmt::{self, Display}, io::Cursor, ops::{Div, Mul}, time}; + +use bytemuck::NoUninit; +use byteorder::{BigEndian, WriteBytesExt}; +use nom::{bytes::complete::take, multi::many_m_n, number::complete::be_u8, IResult}; +use nom_derive::{NomBE, Parse}; +use proc_bitfield::*; + +#[repr(u8)] +#[derive(ConvRaw, Debug, NomBE, PartialEq, Eq, Clone, Copy, NoUninit)] +pub enum BfdDiagnostic { + None = 0, + TimeExpired = 1, + EchoFailed = 2, + NeighborDown = 3, + FwdPlaneReset = 4, + PathDown = 5, + ConcatPathDown = 6, + AdminDown = 7, + RevConcatPathDown = 8, + Reserved, +} +impl Display for BfdDiagnostic { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Self::None => "None", + Self::TimeExpired => "TimeExpired", + Self::EchoFailed => "EchoFailed", + Self::NeighborDown => "NeighborDown", + Self::FwdPlaneReset => "FwdPlaneReset", + Self::PathDown => "PathDown", + Self::ConcatPathDown => "ConcatPathDown", + Self::AdminDown => "AdminDown", + Self::RevConcatPathDown => "RevConcatPathDown", + Self::Reserved => "Reserved", + }) + } +} +#[repr(u8)] +#[derive(ConvRaw, Debug, NomBE, PartialEq, Eq, Default, Clone, Copy, NoUninit)] +pub enum BfdState { + AdminDown = 0, + #[default] + Down = 1, + Init = 2, + Up = 3, +} +impl Display for BfdState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + Self::AdminDown => "AdminDown", + Self::Down => "Down", + Self::Init => "Init", + Self::Up => "Up", + }) + } +} + +impl Into<&OsStr> for BfdState { + fn into(self) -> &'static OsStr { + match self { + Self::AdminDown => OsStr::new("AdminDown"), + Self::Init => OsStr::new("Init"), + Self::Down => OsStr::new("Down"), + Self::Up => OsStr::new("Up"), + } + } +} + +#[repr(u8)] +#[derive(ConvRaw, Debug, NomBE, PartialEq, Eq, Clone, Copy, NoUninit)] +pub enum BfdAuthType { + None = 0, + SimplePassword = 1, + KeyedMD5 = 2, + MetKeyedMD5 = 3, + KeyedSHA1 = 4, + MetKeyedSHA1 = 5, + Reserved, +} +#[repr(transparent)] +#[derive(Debug, NomBE, PartialEq, Eq, Clone, Copy, NoUninit, Hash)] +pub struct BfdDiscriminator(pub u32); +impl Display for BfdDiscriminator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[repr(transparent)] +#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord, NoUninit, NomBE)] +/// All intervals in BFD are specified in microseconds +pub struct BfdInterval(u32); + +impl BfdInterval { + pub fn from_micros(micros: u32) -> Self { + Self(micros) + } + pub fn from_millis(millis: u32) -> Self { + Self(millis * 1000) + } + pub fn from_secs(secs: u32) -> Self { + Self(secs * 1000000) + } + pub fn from_secs_f32(secs: f32) -> Self { + Self((secs * 1000000.0) as u32) + } +} + +impl From for time::Duration { + fn from(value: BfdInterval) -> Self { + time::Duration::from_micros(value.0 as u64) + } +} +impl From for BfdInterval { + fn from(value: time::Duration) -> Self { + BfdInterval(value.as_micros() as u32) + } +} +impl Display for BfdInterval { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}us", self.0) + } +} +impl Mul for BfdInterval { + type Output = BfdInterval; + fn mul(self, rhs: u32) -> Self::Output { + BfdInterval(self.0 * rhs) + } +} +impl Mul for u32 { + type Output = BfdInterval; + fn mul(self, rhs: BfdInterval) -> Self::Output { + BfdInterval(self * rhs.0) + } +} +impl Div for BfdInterval { + type Output = BfdInterval; + fn div(self, rhs: u32) -> Self::Output { + BfdInterval(self.0 / rhs) + } +} +impl Div for u32 { + type Output = BfdInterval; + fn div(self, rhs: BfdInterval) -> Self::Output { + BfdInterval(self / rhs.0) + } +} + +bitfield! { + #[derive(NomBE)] + pub struct BfdFlags(pub u32): Debug { + pub vers: u8 @ 29..=31, + pub diag: u8 [try_get BfdDiagnostic] @ 24..=28, + pub state: u8 [try_get BfdState] @ 22..=23, + pub poll: bool @ 21, + pub final_: bool @ 20, + pub cpi: bool @ 19, + pub auth_present: bool @ 18, + pub demand: bool @ 17, + pub multipoint: bool @ 16, + pub detect_mult: u8 @ 8..=15, + pub length: u8 @ 0..=7 + } +} + +#[derive(Debug)] +pub struct BfdAuthSimplePassword(Vec); +impl<'a> Parse<&'a [u8]> for BfdAuthSimplePassword { + fn parse(i: &'a [u8]) -> IResult<&'a [u8], Self, nom::error::Error<&'a [u8]>> { + let (i, res) = many_m_n(1, 16, be_u8)(i)?; + Ok((i, Self(res))) + } +} + +#[derive(Debug, NomBE)] +pub struct BfdAuthKeyedMD5 { + key_id: u8, + _reserved: u8, + seq: u32, + digest: [u8; 16], +} + +#[derive(Debug, NomBE)] +pub struct BfdAuthKeyedSHA1 { + key_id: u8, + _reserved: u8, + seq: u32, + hash: [u8; 20], +} + +#[derive(Debug, NomBE)] +#[nom(Selector = "BfdAuthType", Complete)] +pub enum BfdAuthData { + #[nom(Selector = "BfdAuthType::SimplePassword")] + SimplePassword(BfdAuthSimplePassword), + #[nom(Selector = "BfdAuthType::KeyedMD5")] + KeyedMD5(BfdAuthKeyedMD5), + #[nom(Selector = "BfdAuthType::MetKeyedMD5")] + MetKeyedMD5(BfdAuthKeyedMD5), + #[nom(Selector = "BfdAuthType::KeyedSHA1")] + KeyedSHA1(BfdAuthKeyedSHA1), + #[nom(Selector = "BfdAuthType::MetKeyedSHA1")] + MetKeyedSHA1(BfdAuthKeyedSHA1), +} + +impl BfdAuthData { + fn parse_be_with_length( + i: &[u8], + auth_type: BfdAuthType, + auth_len: u8, + ) -> IResult<&[u8], Self> { + let (new_i, data) = take(auth_len)(i)?; + let (_leftovers, retval) = BfdAuthData::parse_be(data, auth_type)?; + Ok((new_i, retval)) + } +} + +#[derive(Debug, NomBE)] +pub struct BfdAuth { + auth_type: BfdAuthType, + auth_len: u8, + #[nom(Parse = "{ |i| BfdAuthData::parse_be_with_length(i, auth_type, auth_len) }")] + auth_data: BfdAuthData, +} + +#[derive(Debug, NomBE)] +pub struct BfdPacket { + pub flags: BfdFlags, + pub my_disc: BfdDiscriminator, + pub your_disc: BfdDiscriminator, + pub desired_min_tx: BfdInterval, + pub required_min_rx: BfdInterval, + pub required_min_echo_rx: BfdInterval, + #[nom(Cond = "flags.auth_present()")] + pub auth: Option, +} + +impl BfdPacket { + pub fn serialize(&self) -> Result, std::io::Error> { + // TODO: serialize auth + let buf = [0u8; 24]; + let mut wtr = Cursor::new(buf); + wtr.write_u32::(self.flags.0)?; + wtr.write_u32::(self.my_disc.0)?; + wtr.write_u32::(self.your_disc.0)?; + wtr.write_u32::(self.desired_min_tx.0)?; + wtr.write_u32::(self.required_min_rx.0)?; + wtr.write_u32::(self.required_min_echo_rx.0)?; + + Ok(Box::new(wtr.into_inner())) + } +} diff --git a/src/main.rs b/src/bin/rust-bfdd.rs similarity index 69% rename from src/main.rs rename to src/bin/rust-bfdd.rs index 591222a..7ea9b84 100644 --- a/src/main.rs +++ b/src/bin/rust-bfdd.rs @@ -1,111 +1,43 @@ use futures::future; +use rust_bfd::events::{EventMessageSink, ScriptHookSink, StateChangeEvent}; +use rust_bfd::{set_ttl_or_hop_limit, BfdSessionStats}; +use rust_bfd::{bfd::*, events::EventMessage}; +use std::net::Ipv6Addr; use std::{ collections::HashMap, error::Error, - fmt::{self, Display}, - fs::read, - io::Cursor, + fmt::Display, + io::ErrorKind, mem::swap, net::{IpAddr, SocketAddr}, - num::NonZeroU32, - pin::Pin, str::FromStr, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - time::SystemTime, + sync::{atomic::Ordering, Arc}, + time::{Duration, Instant}, }; -use tokio::{ - sync::{oneshot, Mutex}, - time::{interval_at, Timeout}, -}; +// Necessary until a hop_limit method is added to UdpSocket +use nix::sys::socket::{setsockopt, sockopt}; + +use tokio::sync::{oneshot, Mutex}; -use byteorder::{BigEndian, WriteBytesExt}; use env_logger::Env; -use itertools::Itertools; use log::{debug, error, info, warn}; -use nom::{bytes::complete::take, multi::many_m_n, number::complete::be_u8, IResult}; -use nom_derive::{NomBE, Parse}; -use proc_bitfield::*; + +use nom_derive::Parse; use rand::prelude::*; +use tokio::io; +use tokio::net::UdpSocket; +use tokio::sync::mpsc; +use tokio::task::{self, JoinHandle}; use tokio::time; -use tokio::{io, join, task::JoinHandle}; -use tokio::{net::UdpSocket, sync::RwLock}; -use tokio::{sync::mpsc, time::Instant}; -use tokio::{task, time::Interval}; use atomic::Atomic; -use bytemuck::{NoUninit, Pod}; const CONTROL_PORT: u16 = 3784; const ECHO_PORT: u16 = 3785; const ORDERING: Ordering = Ordering::Relaxed; -#[repr(u8)] -#[derive(ConvRaw, Debug, NomBE, PartialEq, Eq, Clone, Copy, NoUninit)] -pub enum BfdDiagnostic { - None = 0, - TimeExpired = 1, - EchoFailed = 2, - NeighborDown = 3, - FwdPlaneReset = 4, - PathDown = 5, - ConcatPathDown = 6, - AdminDown = 7, - RevConcatPathDown = 8, - Reserved, -} -impl Display for BfdDiagnostic { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Self::None => "None", - Self::TimeExpired => "TimeExpired", - Self::EchoFailed => "EchoFailed", - Self::NeighborDown => "NeighborDown", - Self::FwdPlaneReset => "FwdPlaneReset", - Self::PathDown => "PathDown", - Self::ConcatPathDown => "ConcatPathDown", - Self::AdminDown => "AdminDown", - Self::RevConcatPathDown => "RevConcatPathDown", - Self::Reserved => "Reserved", - }) - } -} -#[repr(u8)] -#[derive(ConvRaw, Debug, NomBE, PartialEq, Eq, Default, Clone, Copy, NoUninit)] -pub enum BfdState { - AdminDown = 0, - #[default] - Down = 1, - Init = 2, - Up = 3, -} -impl Display for BfdState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(match self { - Self::AdminDown => "AdminDown", - Self::Down => "Down", - Self::Init => "Init", - Self::Up => "Up", - }) - } -} - -#[repr(u8)] -#[derive(ConvRaw, Debug, NomBE, PartialEq, Eq, Clone, Copy, NoUninit)] -pub enum BfdAuthType { - None = 0, - SimplePassword = 1, - KeyedMD5 = 2, - MetKeyedMD5 = 3, - KeyedSHA1 = 4, - MetKeyedSHA1 = 5, - Reserved, -} - #[derive(Debug)] pub enum BfdError { // field, value @@ -122,134 +54,6 @@ impl Display for BfdError { } impl Error for BfdError {} -#[repr(transparent)] -#[derive(Debug, NomBE, PartialEq, Eq, Clone, Copy, NoUninit, Hash)] -pub struct BfdDiscriminator(u32); -impl Display for BfdDiscriminator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } -} - -#[repr(transparent)] -#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord, NoUninit, NomBE)] -/// All intervals in BFD are specified in microseconds -pub struct BfdInterval(u32); -impl From for time::Duration { - fn from(value: BfdInterval) -> Self { - time::Duration::from_micros(value.0 as u64) - } -} -impl From for BfdInterval { - fn from(value: u32) -> Self { - BfdInterval(value) - } -} -bitfield! { - #[derive(NomBE)] - pub struct BfdFlags(pub u32): Debug { - pub vers: u8 @ 29..=31, - pub diag: u8 [try_get BfdDiagnostic] @ 24..=28, - pub state: u8 [try_get BfdState] @ 22..=23, - pub poll: bool @ 21, - pub final_: bool @ 20, - pub cpi: bool @ 19, - pub auth_present: bool @ 18, - pub demand: bool @ 17, - pub multipoint: bool @ 16, - pub detect_mult: u8 @ 8..=15, - pub length: u8 @ 0..=7 - } -} - -#[derive(Debug)] -pub struct BfdAuthSimplePassword(Vec); -impl<'a> Parse<&'a [u8]> for BfdAuthSimplePassword { - fn parse(i: &'a [u8]) -> IResult<&'a [u8], Self, nom::error::Error<&'a [u8]>> { - let (i, res) = many_m_n(1, 16, be_u8)(i)?; - Ok((i, Self(res))) - } -} - -#[derive(Debug, NomBE)] -pub struct BfdAuthKeyedMD5 { - key_id: u8, - _reserved: u8, - seq: u32, - digest: [u8; 16], -} - -#[derive(Debug, NomBE)] -pub struct BfdAuthKeyedSHA1 { - key_id: u8, - _reserved: u8, - seq: u32, - hash: [u8; 20], -} - -#[derive(Debug, NomBE)] -#[nom(Selector = "BfdAuthType", Complete)] -pub enum BfdAuthData { - #[nom(Selector = "BfdAuthType::SimplePassword")] - SimplePassword(BfdAuthSimplePassword), - #[nom(Selector = "BfdAuthType::KeyedMD5")] - KeyedMD5(BfdAuthKeyedMD5), - #[nom(Selector = "BfdAuthType::MetKeyedMD5")] - MetKeyedMD5(BfdAuthKeyedMD5), - #[nom(Selector = "BfdAuthType::KeyedSHA1")] - KeyedSHA1(BfdAuthKeyedSHA1), - #[nom(Selector = "BfdAuthType::MetKeyedSHA1")] - MetKeyedSHA1(BfdAuthKeyedSHA1), -} - -impl BfdAuthData { - fn parse_be_with_length( - i: &[u8], - auth_type: BfdAuthType, - auth_len: u8, - ) -> IResult<&[u8], Self> { - let (new_i, data) = take(auth_len)(i)?; - let (_leftovers, retval) = BfdAuthData::parse_be(data, auth_type)?; - Ok((new_i, retval)) - } -} - -#[derive(Debug, NomBE)] -pub struct BfdAuth { - auth_type: BfdAuthType, - auth_len: u8, - #[nom(Parse = "{ |i| BfdAuthData::parse_be_with_length(i, auth_type, auth_len) }")] - auth_data: BfdAuthData, -} - -#[derive(Debug, NomBE)] -pub struct BfdPacket { - flags: BfdFlags, - my_disc: BfdDiscriminator, - your_disc: BfdDiscriminator, - desired_min_tx: BfdInterval, - required_min_rx: BfdInterval, - required_min_echo_rx: BfdInterval, - #[nom(Cond = "flags.auth_present()")] - auth: Option, -} - -impl BfdPacket { - fn serialize(&self) -> Result, std::io::Error> { - // TODO: serialize auth - let buf = [0u8; 24]; - let mut wtr = Cursor::new(buf); - wtr.write_u32::(self.flags.0)?; - wtr.write_u32::(self.my_disc.0)?; - wtr.write_u32::(self.your_disc.0)?; - wtr.write_u32::(self.desired_min_tx.0)?; - wtr.write_u32::(self.required_min_rx.0)?; - wtr.write_u32::(self.required_min_echo_rx.0)?; - - Ok(Box::new(wtr.into_inner())) - } -} - /// Data structure to store the state of the Bfd machine. The impl does *not* /// implement any Bfd logic, it is merely a thread-safe data structure #[derive(Debug)] @@ -355,7 +159,7 @@ impl BfdSessionState { self.desired_min_tx_interval.load(ORDERING) } fn set_desired_min_tx_interval(&self, value: BfdInterval) { - self.desired_min_tx_interval.store(value, ORDERING) + self.desired_min_tx_interval.store(value, ORDERING); } fn required_min_rx_interval(&self) -> BfdInterval { self.required_min_rx_interval.load(ORDERING) @@ -398,30 +202,9 @@ struct BfdSession { tx: mpsc::Sender, rx_watchdog: Option>, cur_interval: BfdInterval, + event_dispatcher: EventDispatcherHandle, } -#[derive(Debug)] -struct BfdSessionStats { - local_ip: IpAddr, - remote_ip: IpAddr, - local_discr: BfdDiscriminator, - remote_discr: BfdDiscriminator, - state: BfdState, - last_diag: BfdDiagnostic, - control_packets_rx: u64, - control_packets_tx: u64, - last_change: Instant, -} -impl Display for BfdSessionStats { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "{:<20} {:<9} {:>6} {:>6}", - self.remote_ip, self.state, self.control_packets_rx, self.control_packets_tx - )?; - write!(f, " LD:{} RD:{} Diag:{} Last: {}s", self.local_discr, self.remote_discr, self.last_diag, Instant::now().duration_since(self.last_change).as_secs()) - } -} impl From<&BfdSessionState> for BfdSessionStats { fn from(state: &BfdSessionState) -> Self { Self { @@ -434,6 +217,8 @@ impl From<&BfdSessionState> for BfdSessionStats { control_packets_rx: state.control_packets_rx(), control_packets_tx: state.control_packets_tx(), last_change: state.last_state_change(), + detect_time: state.detection_time(), + base_interval: BfdInterval::from_secs(0), // TODO: This is a hack and will be patched up in the session } } } @@ -463,10 +248,16 @@ struct BfdSessionHandle { } impl BfdSessionHandle { - async fn new(local_addr: IpAddr, remote_addr: IpAddr, local_discr: BfdDiscriminator) -> Self { + async fn new( + local_addr: IpAddr, + remote_addr: IpAddr, + local_discr: BfdDiscriminator, + event_dispatcher: EventDispatcherHandle, + ) -> Self { let (tx, rx) = mpsc::channel(32); let new_self = Self { tx }; - let periodic = BfdPeriodicSenderHandle::new(new_self.clone(), BfdInterval(1_000_000)); + let periodic = + BfdPeriodicSenderHandle::new(new_self.clone(), time::Duration::from_secs(1).into()); let mut session = BfdSession::new( rx, @@ -475,6 +266,7 @@ impl BfdSessionHandle { local_addr, remote_addr, local_discr, + event_dispatcher, ) .await; tokio::spawn(async move { session.run().await }); @@ -534,8 +326,8 @@ impl BfdPeriodicSender { } } async fn run(&mut self) { - let base_interval = time::Duration::from_micros(self.cur_interval.0 as u64 * 3 / 4); - let mut clock = time::interval(base_interval); + let base_interval = self.cur_interval * 3 / 4; + let mut clock = time::interval(base_interval.into()); let mut running = true; loop { @@ -560,7 +352,7 @@ impl BfdPeriodicSender { , PeriodicControlCommand::Start => { debug!("Starting periodic packets"); running = true }, PeriodicControlCommand::SetMinInterval(i) => { - let base_interval = time::Duration::from_micros(i.0 as u64 * 3 / 4); + let base_interval: Duration = (i * 3 / 4).into(); debug!("Updating base interval to {}ms (jittering {} to {}ms)", base_interval.as_millis(), base_interval.as_millis(), base_interval.as_millis() + base_interval.as_millis() / 3); clock = time::interval_at( time::Instant::now() + base_interval.into(), @@ -586,6 +378,7 @@ impl BfdSession { local_addr: IpAddr, remote_addr: IpAddr, local_discr: BfdDiscriminator, + event_dispatcher: EventDispatcherHandle, ) -> Self { let mut rng = rand::thread_rng(); @@ -594,7 +387,7 @@ impl BfdSession { let control_sock = UdpSocket::bind(SocketAddr::new(local_addr, source_port)) .await .unwrap(); - control_sock.set_ttl(255).unwrap(); + set_ttl_or_hop_limit(&control_sock, 255).unwrap(); // control_sock // .connect(SocketAddr::new(remote_addr, CONTROL_PORT)) // .await?; @@ -605,7 +398,8 @@ impl BfdSession { tx, periodic, rx_watchdog: None, - cur_interval: BfdInterval(1_000_000), + cur_interval: BfdInterval::from_secs(1), + event_dispatcher, state: BfdSessionState { control_sock: Arc::new(control_sock), peer_addr: remote_addr, @@ -614,9 +408,9 @@ impl BfdSession { local_discr, remote_discr: Atomic::new(BfdDiscriminator(0)), local_diag: Atomic::new(BfdDiagnostic::None), - desired_min_tx_interval: Atomic::new(BfdInterval(1_000_000)), - required_min_rx_interval: Atomic::new(BfdInterval(300_000)), - remote_min_rx_interval: Atomic::new(BfdInterval(1)), + desired_min_tx_interval: Atomic::new(BfdInterval::from_millis(300)), + required_min_rx_interval: Atomic::new(BfdInterval::from_millis(300)), + remote_min_rx_interval: Atomic::new(BfdInterval::from_millis(300)), demand_mode: Atomic::new(false), remote_demand_mode: Atomic::new(false), detect_mult: Atomic::new(3), @@ -624,7 +418,7 @@ impl BfdSession { rcv_auth_seq: Atomic::new(0), xmit_auth_seq: Atomic::new(rng.gen()), auth_seq_known: Atomic::new(false), - detection_time: Atomic::new(BfdInterval(0)), + detection_time: Atomic::new(BfdInterval::from_secs(0)), poll_mode: Atomic::new(false), control_packets_rx: Atomic::new(0), @@ -639,6 +433,16 @@ impl BfdSession { if self.state.session_state() == new_state { return; }; + self.event_dispatcher + .dispatch(EventMessage::StateChange(StateChangeEvent { + local_discr: self.state.local_discr(), + remote_discr: self.state.remote_discr(), + local_ip: self.state.control_sock().local_addr().unwrap().ip(), //TODO: shouldn't unwrap here, it will crash the program + remote_ip: self.state.peer_addr(), + from_state: self.state.session_state(), + to_state: new_state, + })) + .await; info!( "Peer {} state change {} -> {}", self.state.peer_addr(), @@ -655,10 +459,20 @@ impl BfdSession { .send(WatchdogReset::Terminate) .await .unwrap(); - } + }; + // When bfd.SessionState is not Up, the system MUST set bfd.DesiredMinTxInterval to a value of not + // less than one second (1,000,000 microseconds). This is intended to ensure that the bandwidth + // consumed by BFD sessions that are not Up is negligible, particularly in the case where a neighbor + // may not be running BFD. + self.state.set_desired_min_tx_interval(std::cmp::min( + self.state.desired_min_tx_interval(), + BfdInterval::from_secs(1), + )); + self.update_transmit_interval().await; } (BfdState::AdminDown | BfdState::Down | BfdState::Init, BfdState::Up) => { self.start_watchdog().await; + self.state.set_local_diag(BfdDiagnostic::None); // This doesn't seem to be required by the spec but seems sensible and at least FRR does it. } _ => {} } @@ -673,7 +487,9 @@ impl BfdSession { self.receive_control_packet(&packet).await; } SessionControlCommand::GetSessionStats { respond_to } => { - respond_to.send(BfdSessionStats::from(&self.state)).unwrap() + let mut stats = + BfdSessionStats::from(&self.state).with_base_interval(self.cur_interval); + respond_to.send(stats).unwrap() } SessionControlCommand::TxControlPacket { poll_response } => { self.transmit_control_packet(poll_response).await @@ -710,20 +526,28 @@ impl BfdSession { your_disc: self.state.remote_discr(), desired_min_tx: self.state.desired_min_tx_interval(), required_min_rx: self.state.required_min_rx_interval(), - required_min_echo_rx: BfdInterval(0), + required_min_echo_rx: BfdInterval::from_secs(0), auth: None, }; let socket = self.state.control_sock.clone(); let dest = self.state.peer_addr; self.state.control_packets_tx.fetch_add(1, ORDERING); debug!("tx packet: {:?}", packet); - socket + match socket .send_to( packet.serialize().unwrap().as_ref(), SocketAddr::new(dest, CONTROL_PORT), ) .await - .unwrap(); + { + Err(e) => { + warn!( + "Error sending packet to {:?}:{} ({:?})", + dest, CONTROL_PORT, e + ); + } + Ok(_) => {} + } } async fn start_watchdog(&mut self) { @@ -760,7 +584,10 @@ impl BfdSession { let mut temporary = Some(tx); swap(&mut temporary, &mut self.rx_watchdog); if temporary.is_some() { - debug!("Updating old watchdog with new detection time"); + debug!( + "Updating old watchdog with new detection time ({}ms)", + duration.as_millis() + ); temporary .as_mut() .unwrap() @@ -879,7 +706,7 @@ impl BfdSession { self.state.set_last_remote_diag(p.flags.diag().unwrap()); // If the Required Min Echo RX Interval field is zero, the transmission of Echo packets, if any, MUST cease. - if p.required_min_echo_rx == BfdInterval(0) { + if p.required_min_echo_rx == BfdInterval::from_secs(0) { // TODO: implement echo thread } // If a Poll Sequence is being transmitted by the local system and the Final (F) bit in the received packet is @@ -987,16 +814,15 @@ impl BfdSession { let old = self.state.detection_time(); self.state.set_detection_time(if !self.state.demand_mode() { (p.flags.detect_mult() as u32 - * std::cmp::max(self.state.required_min_rx_interval(), p.desired_min_tx).0) - .into() + * std::cmp::max(self.state.required_min_rx_interval(), p.desired_min_tx)) + .into() } else { (self.state.detect_mult() as u32 * std::cmp::max( self.state.desired_min_tx_interval(), self.state.remote_min_rx_interval(), - ) - .0) - .into() + )) + .into() }); if old != self.state.detection_time() { self.start_watchdog().await @@ -1004,100 +830,257 @@ impl BfdSession { } } +#[derive(Debug)] +enum EventDispatcherCommand { + AddSink(Box), + DispatchMessage(EventMessage), +} + +struct EventDispatcher { + event_sinks: Vec>, + event_sink_handles: Vec>, + tx: mpsc::Sender, + rx: mpsc::Receiver, // TODO: there must be a better way to do this +} + +impl EventDispatcher { + fn new( + channel_tx: mpsc::Sender, + channel_rx: mpsc::Receiver, + ) -> Self { + let event_sinks = Vec::new(); + let event_sink_handles = Vec::new(); + + Self { + event_sinks, + event_sink_handles, + tx: channel_tx, + rx: channel_rx, + } + } + fn add_sink(&mut self, sink: Box) { + // Save the transmit channel + self.event_sinks.push(sink.channel()); + // Run the sink + self.event_sink_handles.push(sink.run()); + } + async fn dispatch(&self, event: EventMessage) { + debug!( + "EventDispatcher dispatching event {:?} to sinks {:?}", + event, self.event_sinks + ); + for sink in &self.event_sinks { + sink.send(event.clone()).await.unwrap(); + } + } + async fn run(mut self) { + info!("Event dispatcher running"); + while let Some(cmd) = self.rx.recv().await { + debug!("EventDispatcher got command {:?}", cmd); + match cmd { + EventDispatcherCommand::AddSink(sink) => self.add_sink(sink), + EventDispatcherCommand::DispatchMessage(event) => self.dispatch(event).await, + } + } + } +} + +#[derive(Clone)] +struct EventDispatcherHandle { + channel_tx: mpsc::Sender, +} + +impl EventDispatcherHandle { + async fn new() -> Self { + let (channel_tx, channel_rx) = mpsc::channel(16); + let dispatcher = EventDispatcher::new(channel_tx.clone(), channel_rx); + tokio::spawn(async move { dispatcher.run().await }); + Self { channel_tx } + } + async fn dispatch(&self, event: EventMessage) { + self.channel_tx + .send(EventDispatcherCommand::DispatchMessage(event)) + .await + .unwrap(); + } + async fn add_sink(&self, sink: Box) { + self.channel_tx + .send(EventDispatcherCommand::AddSink(sink)) + .await + .unwrap() + } +} + +struct Bfdd { + local: SocketAddr, + peers: Vec, + sessions_by_ip: HashMap<(IpAddr, IpAddr), BfdDiscriminator>, + sessions_by_discr: HashMap, + + control_socket: Arc, + + event_dispatcher: EventDispatcherHandle, +} + +impl Bfdd { + async fn new() -> Result> { + // let local = SocketAddr::from_str("192.168.255.2:3784").unwrap(); + let local = SocketAddr::new(Ipv6Addr::from_str("2001:db8::2").unwrap().into(), 3784); + let peers = vec![ + // SocketAddr::from_str("192.168.122.132:3784").unwrap(), + // SocketAddr::from_str("192.168.255.1:3784").unwrap(), + SocketAddr::new(Ipv6Addr::from_str("2001:db8::1").unwrap().into(), 3784), + ]; + let sessions_by_ip = HashMap::new(); + let sessions_by_discr = HashMap::new(); + + // TODO: Bind only the sockets we need for the configuration, including multihop and run them in different tasks. + info!("Binding on {:?}", local); + let control_socket = Arc::new(UdpSocket::bind(local).await?); + // If BFD authentication is not in use on a session, all BFD Control packets for the session MUST be sent with a + // Time to Live (TTL) or Hop Limit value of 255. + set_ttl_or_hop_limit(&control_socket, 255)?; + let echo_socket = Arc::new(UdpSocket::bind(SocketAddr::new(local.ip(), ECHO_PORT)).await?); + + let event_dispatcher = EventDispatcherHandle::new().await; + event_dispatcher + .add_sink(Box::new(ScriptHookSink::new())) + .await; + + Ok(Self { + local, + peers, + sessions_by_ip, + sessions_by_discr, + control_socket, + event_dispatcher, + }) + } + async fn start_peers(&mut self) { + let mut rng = rand::thread_rng(); + for peer in &self.peers { + let mut local_discr = BfdDiscriminator(0); + while local_discr.0 == 0 || self.sessions_by_discr.get(&local_discr).is_some() { + local_discr = BfdDiscriminator(rng.gen()); + } + let session = BfdSessionHandle::new( + self.local.ip(), + peer.ip(), + local_discr, + self.event_dispatcher.clone(), + ) + .await; + + self.sessions_by_ip + .insert((self.local.ip(), peer.ip()), local_discr); + self.sessions_by_discr.insert(local_discr, session); + } + } + fn session_for_packet( + &self, + addr: &SocketAddr, + packet: &BfdPacket, + ) -> Option<&BfdSessionHandle> { + // Try to find a configured session for the remote IP if it doesn't send our discriminator, otherwise + // use the discriminator that was sent. + + // https://datatracker.ietf.org/doc/html/rfc5880#section-6.3 + // + // Once the remote end echoes back the local discriminator, all further received packets are + // demultiplexed based on the Your Discriminator field only (which means that, among other things, the + // source address field can change or the interface over which the packets are received can change, but + // the packets will still be associated with the proper session). + let our_disc = if packet.your_disc == BfdDiscriminator(0) { + if let Some(discr) = self + .sessions_by_ip + .get(&(self.control_socket.local_addr().unwrap().ip(), addr.ip())) + { + // TODO: Probably we should destroy the old session here? + debug!( + "Found session for unknown discriminator from {}: {}", + addr.ip(), + discr + ); + discr + } else { + warn!("Unable to match packet to session with {}", addr.ip()); + return None; + } + } else { + &packet.your_disc + }; + self.sessions_by_discr.get(our_disc) + } + async fn rx_packet(&self, addr: SocketAddr, packet: BfdPacket) { + if let Some(session) = self.session_for_packet(&addr, &packet) { + session + .tx + .send(SessionControlCommand::RxPacket(packet)) + .await + .unwrap(); + } else { + warn!("Unable to match packet to session with"); + return; + } + } + async fn run(mut self) -> io::Result<()> { + info!("rust-bfd starting up"); + self.start_peers().await; + info!( + "configured sessions: {}", + future::join_all( + self.sessions_by_discr + .values() + .map(|session| async { session.get_stats().await.unwrap().to_string() }) + ) + .await + .join("\n\t") + ); + + let rx_task = + task::spawn(async move { + let mut buf = [0; 1024]; + loop { + // TODO: All received BFD Control packets that are demultiplexed to the session MUST be discarded if + // the received TTL or Hop Limit is not equal to 255. Need raw sockets for this. + // + // TODO: Maybe it is okay to receive in the session-specific socket instead of this global one? But + // the RFC mentions that we must demux based on discriminator only. + let (len, addr) = self.control_socket.recv_from(&mut buf).await.unwrap(); // TODO: fallibility? + if let Ok((_leftover, packet)) = BfdPacket::parse(&buf[..len]) { + debug!("rx packet ({:?}): {:?}", addr.ip(), packet); + self.rx_packet(addr, packet).await + } else { + warn!("Failed to parse packet"); + } + + info!( + "configured sessions: {}", + future::join_all(self.sessions_by_discr.values().map(|session| async { + session.get_stats().await.unwrap().to_string() + })) + .await + .join("\n\t") + ); + } + }); + + tokio::join!(rx_task) + .0 + .expect("Unable to join on the receive thread"); + + Ok(()) + } +} + #[tokio::main] async fn main() -> io::Result<()> { - env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); + env_logger::Builder::from_env(Env::default().default_filter_or("info")) + .format_timestamp_micros() + .init(); - info!("rust-bfd starting up"); - - let local = SocketAddr::from_str("192.168.122.1:3784").unwrap(); - let peers = vec![SocketAddr::from_str("192.168.122.132:3784").unwrap()]; - let mut rng = rand::thread_rng(); - let mut sessions_by_ip = HashMap::new(); - let mut sessions_by_discr = HashMap::new(); - - for peer in peers { - let mut local_discr = BfdDiscriminator(0); - while local_discr.0 == 0 || sessions_by_discr.get(&local_discr).is_some() { - local_discr = BfdDiscriminator(rng.gen()); - } - let session = BfdSessionHandle::new(local.ip(), peer.ip(), local_discr).await; - - sessions_by_ip.insert((local.ip(), peer.ip()), local_discr); - sessions_by_discr.insert(local_discr, session); + match Bfdd::new().await { + Ok(bfdd) => bfdd.run().await, + Err(e) => Err(std::io::Error::new(ErrorKind::Other, e.to_string())), } - debug!( - "configured sessions: {}", - future::join_all( - sessions_by_discr - .values() - .map(|session| async { session.get_stats().await.unwrap().to_string() }) - ) - .await - .join("\n\t") - ); - - let control_sock = Arc::new(UdpSocket::bind(local).await?); - // If BFD authentication is not in use on a session, all BFD Control packets for the session MUST be sent with a - // Time to Live (TTL) or Hop Limit value of 255. - control_sock.set_ttl(255).unwrap(); - let echo_socket = Arc::new(UdpSocket::bind(SocketAddr::new(local.ip(), ECHO_PORT)).await?); - - let rx_thread = task::spawn(async move { - let mut buf = [0; 1024]; - loop { - // TODO: All received BFD Control packets that are demultiplexed to the session MUST be discarded if the - // received TTL or Hop Limit is not equal to 255. Need raw sockets for this. - let (len, addr) = control_sock.recv_from(&mut buf).await.unwrap(); // TODO: fallibility? - if let Ok((_leftover, packet)) = BfdPacket::parse(buf.as_slice()) { - debug!("rx packet: {:?}", packet); - let our_disc = if packet.your_disc == BfdDiscriminator(0) { - if let Some(discr) = - sessions_by_ip.get(&(control_sock.local_addr().unwrap().ip(), addr.ip())) - { - debug!( - "Found session for unknown discriminator from {}: {}", - addr.ip(), - discr - ); - discr - } else { - warn!("Unable to match packet to session with {}", addr.ip()); - continue; - } - } else { - &packet.your_disc - }; - if let Some(session) = sessions_by_discr.get(our_disc) { - session - .tx - .send(SessionControlCommand::RxPacket(packet)) - .await - .unwrap(); - } else { - warn!("Unable to match packet to session with {}", our_disc); - continue; - } - } else { - warn!("Failed to parse packet"); - } - - debug!( - "configured sessions: {}", - future::join_all( - sessions_by_discr - .values() - .map(|session| async { session.get_stats().await.unwrap().to_string() }) - ) - .await - .join("\n\t") - ); - } - }); - - tokio::join!(rx_thread) - .0 - .expect("Unable to join on the receive thread"); - - Ok(()) } diff --git a/src/control.rs b/src/control.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/events.rs b/src/events.rs new file mode 100644 index 0000000..ec1937b --- /dev/null +++ b/src/events.rs @@ -0,0 +1,171 @@ +use core::{fmt, str}; +use std::{ + ffi::{OsStr, OsString}, + net::IpAddr, + path::PathBuf, + process::ExitStatus, +}; + +use itertools::Itertools; +use log::{debug, info, warn}; +use nix::unistd::{self, eaccess}; +use tokio::{ + fs, process, + sync::mpsc, + task::{self, JoinHandle}, +}; + +use crate::{BfdDiscriminator, BfdState}; + +#[derive(Debug, Clone)] +pub enum EventMessage { + StateChange(StateChangeEvent), +} + +#[derive(Debug, Clone)] +pub struct StateChangeEvent { + pub local_discr: BfdDiscriminator, + pub remote_discr: BfdDiscriminator, + pub local_ip: IpAddr, + pub remote_ip: IpAddr, + pub from_state: BfdState, + pub to_state: BfdState, +} + +pub trait EventMessageSink: Send + Sync + fmt::Debug { + fn channel(&self) -> mpsc::Sender; + fn run(self: Box) -> JoinHandle<()>; +} + +#[derive(Debug)] +pub struct ScriptHookSink { + script_dir: PathBuf, + tx_channel: mpsc::Sender, + rx_channel: mpsc::Receiver, +} + +impl EventMessageSink for ScriptHookSink { + fn channel(&self) -> mpsc::Sender { + self.tx_channel.clone() + } + fn run(self: Box) -> JoinHandle<()> { + task::spawn(async move { self.event_loop().await }) + } +} + +impl ScriptHookSink { + pub fn new() -> Self { + let (tx_channel, rx_channel) = mpsc::channel(16); + Self { + script_dir: "/etc/rust-bfd/hooks.d/".into(), + tx_channel, + rx_channel, + } + } + async fn script_runner>(script_path: S, args: &[S], environment: &[(S, S)]) { + let mut cmd = process::Command::new("sh"); + cmd.env_clear() + .arg0(&script_path) + .arg(&script_path) + .args(args) + .envs(environment.iter().map(|(k, v)| (k, v))); + debug!("Running command {:?}", cmd); + let result = cmd.output().await; + match result { + Ok(output) => { + info!( + "Script `{} {}` completed with status {}", + cmd.as_std().get_program().to_string_lossy(), + cmd.as_std() + .get_args() + .map(|arg| arg.to_string_lossy()) + .join(" "), + output.status + ); + if !output.status.success() { + if output.stdout.len() > 0 { + warn!( + " stdout: \n{}", + str::from_utf8(output.stdout.as_slice()).unwrap() + ); + } + if output.stderr.len() > 0 { + warn!( + " stderr: \n{}", + str::from_utf8(output.stderr.as_slice()).unwrap() + ); + } + } + } + Err(e) => { + warn!("Failed to execute {:?}, {}", cmd, e); + } + } + } + async fn send_state_change(&self, event: &StateChangeEvent) { + debug!("Got a state change event! {:?}, running scripts", event); + let mut files = match fs::read_dir(&self.script_dir).await { + Ok(files) => files, + Err(e) => { + warn!( + "Error opening script directory `{}` ({}), no hooks executed.", + self.script_dir.as_os_str().to_string_lossy(), + e + ); + return; + } + }; + let mut scripts = Vec::new(); + while let Ok(Some(direntry)) = files.next_entry().await { + // Check if entry is a file + if direntry.file_type().await.is_ok_and(|ft| ft.is_file()) { + // Check if it's executable + let path = direntry.path(); + if unistd::eaccess(&path, unistd::AccessFlags::X_OK).is_ok() { + // Add to scripts array + scripts.push(path); + } + } + } + // Execute in lexicographic order, so sort the list of found scripts + scripts.sort_unstable_by_key(|path| { + path.file_name() + .map_or(OsString::new(), |file| file.to_owned()) + }); + + if scripts.len() > 0 { + // Build context + let mut args = Vec::<&OsStr>::new(); + let mut env = Vec::<(&OsStr, &OsStr)>::new(); + + let remote_ip_str = OsString::from(event.remote_ip.to_string()); + let local_ip_str = OsString::from(event.local_ip.to_string()); + let remote_discr_str = OsString::from(event.remote_discr.to_string()); + let local_discr_str = OsString::from(event.local_discr.to_string()); + + args.push(event.to_state.into()); + args.push(&remote_ip_str); + env.push((OsStr::new("BFD_LAST_STATE"), event.from_state.into())); + env.push((OsStr::new("BFD_STATE"), event.to_state.into())); + env.push((OsStr::new("BFD_LOCAL_ADDR"), &local_ip_str)); + env.push((OsStr::new("BFD_PEER_ADDR"), &remote_ip_str)); + env.push((OsStr::new("BFD_LOCAL_DISCRIMINATOR"), &local_discr_str)); + env.push((OsStr::new("BFD_PEER_DISCRIMINATOR"), &remote_discr_str)); + for script in scripts { + Self::script_runner(script.as_os_str(), &args, &env).await + } + } + } + pub async fn handle_event(&self, event: &EventMessage) { + match event { + EventMessage::StateChange(change) => self.send_state_change(change).await, + } + } + async fn event_loop(mut self) { + info!("ScriptHookSink started"); + while let Some(event) = self.rx_channel.recv().await { + debug!("ScriptHookSink got event {:?}", event); + self.handle_event(&event).await + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..0d988c7 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,75 @@ +pub mod bfd; +pub mod control; +pub mod events; + +use std::{ + fmt::{self, Display}, + net::IpAddr, + time::Instant, +}; + +use log::debug; +use nix::sys::socket::{getsockopt, setsockopt, sockopt}; +use tokio::net::UdpSocket; + +use crate::bfd::*; + +#[derive(Debug)] +pub struct BfdSessionStats { + pub local_ip: IpAddr, + pub remote_ip: IpAddr, + pub local_discr: BfdDiscriminator, + pub remote_discr: BfdDiscriminator, + pub state: BfdState, + pub last_diag: BfdDiagnostic, + pub control_packets_rx: u64, + pub control_packets_tx: u64, + pub last_change: Instant, + pub detect_time: BfdInterval, + pub base_interval: BfdInterval, +} +impl BfdSessionStats { + pub fn with_base_interval(mut self, base_interval: BfdInterval) -> Self { + self.base_interval = base_interval; + self + } +} +impl Display for BfdSessionStats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "{:<20} {:<9} {:>6} {:>6}", + self.remote_ip, self.state, self.control_packets_rx, self.control_packets_tx + )?; + write!( + f, + " LocalDiscr:{} RemoteDiscr:{} DetectTime: {} TxTime: {} Diag:{} Last: {}s", + self.local_discr, + self.remote_discr, + self.detect_time, + self.base_interval, + self.last_diag, + Instant::now().duration_since(self.last_change).as_secs() + ) + } +} + +pub fn set_ttl_or_hop_limit(sock: &UdpSocket, hop_limit: u8) -> Result<(), std::io::Error> { + let addr = sock.local_addr()?; + if addr.is_ipv4() { + debug!("IPv4 socket, setting ttl {}", hop_limit as u32); + sock.set_ttl(hop_limit as u32) + } else if addr.is_ipv6() { + debug!("IPv6 socket, setting hop limit {}", hop_limit as i32); + let i32_hop_limit = hop_limit as i32; + if let Err(e) = setsockopt(sock, sockopt::Ipv6Ttl, &i32_hop_limit) { + Err(std::io::Error::from_raw_os_error(e as i32)) + } else { + let res = getsockopt(sock, sockopt::Ipv6Ttl)?; + debug!("Got {} after setting {}", res, i32_hop_limit); + Ok(()) + } + } else { + panic!("UdpSocket with no address family ?!") + } +}