Initial commit. Kinda does something.

This commit is contained in:
2024-01-31 00:59:42 -08:00
commit d00f47d004
4 changed files with 1237 additions and 0 deletions

664
src/main.rs Normal file
View File

@ -0,0 +1,664 @@
use std::{
collections::HashMap, error::Error, fmt::Display, fs::read, io::Cursor, net::{IpAddr, SocketAddr}, str::FromStr, sync::Arc
};
use nom::{bytes::complete::take, multi::many_m_n, number::complete::be_u8, IResult};
use nom_derive::{NomBE, Parse};
use proc_bitfield::*;
use rand::prelude::*;
use tokio::task;
use tokio::time;
use tokio::{io, join, task::JoinHandle};
use tokio::{net::UdpSocket, sync::RwLock};
use tokio::{sync::mpsc, time::Instant};
use byteorder::{BigEndian, WriteBytesExt};
const CONTROL_PORT: u16 = 3784;
const ECHO_PORT: u16 = 3785;
#[repr(u8)]
#[derive(ConvRaw, Debug, NomBE, PartialEq, Eq, Clone, Copy)]
pub enum BfdDiagnostic {
None = 0,
TimeExpired = 1,
EchoFailed = 2,
NeighborDown = 3,
FwdPlaneReset = 4,
PathDown = 5,
ConcatPathDown = 6,
AdminDown = 7,
RevConcatPathDown = 8,
Reserved,
}
#[repr(u8)]
#[derive(ConvRaw, Debug, NomBE, PartialEq, Eq, Default, Clone, Copy)]
pub enum BfdState {
AdminDown = 0,
#[default]
Down = 1,
Init = 2,
Up = 3,
}
impl Display for BfdState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
Self::AdminDown => "AdminDown",
Self::Down => "Down",
Self::Init => "Init",
Self::Up => "Up",
})
}
}
#[repr(u8)]
#[derive(ConvRaw, Debug, NomBE, PartialEq, Eq, Clone, Copy)]
pub enum BfdAuthType {
None = 0,
SimplePassword = 1,
KeyedMD5 = 2,
MetKeyedMD5 = 3,
KeyedSHA1 = 4,
MetKeyedSHA1 = 5,
Reserved,
}
#[derive(Debug)]
pub enum BfdError {
// field, value
InvalidFieldValue(&'static str, &'static str),
}
impl Display for BfdError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidFieldValue(field, value) => {
write!(f, "invalid value `{}` for field `{}`", field, value)
}
}
}
}
impl Error for BfdError {}
#[derive(Debug, NomBE, PartialEq, Eq, Clone, Copy)]
pub struct BfdDiscriminator(u32);
impl Display for BfdDiscriminator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, NomBE, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
pub struct BfdInterval(u32);
impl From<BfdInterval> for time::Duration {
fn from(value: BfdInterval) -> Self {
time::Duration::from_micros(value.0 as u64)
}
}
bitfield! {
#[derive(NomBE)]
pub struct BfdFlags(pub u32): Debug {
pub vers: u8 @ 29..=31,
pub diag: u8 [try_get BfdDiagnostic] @ 24..=28,
pub state: u8 [try_get BfdState] @ 22..=23,
pub poll: bool @ 21,
pub final_: bool @ 20,
pub cpi: bool @ 19,
pub auth_present: bool @ 18,
pub demand: bool @ 17,
pub multipoint: bool @ 16,
pub detect_mult: u8 @ 8..=15,
pub length: u8 @ 0..=7
}
}
#[derive(Debug)]
pub struct BfdAuthSimplePassword(Vec<u8>);
impl<'a> Parse<&'a [u8]> for BfdAuthSimplePassword {
fn parse(i: &'a [u8]) -> IResult<&'a [u8], Self, nom::error::Error<&'a [u8]>> {
let (i, res) = many_m_n(1, 16, be_u8)(i)?;
Ok((i, Self(res)))
}
}
#[derive(Debug, NomBE)]
pub struct BfdAuthKeyedMD5 {
key_id: u8,
_reserved: u8,
seq: u32,
digest: [u8; 16],
}
#[derive(Debug, NomBE)]
pub struct BfdAuthKeyedSHA1 {
key_id: u8,
_reserved: u8,
seq: u32,
hash: [u8; 20],
}
#[derive(Debug, NomBE)]
#[nom(Selector = "BfdAuthType", Complete)]
pub enum BfdAuthData {
#[nom(Selector = "BfdAuthType::SimplePassword")]
SimplePassword(BfdAuthSimplePassword),
#[nom(Selector = "BfdAuthType::KeyedMD5")]
KeyedMD5(BfdAuthKeyedMD5),
#[nom(Selector = "BfdAuthType::MetKeyedMD5")]
MetKeyedMD5(BfdAuthKeyedMD5),
#[nom(Selector = "BfdAuthType::KeyedSHA1")]
KeyedSHA1(BfdAuthKeyedSHA1),
#[nom(Selector = "BfdAuthType::MetKeyedSHA1")]
MetKeyedSHA1(BfdAuthKeyedSHA1),
}
impl BfdAuthData {
fn parse_be_with_length(
i: &[u8],
auth_type: BfdAuthType,
auth_len: u8,
) -> IResult<&[u8], Self> {
let (new_i, data) = take(auth_len)(i)?;
let (_leftovers, retval) = BfdAuthData::parse_be(data, auth_type)?;
Ok((new_i, retval))
}
}
#[derive(Debug, NomBE)]
pub struct BfdAuth {
auth_type: BfdAuthType,
auth_len: u8,
#[nom(Parse = "{ |i| BfdAuthData::parse_be_with_length(i, auth_type, auth_len) }")]
auth_data: BfdAuthData,
}
#[derive(Debug, NomBE)]
pub struct BfdPacket {
flags: BfdFlags,
my_disc: BfdDiscriminator,
your_disc: BfdDiscriminator,
desired_min_tx: BfdInterval,
required_min_rx: BfdInterval,
required_min_echo_rx: BfdInterval,
#[nom(Cond = "flags.auth_present()")]
auth: Option<BfdAuth>,
}
impl BfdPacket {
fn serialize(&self) -> Result<Box<[u8]>, std::io::Error> {
// TODO: serialize auth
let buf = [0u8; 24];
let mut wtr = Cursor::new(buf);
wtr.write_u32::<BigEndian>(self.flags.0)?;
wtr.write_u32::<BigEndian>(self.my_disc.0)?;
wtr.write_u32::<BigEndian>(self.your_disc.0)?;
wtr.write_u32::<BigEndian>(self.desired_min_tx.0)?;
wtr.write_u32::<BigEndian>(self.required_min_rx.0)?;
wtr.write_u32::<BigEndian>(self.required_min_echo_rx.0)?;
Ok(Box::new(wtr.into_inner()))
}
}
#[derive(Debug, Clone)]
struct BfdSessionState {
control_sock: Arc<UdpSocket>,
peer_addr: IpAddr,
session_state: BfdState,
remote_session_state: BfdState,
local_discr: BfdDiscriminator,
remote_discr: BfdDiscriminator,
local_diag: BfdDiagnostic,
desired_min_tx_interval: BfdInterval,
required_min_rx_interval: BfdInterval,
remote_min_rx_interval: BfdInterval,
demand_mode: bool,
remote_demand_mode: bool,
detect_mult: u8,
auth_type: BfdAuthType,
rcv_auth_seq: u32,
xmit_auth_seq: u32,
auth_seq_known: bool,
periodic_cmd_channel: mpsc::Sender<PeriodicControlCommand>,
detection_time: time::Duration,
poll_mode: bool,
}
struct BfdSession {
state: Arc<RwLock<BfdSessionState>>,
}
enum PeriodicControlCommand {
Stop,
Start,
Quit,
SetMinInterval(BfdInterval),
}
enum SessionControlCommand {
RxPacket(Vec<u8>),
Quit,
}
impl BfdSession {
async fn new(
local_addr: IpAddr,
remote_addr: IpAddr,
) -> Result<Self, Box<dyn std::error::Error>> {
let mut rng = rand::thread_rng();
//TODO: select a random unused port instead of pure random
let source_port: u16 = rng.gen_range(49152..=65535);
let control_sock = UdpSocket::bind(SocketAddr::new(local_addr, source_port)).await?;
// control_sock
// .connect(SocketAddr::new(remote_addr, CONTROL_PORT))
// .await?;
// Incoming packets will come over the channel from the mux, since they don't send to the reciprocal port
Ok(Self {
state: Arc::new(RwLock::new(BfdSessionState {
control_sock: Arc::new(control_sock),
peer_addr: remote_addr,
session_state: BfdState::default(),
remote_session_state: BfdState::default(),
local_discr: BfdDiscriminator(rng.gen()),
remote_discr: BfdDiscriminator(0),
local_diag: BfdDiagnostic::None,
desired_min_tx_interval: BfdInterval(1_000_000),
required_min_rx_interval: BfdInterval(300_000),
remote_min_rx_interval: BfdInterval(1),
demand_mode: false,
remote_demand_mode: false,
detect_mult: 3,
auth_type: BfdAuthType::None,
rcv_auth_seq: 0,
xmit_auth_seq: rng.gen(),
auth_seq_known: false,
periodic_cmd_channel: mpsc::channel(1).0,
detection_time: time::Duration::ZERO,
poll_mode: false,
})),
})
}
async fn spawn_control_thread(
self: Arc<Self>,
mut rx: mpsc::Receiver<SessionControlCommand>,
) -> JoinHandle<()> {
task::spawn(async move {
while let Some(cmd) = rx.recv().await {
match cmd {
SessionControlCommand::Quit => return,
SessionControlCommand::RxPacket(buf) => {
if let Ok((_leftover, packet)) = BfdPacket::parse(buf.as_slice()) {
println!("packet: {:?}", packet);
self.clone().receive_control_packet(&packet).await
} else {
eprintln!("Failed to parse packet");
}
}
}
}
})
}
async fn transmit_periodic_packet(self: Arc<Self>) {
let read_guard = self.state.read().await;
let packet = BfdPacket {
flags: BfdFlags(0).with_vers(1).with_diag(read_guard.local_diag.into()).with_state(read_guard.session_state.into()).with_poll(read_guard.poll_mode).with_cpi(true).with_demand(read_guard.session_state == BfdState::Up && read_guard.remote_session_state == BfdState::Up).with_detect_mult(read_guard.detect_mult).with_length(24),
my_disc: read_guard.local_discr,
your_disc: read_guard.remote_discr,
desired_min_tx: read_guard.desired_min_tx_interval,
required_min_rx: read_guard.required_min_rx_interval,
required_min_echo_rx: BfdInterval(0),
auth: None
};
let socket = read_guard.control_sock.clone();
let dest = read_guard.peer_addr;
drop(read_guard);
socket.send_to(packet.serialize().unwrap().as_ref(), SocketAddr::new(dest, CONTROL_PORT)).await.unwrap();
}
async fn spawn_periodic_thread(
self: Arc<Self>,
mut rx: mpsc::Receiver<PeriodicControlCommand>,
interval: BfdInterval,
) -> JoinHandle<()> {
task::spawn(async move {
let mut running = true;
let base_interval = time::Duration::from_micros(interval.0 as u64 * 3 / 4);
let mut clock = time::interval(base_interval);
'MAIN: loop {
if running {
// Get and action all pending commands then wait for interval to tick
while let Ok(cmd) = rx.try_recv() {
match cmd {
PeriodicControlCommand::Quit => return,
PeriodicControlCommand::Stop => {
running = false;
continue 'MAIN;
}
PeriodicControlCommand::Start => running = true,
PeriodicControlCommand::SetMinInterval(i) => {
running = true;
let base_interval = time::Duration::from_micros(i.0 as u64 * 3 / 4);
clock = time::interval_at(
time::Instant::now() + base_interval.into(),
base_interval.into(),
);
}
}
}
// The periodic transmission of BFD Control packets MUST be jittered on a per-packet basis by up to
// 25%, that is, the interval MUST be reduced by a random value of 0 to 25%
//
// We do the equivalent inverse, we wait 75%, then add an additional 0-25%.
let jitter = time::Duration::from_micros(
rand::thread_rng()
.gen_range(0..clock.period().as_micros() / 3)
.try_into()
.unwrap(),
);
clock.tick().await;
time::sleep(jitter).await;
self.clone().transmit_periodic_packet().await;
//
} else {
// Instead we block on incoming commands
if let Some(cmd) = rx.recv().await {
match cmd {
PeriodicControlCommand::Start => {
running = true;
clock.reset_after(clock.period())
}
PeriodicControlCommand::SetMinInterval(i) => {
running = true;
let base_interval = time::Duration::from_micros(i.0 as u64 * 3 / 4);
clock = time::interval_at(
time::Instant::now() + base_interval.into(),
base_interval.into(),
);
}
_ => {} // Other commands don't mutate state or start the clock
}
}
}
}
})
}
// https://datatracker.ietf.org/doc/html/rfc5880#section-6.8.6
async fn receive_control_packet(self: Arc<Self>, p: &BfdPacket) {
let received_state = match p.flags.state() {
Err(_) => {
eprintln!("Invalid state, discarding");
return;
}
Ok(v) => v,
};
// If the version number is not correct (1), the packet MUST be discarded.
if p.flags.vers() != 1 {
eprintln!("Invalid version {}, discarding", p.flags.vers());
return;
}
// If the Length field is less than the minimum correct value (24 if the A bit is clear, or 26 if the A bit is
// set), the packet MUST be discarded.
if p.flags.length() < 24 || (p.flags.length() < 26 && p.flags.auth_present()) {
eprintln!("Invalid packet length {}, discarding", p.flags.length());
return;
}
// TODO: If the Length field is greater than the payload of the encapsulating protocol, the packet MUST be
// discarded.
// If the Detect Mult field is zero, the packet MUST be discarded.
if p.flags.detect_mult() == 0 {
eprintln!("Invalid detect mult {}, discarding", p.flags.detect_mult());
return;
}
//If the Multipoint (M) bit is nonzero, the packet MUST be discarded.
if p.flags.multipoint() {
eprintln!("Invalid multipoint enabled, discarding");
return;
}
// If the My Discriminator field is zero, the packet MUST be discarded.
if p.my_disc == BfdDiscriminator(0) {
eprintln!("Invalid my discriminator {:?}, discarding", p.my_disc);
return;
}
let state_read = self.state.read().await;
// If the Your Discriminator field is nonzero, it MUST be used to select the session with which this BFD packet
// is associated. If no session is found, the packet MUST be discarded.
//
// TODO: actually implement multiplexing
if p.your_disc != BfdDiscriminator(0) && p.your_disc != state_read.local_discr {
eprintln!(
"Received unexpected discriminator {:?}, discarding",
p.your_disc
);
return;
}
// If the Your Discriminator field is zero and the State field is not Down or AdminDown, the packet MUST be
// discarded.
if p.your_disc == BfdDiscriminator(0)
&& (received_state != BfdState::Down && received_state != BfdState::AdminDown)
{
eprintln!(
"Got packet with zero discriminator and invalid state {:?}, discarding",
received_state
);
return;
}
// If the A bit is set and no authentication is in use (bfd.AuthType is zero), the packet MUST be discarded.
if p.flags.auth_present() && state_read.auth_type == BfdAuthType::None {
eprintln!("Got packet with auth enabled when we disagree, discarding");
return;
}
// If the A bit is clear and authentication is in use (bfd.AuthType is nonzero), the packet MUST be discarded.
if !p.flags.auth_present() && state_read.auth_type != BfdAuthType::None {
eprintln!("Got packet without auth when we expect it, discarding");
return;
}
// If the A bit is set, the packet MUST be authenticated under the rules of section 6.7, based on the
// authentication type in use (bfd.AuthType). This may cause the packet to be discarded.
if p.flags.auth_present() {
unimplemented!("Authentication is not implemented");
}
drop(state_read);
let mut state_write = self.state.write().await;
// Set bfd.RemoteDiscr to the value of My Discriminator.
state_write.remote_discr = p.my_disc;
// Set bfd.RemoteState to the value of the State (Sta) field.
state_write.remote_session_state = received_state;
// Set bfd.RemoteDemandMode to the value of the Demand (D) bit.
state_write.remote_demand_mode = p.flags.demand();
// Set bfd.RemoteMinRxInterval to the value of Required Min RX Interval.
state_write.remote_min_rx_interval = p.required_min_rx;
drop(state_write);
// If the Required Min Echo RX Interval field is zero, the transmission of Echo packets, if any, MUST cease.
if p.required_min_echo_rx == BfdInterval(0) {
// TODO: implement echo thread
}
// If a Poll Sequence is being transmitted by the local system and the Final (F) bit in the received packet is
// set, the Poll Sequence MUST be terminated.
//
// TODO: poll stuff
// Update the transmit interval as described in section 6.8.2.
self.clone().update_transmit_interval().await;
// Update the Detection Time as described in section 6.8.4.
self.clone().update_detection_time(&p).await;
// There's not much actual work to do here so just hold a write lock for all of it
let mut state_write = self.state.write().await;
// If bfd.SessionState is AdminDown
// Discard the packet
if state_write.session_state == BfdState::AdminDown {
return;
}
// If received state is AdminDown
if received_state == BfdState::AdminDown {
// If bfd.SessionState is not Down
if state_write.session_state != BfdState::Down {
// Set bfd.LocalDiag to 3 (Neighbor signaled session down)
state_write.local_diag = BfdDiagnostic::NeighborDown;
// Set bfd.SessionState to Down
state_write.session_state = BfdState::Down;
}
} else {
// If bfd.SessionState ...
match state_write.session_state {
BfdState::Down => {
// If received State is Down
if received_state == BfdState::Down {
// Set bfd.SessionState to Init
state_write.session_state = BfdState::Init;
// Else if received State is Init
} else if received_state == BfdState::Init {
// Set bfd.SessionState to Up
state_write.session_state = BfdState::Up;
}
}
BfdState::Init => {
// If received State is Init or Up
if received_state == BfdState::Init || received_state == BfdState::Up {
// Set bfd.SessionState to Up
state_write.session_state = BfdState::Up;
}
}
BfdState::Up => {
// If received State is Down
if received_state == BfdState::Down {
// Set bfd.LocalDiag to 3 (Neighbor signaled session down)
state_write.local_diag = BfdDiagnostic::NeighborDown;
// Set bfd.SessionState to Down
state_write.session_state = BfdState::Down;
}
}
BfdState::AdminDown => unreachable!("unexpected AdminDown"), // AdminDown is discarded earlier
}
drop(state_write);
// Check to see if Demand mode should become active or not (see section 6.6).
// If bfd.RemoteDemandMode is 1, bfd.SessionState is Up, and bfd.RemoteSessionState is Up, Demand mode is
// active on the remote system and the local system MUST cease the periodic transmission of BFD Control
// packets (see section 6.8.7).
// TODO: implement ceasing/restarting of control packets due to demand mode
if p.flags.demand() {
eprintln!("WARNING: Demand mode requested but not implemented");
}
// If the Poll (P) bit is set, send a BFD Control packet to the remote system with the Poll (P) bit clear,
// and the Final (F) bit set (see section 6.8.7).
if p.flags.poll() {
// TODO: Implement sending stuff
}
// If the packet was not discarded, it has been received for purposes of the Detection Time expiration rules
// in section 6.8.4.
}
}
// https://datatracker.ietf.org/doc/html/rfc5880#section-6.8.2
async fn update_transmit_interval(self: Arc<Self>) {
let state = self.state.read().await;
state
.periodic_cmd_channel
.send(PeriodicControlCommand::SetMinInterval(std::cmp::max(
state.desired_min_tx_interval,
state.remote_min_rx_interval,
)))
.await
.unwrap()
}
// https://datatracker.ietf.org/doc/html/rfc5880#section-6.8.4
async fn update_detection_time(self: Arc<Self>, p: &BfdPacket) {
let mut state = self.state.write().await;
state.detection_time = if !state.demand_mode {
time::Duration::from_micros(
p.flags.detect_mult() as u64
* std::cmp::max(state.required_min_rx_interval, p.desired_min_tx).0 as u64,
)
} else {
time::Duration::from_micros(
state.detect_mult as u64
* std::cmp::max(state.desired_min_tx_interval, state.remote_min_rx_interval).0
as u64,
)
}
}
async fn run(self: Arc<Self>, rx: mpsc::Receiver<SessionControlCommand>) {
let (cmd_tx, cmd_rx) = mpsc::channel(32);
self.state.write().await.periodic_cmd_channel = cmd_tx;
let rxt = self.clone().spawn_control_thread(rx).await;
let pxt = self
.clone()
.spawn_periodic_thread(cmd_rx, self.state.read().await.desired_min_tx_interval)
.await;
join!(rxt, pxt).0.unwrap();
}
}
#[tokio::main]
async fn main() -> io::Result<()> {
let local = SocketAddr::from_str("192.168.65.224:3784").unwrap();
let peers = vec![SocketAddr::from_str("127.0.0.1:3784").unwrap()];
let mut sessions = HashMap::new();
for peer in peers {
let (tx, rx) = mpsc::channel(32);
let session = Arc::new(BfdSession::new(local.ip(), peer.ip()).await.unwrap());
let handle = task::spawn(session.clone().run(rx));
sessions.insert((local.ip(), peer.ip()), (tx, session, handle));
}
let control_sock = Arc::new(UdpSocket::bind(local).await?);
// If BFD authentication is not in use on a session, all BFD Control packets for the session MUST be sent with a
// Time to Live (TTL) or Hop Limit value of 255.
control_sock.set_ttl(255).unwrap();
let echo_socket = Arc::new(UdpSocket::bind(SocketAddr::new(local.ip(), ECHO_PORT)).await?);
let rx_thread = task::spawn(async move {
let mut buf = [0; 1024];
loop {
// TODO: All received BFD Control packets that are demultiplexed to the session MUST be discarded if the
// received TTL or Hop Limit is not equal to 255.
let (len, addr) = control_sock.recv_from(&mut buf).await.unwrap(); // TODO: fallibility?
println!("{:?} bytes received from {:?}", len, addr);
if let Some(session) =
sessions.get(&(control_sock.local_addr().unwrap().ip(), addr.ip()))
{
println!("matched to session");
session
.0
.send(SessionControlCommand::RxPacket(buf[0..len].to_vec()))
.await
.unwrap();
}
}
});
tokio::join!(rx_thread)
.0
.expect("Unable to join on the receive thread");
Ok(())
}