refactor for DRY and compartmentalization

This commit is contained in:
Keenan Tims 2025-05-06 11:05:31 -07:00
parent 1488dba92d
commit 286b21e940
No known key found for this signature in database
GPG Key ID: B8FDD4AD6B193F06
2 changed files with 72 additions and 118 deletions

View File

@ -2,13 +2,13 @@ use crc32fast::{hash, Hasher};
use log::warn; use log::warn;
use nom::bytes::complete::{tag, take}; use nom::bytes::complete::{tag, take};
use nom::error::ParseError; use nom::error::ParseError;
use nom::multi::{many, many0}; use nom::multi::many;
use nom::number::complete::{be_u128, be_u16, be_u32, be_u8}; use nom::number::complete::{be_u128, be_u16, be_u32, be_u8};
use nom::{AsBytes, IResult, Parser}; use nom::{AsBytes, IResult, Parser};
use rand::{self, RngCore}; use rand::{self, RngCore};
use serde::Serialize; use serde::Serialize;
use std::fmt::{self, Debug}; use std::fmt::{self, Debug};
use std::net::IpAddr; use std::net::{IpAddr, SocketAddr};
// https://github.com/tailscale/tailscale/blob/main/net/stun/stun.go // https://github.com/tailscale/tailscale/blob/main/net/stun/stun.go
@ -45,15 +45,19 @@ pub struct TxId([u8; 12]);
impl Default for TxId { impl Default for TxId {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new([0; 12])
} }
} }
impl TxId { impl TxId {
pub fn new() -> Self { pub fn new(tx_id: [u8; 12]) -> Self {
Self(tx_id)
}
pub fn random() -> Self {
let mut tx_id = [0; 12]; let mut tx_id = [0; 12];
rand::rng().fill_bytes(&mut tx_id); rand::rng().fill_bytes(&mut tx_id);
Self(tx_id) Self::new(tx_id)
} }
pub fn from_bytes(bytes: &[u8]) -> Self { pub fn from_bytes(bytes: &[u8]) -> Self {
@ -65,26 +69,6 @@ impl TxId {
pub fn as_bytes(&self) -> &[u8] { pub fn as_bytes(&self) -> &[u8] {
&self.0 &self.0
} }
pub fn make_request(&self) -> Vec<u8> {
const LEN_ATTR_SOFTWARE: u16 = 4 + SOFTWARE.len() as u16;
let mut buf =
Vec::with_capacity((HEADER_LEN + LEN_ATTR_SOFTWARE + LEN_FINGERPRINT) as usize);
buf.extend(&BINDING_REQUEST);
buf.extend((LEN_ATTR_SOFTWARE + LEN_FINGERPRINT).to_be_bytes());
buf.extend(&MAGIC_COOKIE);
buf.extend(self.as_bytes());
buf.extend(ATTR_NUM_SOFTWARE.to_be_bytes());
buf.extend((SOFTWARE.len() as u16).to_be_bytes());
buf.extend(&SOFTWARE);
let fp = fingerprint(&buf);
buf.extend(ATTR_NUM_FINGERPRINT.to_be_bytes());
buf.extend(4_u16.to_be_bytes());
buf.extend(&fp.to_be_bytes());
buf
}
} }
impl Serialize for TxId { impl Serialize for TxId {
@ -138,10 +122,10 @@ pub enum StunMethod {
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub enum StunAttribute { pub enum StunAttribute {
MappedAddress(AddrPort), MappedAddress(SocketAddr),
XorMappedAddress(AddrPort), XorMappedAddress(SocketAddr),
SourceAddress(AddrPort), SourceAddress(SocketAddr),
ChangedAddress(AddrPort), ChangedAddress(SocketAddr),
Username(String), Username(String),
MessageIntegrity([u8; 20]), MessageIntegrity([u8; 20]),
Fingerprint((u32, bool)), Fingerprint((u32, bool)),
@ -150,9 +134,9 @@ pub enum StunAttribute {
Nonce(String), Nonce(String),
UnknownAttributes(Vec<u16>), UnknownAttributes(Vec<u16>),
Software(String), Software(String),
AlternateServer(AddrPort), AlternateServer(SocketAddr),
ResponseOrigin(AddrPort), ResponseOrigin(SocketAddr),
OtherAddress(AddrPort), OtherAddress(SocketAddr),
Unknown((u16, Vec<u8>)), Unknown((u16, Vec<u8>)),
} }
@ -165,43 +149,23 @@ fn addr_family(addr: &IpAddr) -> &'static str {
impl fmt::Display for StunAttribute { impl fmt::Display for StunAttribute {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// Helper function for attributes with IP and port
fn format_ip_port(
f: &mut fmt::Formatter<'_>,
label: &str,
addr: &SocketAddr,
) -> fmt::Result {
write!(f, " {} ({}) {}", label, addr_family(&addr.ip()), addr)
}
match self { match self {
StunAttribute::MappedAddress(a) => { StunAttribute::MappedAddress(a) => format_ip_port(f, "MappedAddress", a),
write!( StunAttribute::SourceAddress(a) => format_ip_port(f, "SourceAddress", a),
f, StunAttribute::ChangedAddress(a) => format_ip_port(f, "ChangedAddress", a),
" MappedAddress ({}) {}:{}", StunAttribute::XorMappedAddress(a) => format_ip_port(f, "XorMappedAddress", a),
addr_family(&a.address), StunAttribute::AlternateServer(a) => format_ip_port(f, "AlternateServer", a),
a.address, StunAttribute::ResponseOrigin(a) => format_ip_port(f, "ResponseOrigin", a),
a.port StunAttribute::OtherAddress(a) => format_ip_port(f, "OtherAddress", a),
)
}
StunAttribute::SourceAddress(a) => {
write!(
f,
" SourceAddress ({}) {}:{}",
addr_family(&a.address),
a.address,
a.port
)
}
StunAttribute::ChangedAddress(a) => {
write!(
f,
" ChangedAddress ({}) {}:{}",
addr_family(&a.address),
a.address,
a.port
)
}
StunAttribute::XorMappedAddress(a) => {
write!(
f,
" XorMappedAddress ({}) {}:{}",
addr_family(&a.address),
a.address,
a.port
)
}
StunAttribute::Username(username) => write!(f, " Username {}", username), StunAttribute::Username(username) => write!(f, " Username {}", username),
StunAttribute::MessageIntegrity(msg_integrity) => { StunAttribute::MessageIntegrity(msg_integrity) => {
write!(f, " MessageIntegrity {:?}", msg_integrity) write!(f, " MessageIntegrity {:?}", msg_integrity)
@ -223,33 +187,6 @@ impl fmt::Display for StunAttribute {
write!(f, " UnknownAttributes {:?}", unknown_attrs) write!(f, " UnknownAttributes {:?}", unknown_attrs)
} }
StunAttribute::Software(software) => write!(f, " Software {}", software), StunAttribute::Software(software) => write!(f, " Software {}", software),
StunAttribute::AlternateServer(a) => {
write!(
f,
" AlternateServer ({}) {}:{}",
addr_family(&a.address),
a.address,
a.port
)
}
StunAttribute::ResponseOrigin(a) => {
write!(
f,
" ResponseOrigin ({}) {}:{}",
addr_family(&a.address),
a.address,
a.port
)
}
StunAttribute::OtherAddress(a) => {
write!(
f,
" OtherAddress ({}) {}:{}",
addr_family(&a.address),
a.address,
a.port
)
}
StunAttribute::Unknown((attr_type, data)) => { StunAttribute::Unknown((attr_type, data)) => {
write!(f, " Unknown ({:04x}) {:?}", attr_type, data) write!(f, " Unknown ({:04x}) {:?}", attr_type, data)
} }
@ -298,7 +235,7 @@ impl fmt::Display for StunAttributes {
} }
impl StunAttributes { impl StunAttributes {
pub fn mapped_address(&self) -> Option<&AddrPort> { pub fn mapped_address(&self) -> Option<&SocketAddr> {
self.0.iter().find_map(|attr| match attr { self.0.iter().find_map(|attr| match attr {
StunAttribute::MappedAddress(addr) | StunAttribute::XorMappedAddress(addr) => { StunAttribute::MappedAddress(addr) | StunAttribute::XorMappedAddress(addr) => {
Some(addr) Some(addr)
@ -329,6 +266,26 @@ impl StunMessage {
} }
} }
pub fn rand_request() -> Vec<u8> {
const LEN_ATTR_SOFTWARE: u16 = 4 + SOFTWARE.len() as u16;
let tx_id = TxId::random();
let mut buf = Vec::with_capacity((HEADER_LEN + LEN_ATTR_SOFTWARE + LEN_FINGERPRINT) as usize);
buf.extend(&BINDING_REQUEST);
buf.extend((LEN_ATTR_SOFTWARE + LEN_FINGERPRINT).to_be_bytes());
buf.extend(&MAGIC_COOKIE);
buf.extend(tx_id.as_bytes());
buf.extend(ATTR_NUM_SOFTWARE.to_be_bytes());
buf.extend((SOFTWARE.len() as u16).to_be_bytes());
buf.extend(&SOFTWARE);
let fp = fingerprint(&buf);
buf.extend(ATTR_NUM_FINGERPRINT.to_be_bytes());
buf.extend(4_u16.to_be_bytes());
buf.extend(&fp.to_be_bytes());
buf
}
fn take_txid<I, E: ParseError<I>>(input: I) -> IResult<I, TxId, E> fn take_txid<I, E: ParseError<I>>(input: I) -> IResult<I, TxId, E>
where where
I: nom::Input<Item = u8> + AsBytes, I: nom::Input<Item = u8> + AsBytes,
@ -377,11 +334,8 @@ where
attr attr
} }
} else { } else {
if hasher.is_some() { if let Some(hasher) = hasher.as_mut() {
hasher hasher.update(input.take(input.offset(&new_input)).as_bytes());
.as_mut()
.unwrap()
.update(input.take(input.offset(&new_input)).as_bytes());
} else { } else {
warn!("Received attributes after FINGERPRINT"); warn!("Received attributes after FINGERPRINT");
} }
@ -430,7 +384,7 @@ fn parse_stun_message_type<I: nom::Input<Item = u8>, E: ParseError<I>>(
fn parse_stun_address<I: nom::Input<Item = u8>, E: ParseError<I>>( fn parse_stun_address<I: nom::Input<Item = u8>, E: ParseError<I>>(
input: I, input: I,
) -> IResult<I, AddrPort, E> { ) -> IResult<I, SocketAddr, E> {
let (input, _) = take(1usize)(input)?; let (input, _) = take(1usize)(input)?;
let (input, family) = be_u8(input)?; let (input, family) = be_u8(input)?;
let (input, port) = be_u16(input)?; let (input, port) = be_u16(input)?;
@ -448,20 +402,20 @@ fn parse_stun_address<I: nom::Input<Item = u8>, E: ParseError<I>>(
Ok((input, addr)) Ok((input, addr))
} }
fn parse_stun_xor_address<I, E: ParseError<I>>(input: I, tx_id: &TxId) -> IResult<I, AddrPort, E> fn parse_stun_xor_address<I, E: ParseError<I>>(input: I, tx_id: &TxId) -> IResult<I, SocketAddr, E>
where where
I: nom::Input<Item = u8>, I: nom::Input<Item = u8>,
{ {
let (input, addr) = parse_stun_address(input)?; let (input, addr) = parse_stun_address(input)?;
let xor_port = addr.port ^ 0x2112; let xor_port = addr.port() ^ 0x2112;
let xor_addr = match addr.address { let xor_addr = match addr {
IpAddr::V4(v4) => { SocketAddr::V4(v4) => {
let v4 = u32::from(v4); let v4 = v4.ip().to_bits();
let xor_v4 = v4 ^ 0x2112a442; let xor_v4 = v4 ^ 0x2112a442;
IpAddr::V4(xor_v4.into()) IpAddr::V4(xor_v4.into())
} }
IpAddr::V6(v6) => { SocketAddr::V6(v6) => {
let v6 = u128::from(v6); let v6 = v6.ip().to_bits();
let xor_v6: u128 = v6 ^ (0x2112a442 << 96 | u128::from(tx_id)); let xor_v6: u128 = v6 ^ (0x2112a442 << 96 | u128::from(tx_id));
IpAddr::V6(xor_v6.into()) IpAddr::V6(xor_v6.into())
} }

View File

@ -1,7 +1,7 @@
use clap::ValueEnum; use clap::ValueEnum;
use log::{debug, info}; use log::{debug, info};
use std::net::{IpAddr, ToSocketAddrs, UdpSocket}; use std::net::{IpAddr, SocketAddr, ToSocketAddrs, UdpSocket};
use tailstun::{AddrPort, StunMessage, TxId}; use tailstun::StunMessage;
#[derive(Debug, Clone, ValueEnum)] #[derive(Debug, Clone, ValueEnum)]
enum OutputFormat { enum OutputFormat {
@ -18,14 +18,14 @@ impl OutputFormat {
OutputFormat::Yaml => serde_yaml::to_string(msg).unwrap(), OutputFormat::Yaml => serde_yaml::to_string(msg).unwrap(),
} }
} }
fn format_address(&self, a: &AddrPort) -> String { fn format_address(&self, a: &SocketAddr) -> String {
let a = match a.address { let a = match a {
IpAddr::V4(_) => a.address, SocketAddr::V4(_) => a.ip(),
IpAddr::V6(v6) => { SocketAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() { if let Some(v4) = v6.ip().to_ipv4_mapped() {
IpAddr::V4(v4) IpAddr::V4(v4)
} else { } else {
a.address a.ip()
} }
} }
}; };
@ -92,7 +92,7 @@ fn main() {
socket.local_addr() socket.local_addr()
); );
let req = TxId::new().make_request(); let req = tailstun::rand_request();
debug!("Sending request {:?}", &req); debug!("Sending request {:?}", &req);