diff --git a/src/lib.rs b/src/lib.rs index 1936f4f..b805c02 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,10 @@ use crc32fast::hash; -use log::{debug, info, warn}; +use log::warn; use nom::bytes::complete::{tag, take}; use nom::error::ParseError; use nom::multi::{many, many0}; use nom::number::complete::{be_u128, be_u16, be_u32, be_u8}; use nom::{AsBytes, IResult, Parser}; -use rand::seq::{IndexedRandom, SliceRandom}; use rand::{self, RngCore}; use serde::Serialize; use std::fmt::{self, Debug}; @@ -42,6 +41,12 @@ const HEADER_LEN: u16 = 20; #[derive(Debug, Clone)] pub struct TxId([u8; 12]); +impl Default for TxId { + fn default() -> Self { + Self::new() + } +} + impl TxId { pub fn new() -> Self { let mut tx_id = [0; 12]; @@ -101,6 +106,21 @@ fn fingerprint(msg: &[u8]) -> u32 { hash(msg) ^ 0x5354554e } +#[derive(Debug, Clone, Serialize)] +pub struct AddrPort { + pub address: IpAddr, + pub port: u16, +} + +impl From<(IpAddr, u16)> for AddrPort { + fn from(value: (IpAddr, u16)) -> Self { + Self { + address: value.0, + port: value.1, + } + } +} + #[derive(Debug, Clone, Copy, Serialize)] pub enum StunClass { Request = 0, @@ -116,10 +136,10 @@ pub enum StunMethod { #[derive(Debug, Clone, Serialize)] pub enum StunAttribute { - MappedAddress((IpAddr, u16)), - XorMappedAddress((IpAddr, u16)), - SourceAddress((IpAddr, u16)), - ChangedAddress((IpAddr, u16)), + MappedAddress(AddrPort), + XorMappedAddress(AddrPort), + SourceAddress(AddrPort), + ChangedAddress(AddrPort), Username(String), MessageIntegrity([u8; 20]), Fingerprint(u32), @@ -128,7 +148,7 @@ pub enum StunAttribute { Nonce(String), UnknownAttributes(Vec), Software(String), - AlternateServer((IpAddr, u16)), + AlternateServer(AddrPort), Unknown((u16, Vec)), } @@ -142,40 +162,40 @@ fn addr_family(addr: &IpAddr) -> &'static str { impl fmt::Display for StunAttribute { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - StunAttribute::MappedAddress((addr, port)) => { + StunAttribute::MappedAddress(a) => { write!( f, " MappedAddress ({}) {}:{}", - addr_family(addr), - addr, - port + addr_family(&a.address), + a.address, + a.port ) } - StunAttribute::SourceAddress((addr, port)) => { + StunAttribute::SourceAddress(a) => { write!( f, " SourceAddress ({}) {}:{}", - addr_family(addr), - addr, - port + addr_family(&a.address), + a.address, + a.port ) } - StunAttribute::ChangedAddress((addr, port)) => { + StunAttribute::ChangedAddress(a) => { write!( f, " ChangedAddress ({}) {}:{}", - addr_family(addr), - addr, - port + addr_family(&a.address), + a.address, + a.port ) } - StunAttribute::XorMappedAddress((addr, port)) => { + StunAttribute::XorMappedAddress(a) => { write!( f, " XorMappedAddress ({}) {}:{}", - addr_family(addr), - addr, - port + addr_family(&a.address), + a.address, + a.port ) } StunAttribute::Username(username) => writeln!(f, " Username {}", username), @@ -194,13 +214,13 @@ impl fmt::Display for StunAttribute { write!(f, " UnknownAttributes {:?}", unknown_attrs) } StunAttribute::Software(software) => writeln!(f, " Software {}", software), - StunAttribute::AlternateServer((addr, port)) => { + StunAttribute::AlternateServer(a) => { write!( f, " AlternateServer ({}) {}:{}", - addr_family(addr), - addr, - port + addr_family(&a.address), + a.address, + a.port ) } StunAttribute::Unknown((attr_type, data)) => { @@ -240,7 +260,6 @@ impl fmt::Display for StunHeader { #[derive(Debug, Serialize)] pub struct StunAttributes(Vec); - impl fmt::Display for StunAttributes { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, " Attributes")?; @@ -252,9 +271,11 @@ impl fmt::Display for StunAttributes { } impl StunAttributes { - pub fn mapped_address(&self) -> Option<&(IpAddr, u16)> { + pub fn mapped_address(&self) -> Option<&AddrPort> { self.0.iter().find_map(|attr| match attr { - StunAttribute::MappedAddress(addr) | StunAttribute::XorMappedAddress(addr) => Some(addr), + StunAttribute::MappedAddress(addr) | StunAttribute::XorMappedAddress(addr) => { + Some(addr) + } _ => None, }) } @@ -274,7 +295,6 @@ impl fmt::Display for StunMessage { } } - impl StunMessage { pub fn parse(bytes: &[u8]) -> Result>> { let (_, msg) = parse_stun_message(bytes)?; @@ -282,27 +302,27 @@ impl StunMessage { } } -fn take_txid>(bytes: I) -> IResult +fn take_txid>(input: I) -> IResult where I: nom::Input + AsBytes, { - let (bytes, tx_id) = take(12usize)(bytes)?; - Ok((bytes, TxId::from_bytes(tx_id.as_bytes()))) + let (input, tx_id) = take(12usize)(input)?; + Ok((input, TxId::from_bytes(tx_id.as_bytes()))) } fn parse_stun_message<'a, I, E: ParseError>(input: I) -> IResult where I: nom::Input + nom::Compare + nom::Compare<&'a [u8]> + AsBytes + Debug, { - let (bytes, h) = parse_stun_header(input)?; - let (residual, bytes) = take(h.msg_length)(bytes)?; + let (input, h) = parse_stun_header(input)?; + let (residual, input) = take(h.msg_length)(input)?; if residual.input_len() != 0 { warn!("Trailing bytes in STUN message: {:?}", residual); } - let (bytes, attributes) = many0(parse_stun_attribute(&h.tx_id)).parse(bytes)?; - if !bytes.input_len() != 0 { - warn!("Trailing bytes in STUN message attributes: {:?}", bytes); + let (input, attributes) = many0(parse_stun_attribute(&h.tx_id)).parse(input)?; + if !input.input_len() != 0 { + warn!("Trailing bytes in STUN message attributes: {:?}", input); } let attributes = StunAttributes(attributes.iter().filter_map(|i| i.clone()).collect()); @@ -315,10 +335,10 @@ where )) } -fn parse_stun_message_type<'a, I: nom::Input, E: ParseError>( +fn parse_stun_message_type, E: ParseError>( input: I, ) -> IResult { - let (bytes, msg_type_raw) = be_u16(input)?; + let (input, msg_type_raw) = be_u16(input)?; if msg_type_raw & 0b11000000 != 0 { panic!("Invalid STUN message type"); } @@ -335,39 +355,36 @@ fn parse_stun_message_type<'a, I: nom::Input, E: ParseError>( 1 => StunMethod::Binding, _ => panic!("Invalid STUN message method"), }; - Ok((bytes, StunMessageType { class, method })) + Ok((input, StunMessageType { class, method })) } fn parse_stun_address, E: ParseError>( - bytes: I, -) -> IResult { - let (bytes, _) = take(1usize)(bytes)?; - let (bytes, family) = be_u8(bytes)?; - let (bytes, port) = be_u16(bytes)?; - let (bytes, addr) = match family { + input: I, +) -> IResult { + let (input, _) = take(1usize)(input)?; + let (input, family) = be_u8(input)?; + let (input, port) = be_u16(input)?; + let (input, addr) = match family { 0x01 => { - let (bytes, val) = be_u32(bytes)?; - (bytes, (IpAddr::V4(val.into()), port)) + let (input, val) = be_u32(input)?; + (input, (IpAddr::V4(val.into()), port).into()) } 0x02 => { - let (bytes, val) = be_u128(bytes)?; - (bytes, (IpAddr::V6(val.into()), port)) + let (input, val) = be_u128(input)?; + (input, (IpAddr::V6(val.into()), port).into()) } _ => panic!("Invalid address family"), }; - Ok((bytes, addr)) + Ok((input, addr)) } -fn parse_stun_xor_address>( - bytes: I, - tx_id: &TxId, -) -> IResult +fn parse_stun_xor_address>(input: I, tx_id: &TxId) -> IResult where I: nom::Input, { - let (bytes, addr) = parse_stun_address(bytes)?; - let xor_port = addr.1 ^ 0x2112; - let xor_addr = match addr.0 { + let (input, addr) = parse_stun_address(input)?; + let xor_port = addr.port ^ 0x2112; + let xor_addr = match addr.address { IpAddr::V4(v4) => { let v4 = u32::from(v4); let xor_v4 = v4 ^ 0x2112a442; @@ -379,7 +396,7 @@ where IpAddr::V6(xor_v6.into()) } }; - Ok((bytes, (xor_addr, xor_port))) + Ok((input, (xor_addr, xor_port).into())) } fn parse_stun_attribute>( @@ -389,22 +406,22 @@ where I: nom::Input + nom::Compare + AsBytes, { let tx_id = tx_id.clone(); - move |bytes| parse_stun_attribute_impl(bytes, &tx_id) + move |input| parse_stun_attribute_impl(input, &tx_id) } fn parse_stun_attribute_impl>( - bytes: I, + input: I, tx_id: &TxId, ) -> IResult, E> where I: nom::Input + nom::Compare + AsBytes, { - let (bytes, attr_type) = be_u16(bytes)?; - let (bytes, attr_len) = be_u16(bytes)?; - let (bytes, attr_data) = take(attr_len)(bytes)?; + let (input, attr_type) = be_u16(input)?; + let (input, attr_len) = be_u16(input)?; + let (input, attr_data) = take(attr_len)(input)?; if attr_len == 0 { - return Ok((bytes, None)); + return Ok((input, None)); } let attr = match attr_type { @@ -478,19 +495,19 @@ where } }; - Ok((bytes, Some(attr))) + Ok((input, Some(attr))) } fn parse_stun_header<'a, I, E: ParseError>(input: I) -> IResult where I: nom::Input + nom::Compare + nom::Compare<&'a [u8]> + AsBytes, { - let (bytes, msg_type) = parse_stun_message_type(input)?; - let (bytes, msg_length) = be_u16(bytes)?; - let (bytes, _) = tag(MAGIC_COOKIE.as_bytes())(bytes)?; - let (bytes, tx_id) = take_txid(bytes)?; + let (input, msg_type) = parse_stun_message_type(input)?; + let (input, msg_length) = be_u16(input)?; + let (input, _) = tag(MAGIC_COOKIE.as_bytes())(input)?; + let (input, tx_id) = take_txid(input)?; Ok(( - bytes, + input, StunHeader { msg_type, msg_length, diff --git a/src/main.rs b/src/main.rs index bc10bcf..4098ab0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,7 @@ use clap::ValueEnum; -use log::{debug, info, warn}; -use serde::Serialize; -use std::net::{IpAddr, UdpSocket}; -use tailstun::{StunMessage, TxId}; +use log::{debug, info}; +use std::net::{IpAddr, ToSocketAddrs, UdpSocket}; +use tailstun::{AddrPort, StunMessage, TxId}; #[derive(Debug, Clone, ValueEnum)] enum OutputFormat { @@ -19,11 +18,21 @@ impl OutputFormat { OutputFormat::Yaml => serde_yaml::to_string(msg).unwrap(), } } - fn format_address(&self, (addr, port): &(IpAddr, u16)) -> String { + fn format_address(&self, a: &AddrPort) -> String { + let a = match a.address { + IpAddr::V4(_) => a.address, + IpAddr::V6(v6) => { + if let Some(v4) = v6.to_ipv4_mapped() { + IpAddr::V4(v4) + } else { + a.address + } + } + }; match self { - OutputFormat::Text => format!("{}", addr), - OutputFormat::Json => serde_json::to_string_pretty(addr).unwrap(), - OutputFormat::Yaml => serde_yaml::to_string(addr).unwrap(), + OutputFormat::Text => format!("{}", a), + OutputFormat::Json => serde_json::to_string_pretty(&a).unwrap(), + OutputFormat::Yaml => serde_yaml::to_string(&a).unwrap(), } } } @@ -35,13 +44,17 @@ struct Cli { host: String, #[clap(short, long, default_value = "3478")] port: u16, + #[clap(short = '4', conflicts_with = "v6_only", default_value = "false")] + v4_only: bool, + #[clap(short = '6', conflicts_with = "v4_only", default_value = "false")] + v6_only: bool, #[clap(short, long, default_value = "text")] format: OutputFormat, #[clap( short, long, default_value = "false", - help = "Only output the first mapped address" + help = "Only output the first mapped address & convert IPv6-mapped to IPv4" )] address_only: bool, #[command(flatten)] @@ -54,7 +67,20 @@ fn main() { .filter_level(cli.verbose.log_level_filter()) .init(); - let dest = cli.host + ":" + &cli.port.to_string(); + let dest = (cli.host.as_str(), cli.port) + .to_socket_addrs() + .expect("Unable to resolve host") + .find(|a| { + if cli.v4_only { + a.is_ipv4() + } else if cli.v6_only { + a.is_ipv6() + } else { + true + } + }) + .expect("No address found for host"); + let socket = UdpSocket::bind("[::]:0").expect("Unable to bind a UDP socket"); socket .connect(dest) @@ -83,6 +109,7 @@ fn main() { if let Some(addr) = msg.attributes.mapped_address() { println!("{}", cli.format.format_address(addr)); } else { + // No mapped address std::process::exit(1); } } else {