From 286b21e940ef49ad7d30f83492f2e938a04b118e Mon Sep 17 00:00:00 2001 From: Keenan Tims Date: Tue, 6 May 2025 11:05:31 -0700 Subject: [PATCH] refactor for DRY and compartmentalization --- src/lib.rs | 172 +++++++++++++++++++--------------------------------- src/main.rs | 18 +++--- 2 files changed, 72 insertions(+), 118 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index dbc6f66..2173902 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,13 +2,13 @@ use crc32fast::{hash, Hasher}; use log::warn; use nom::bytes::complete::{tag, take}; use nom::error::ParseError; -use nom::multi::{many, many0}; +use nom::multi::many; use nom::number::complete::{be_u128, be_u16, be_u32, be_u8}; use nom::{AsBytes, IResult, Parser}; use rand::{self, RngCore}; use serde::Serialize; use std::fmt::{self, Debug}; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; // https://github.com/tailscale/tailscale/blob/main/net/stun/stun.go @@ -45,15 +45,19 @@ pub struct TxId([u8; 12]); impl Default for TxId { fn default() -> Self { - Self::new() + Self::new([0; 12]) } } impl TxId { - pub fn new() -> Self { + pub fn new(tx_id: [u8; 12]) -> Self { + Self(tx_id) + } + + pub fn random() -> Self { let mut tx_id = [0; 12]; rand::rng().fill_bytes(&mut tx_id); - Self(tx_id) + Self::new(tx_id) } pub fn from_bytes(bytes: &[u8]) -> Self { @@ -65,26 +69,6 @@ impl TxId { pub fn as_bytes(&self) -> &[u8] { &self.0 } - - pub fn make_request(&self) -> Vec { - const LEN_ATTR_SOFTWARE: u16 = 4 + SOFTWARE.len() as u16; - let mut buf = - Vec::with_capacity((HEADER_LEN + LEN_ATTR_SOFTWARE + LEN_FINGERPRINT) as usize); - buf.extend(&BINDING_REQUEST); - buf.extend((LEN_ATTR_SOFTWARE + LEN_FINGERPRINT).to_be_bytes()); - buf.extend(&MAGIC_COOKIE); - buf.extend(self.as_bytes()); - buf.extend(ATTR_NUM_SOFTWARE.to_be_bytes()); - buf.extend((SOFTWARE.len() as u16).to_be_bytes()); - buf.extend(&SOFTWARE); - - let fp = fingerprint(&buf); - buf.extend(ATTR_NUM_FINGERPRINT.to_be_bytes()); - buf.extend(4_u16.to_be_bytes()); - buf.extend(&fp.to_be_bytes()); - - buf - } } impl Serialize for TxId { @@ -138,10 +122,10 @@ pub enum StunMethod { #[derive(Debug, Clone, Serialize)] pub enum StunAttribute { - MappedAddress(AddrPort), - XorMappedAddress(AddrPort), - SourceAddress(AddrPort), - ChangedAddress(AddrPort), + MappedAddress(SocketAddr), + XorMappedAddress(SocketAddr), + SourceAddress(SocketAddr), + ChangedAddress(SocketAddr), Username(String), MessageIntegrity([u8; 20]), Fingerprint((u32, bool)), @@ -150,9 +134,9 @@ pub enum StunAttribute { Nonce(String), UnknownAttributes(Vec), Software(String), - AlternateServer(AddrPort), - ResponseOrigin(AddrPort), - OtherAddress(AddrPort), + AlternateServer(SocketAddr), + ResponseOrigin(SocketAddr), + OtherAddress(SocketAddr), Unknown((u16, Vec)), } @@ -165,43 +149,23 @@ fn addr_family(addr: &IpAddr) -> &'static str { impl fmt::Display for StunAttribute { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Helper function for attributes with IP and port + fn format_ip_port( + f: &mut fmt::Formatter<'_>, + label: &str, + addr: &SocketAddr, + ) -> fmt::Result { + write!(f, " {} ({}) {}", label, addr_family(&addr.ip()), addr) + } + match self { - StunAttribute::MappedAddress(a) => { - write!( - f, - " MappedAddress ({}) {}:{}", - addr_family(&a.address), - a.address, - a.port - ) - } - StunAttribute::SourceAddress(a) => { - write!( - f, - " SourceAddress ({}) {}:{}", - addr_family(&a.address), - a.address, - a.port - ) - } - StunAttribute::ChangedAddress(a) => { - write!( - f, - " ChangedAddress ({}) {}:{}", - addr_family(&a.address), - a.address, - a.port - ) - } - StunAttribute::XorMappedAddress(a) => { - write!( - f, - " XorMappedAddress ({}) {}:{}", - addr_family(&a.address), - a.address, - a.port - ) - } + StunAttribute::MappedAddress(a) => format_ip_port(f, "MappedAddress", a), + StunAttribute::SourceAddress(a) => format_ip_port(f, "SourceAddress", a), + StunAttribute::ChangedAddress(a) => format_ip_port(f, "ChangedAddress", a), + StunAttribute::XorMappedAddress(a) => format_ip_port(f, "XorMappedAddress", a), + StunAttribute::AlternateServer(a) => format_ip_port(f, "AlternateServer", a), + StunAttribute::ResponseOrigin(a) => format_ip_port(f, "ResponseOrigin", a), + StunAttribute::OtherAddress(a) => format_ip_port(f, "OtherAddress", a), StunAttribute::Username(username) => write!(f, " Username {}", username), StunAttribute::MessageIntegrity(msg_integrity) => { write!(f, " MessageIntegrity {:?}", msg_integrity) @@ -223,33 +187,6 @@ impl fmt::Display for StunAttribute { write!(f, " UnknownAttributes {:?}", unknown_attrs) } StunAttribute::Software(software) => write!(f, " Software {}", software), - StunAttribute::AlternateServer(a) => { - write!( - f, - " AlternateServer ({}) {}:{}", - addr_family(&a.address), - a.address, - a.port - ) - } - StunAttribute::ResponseOrigin(a) => { - write!( - f, - " ResponseOrigin ({}) {}:{}", - addr_family(&a.address), - a.address, - a.port - ) - } - StunAttribute::OtherAddress(a) => { - write!( - f, - " OtherAddress ({}) {}:{}", - addr_family(&a.address), - a.address, - a.port - ) - } StunAttribute::Unknown((attr_type, data)) => { write!(f, " Unknown ({:04x}) {:?}", attr_type, data) } @@ -298,7 +235,7 @@ impl fmt::Display for StunAttributes { } impl StunAttributes { - pub fn mapped_address(&self) -> Option<&AddrPort> { + pub fn mapped_address(&self) -> Option<&SocketAddr> { self.0.iter().find_map(|attr| match attr { StunAttribute::MappedAddress(addr) | StunAttribute::XorMappedAddress(addr) => { Some(addr) @@ -329,6 +266,26 @@ impl StunMessage { } } +pub fn rand_request() -> Vec { + const LEN_ATTR_SOFTWARE: u16 = 4 + SOFTWARE.len() as u16; + let tx_id = TxId::random(); + let mut buf = Vec::with_capacity((HEADER_LEN + LEN_ATTR_SOFTWARE + LEN_FINGERPRINT) as usize); + buf.extend(&BINDING_REQUEST); + buf.extend((LEN_ATTR_SOFTWARE + LEN_FINGERPRINT).to_be_bytes()); + buf.extend(&MAGIC_COOKIE); + buf.extend(tx_id.as_bytes()); + buf.extend(ATTR_NUM_SOFTWARE.to_be_bytes()); + buf.extend((SOFTWARE.len() as u16).to_be_bytes()); + buf.extend(&SOFTWARE); + + let fp = fingerprint(&buf); + buf.extend(ATTR_NUM_FINGERPRINT.to_be_bytes()); + buf.extend(4_u16.to_be_bytes()); + buf.extend(&fp.to_be_bytes()); + + buf +} + fn take_txid>(input: I) -> IResult where I: nom::Input + AsBytes, @@ -377,11 +334,8 @@ where attr } } else { - if hasher.is_some() { - hasher - .as_mut() - .unwrap() - .update(input.take(input.offset(&new_input)).as_bytes()); + if let Some(hasher) = hasher.as_mut() { + hasher.update(input.take(input.offset(&new_input)).as_bytes()); } else { warn!("Received attributes after FINGERPRINT"); } @@ -430,7 +384,7 @@ fn parse_stun_message_type, E: ParseError>( fn parse_stun_address, E: ParseError>( input: I, -) -> IResult { +) -> IResult { let (input, _) = take(1usize)(input)?; let (input, family) = be_u8(input)?; let (input, port) = be_u16(input)?; @@ -448,20 +402,20 @@ fn parse_stun_address, E: ParseError>( Ok((input, addr)) } -fn parse_stun_xor_address>(input: I, tx_id: &TxId) -> IResult +fn parse_stun_xor_address>(input: I, tx_id: &TxId) -> IResult where I: nom::Input, { let (input, addr) = parse_stun_address(input)?; - let xor_port = addr.port ^ 0x2112; - let xor_addr = match addr.address { - IpAddr::V4(v4) => { - let v4 = u32::from(v4); + let xor_port = addr.port() ^ 0x2112; + let xor_addr = match addr { + SocketAddr::V4(v4) => { + let v4 = v4.ip().to_bits(); let xor_v4 = v4 ^ 0x2112a442; IpAddr::V4(xor_v4.into()) } - IpAddr::V6(v6) => { - let v6 = u128::from(v6); + SocketAddr::V6(v6) => { + let v6 = v6.ip().to_bits(); let xor_v6: u128 = v6 ^ (0x2112a442 << 96 | u128::from(tx_id)); IpAddr::V6(xor_v6.into()) } diff --git a/src/main.rs b/src/main.rs index 9b75699..03aff65 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use clap::ValueEnum; use log::{debug, info}; -use std::net::{IpAddr, ToSocketAddrs, UdpSocket}; -use tailstun::{AddrPort, StunMessage, TxId}; +use std::net::{IpAddr, SocketAddr, ToSocketAddrs, UdpSocket}; +use tailstun::StunMessage; #[derive(Debug, Clone, ValueEnum)] enum OutputFormat { @@ -18,14 +18,14 @@ impl OutputFormat { OutputFormat::Yaml => serde_yaml::to_string(msg).unwrap(), } } - fn format_address(&self, a: &AddrPort) -> String { - let a = match a.address { - IpAddr::V4(_) => a.address, - IpAddr::V6(v6) => { - if let Some(v4) = v6.to_ipv4_mapped() { + fn format_address(&self, a: &SocketAddr) -> String { + let a = match a { + SocketAddr::V4(_) => a.ip(), + SocketAddr::V6(v6) => { + if let Some(v4) = v6.ip().to_ipv4_mapped() { IpAddr::V4(v4) } else { - a.address + a.ip() } } }; @@ -92,7 +92,7 @@ fn main() { socket.local_addr() ); - let req = TxId::new().make_request(); + let req = tailstun::rand_request(); debug!("Sending request {:?}", &req);