diff --git a/tor-netdoc/src/doc/routerdesc.rs b/tor-netdoc/src/doc/routerdesc.rs index db224bdb2..8033ff8b0 100644 --- a/tor-netdoc/src/doc/routerdesc.rs +++ b/tor-netdoc/src/doc/routerdesc.rs @@ -1,4 +1,3 @@ -//! Parsing implementation for Tor router descriptors. //! //! A "router descriptor" is a signed statment that a relay makes //! about itself, explaining its keys, its capabilities, its location, @@ -525,7 +524,11 @@ impl RouterDesc { let ipv4_policy = { let mut pol = AddrPolicy::new(); for ruletok in body.slice(POLICY).iter() { - let accept = ruletok.kwd_str() == "accept"; + let accept = match ruletok.kwd_str() { + "accept" => RuleKind::Accept, + "reject" => RuleKind::Reject, + _ => return Err(Error::Internal(ruletok.pos())), + }; let pat: AddrPortPattern = ruletok .args_as_str() .parse() diff --git a/tor-netdoc/src/types/policy.rs b/tor-netdoc/src/types/policy.rs index 909d9acd7..c5559cd54 100644 --- a/tor-netdoc/src/types/policy.rs +++ b/tor-netdoc/src/types/policy.rs @@ -23,7 +23,7 @@ use std::fmt::Display; use std::str::FromStr; use thiserror::Error; -pub use addrpolicy::{AddrPolicy, AddrPortPattern}; +pub use addrpolicy::{AddrPolicy, AddrPortPattern, RuleKind}; pub use portpolicy::PortPolicy; /// Error from an unpareasble or invalid policy. diff --git a/tor-netdoc/src/types/policy/addrpolicy.rs b/tor-netdoc/src/types/policy/addrpolicy.rs index 3f691e7e3..a139d9618 100644 --- a/tor-netdoc/src/types/policy/addrpolicy.rs +++ b/tor-netdoc/src/types/policy/addrpolicy.rs @@ -1,36 +1,51 @@ /// Implements address policies, based on a series of accept/reject /// rules. use std::fmt::Display; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::str::FromStr; use super::{PolicyError, PortRange}; /// A sequence of rules that are applied to an address:port until one /// matches. +#[derive(Clone, Debug)] pub struct AddrPolicy { rules: Vec, } +/// A kind of policy rule: either accepts or rejects addresses +/// matching a pattern. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum RuleKind { + /// A rule that accepts matching address:port combinations. + Accept, + /// A rule that rejects matching address:port combinations. + Reject, +} + impl AddrPolicy { /// Apply this policy to an address:port combination /// /// We do this by applying each rule in sequence, until one - /// matches. If that rule is accept, we return Some(true). If - /// that rule is reject, we return Some(false). + /// matches. /// /// Returns None if no rule matches. - pub fn allows(&self, addr: &IpAddr, port: u16) -> Option { + pub fn allows(&self, addr: &IpAddr, port: u16) -> Option { match self .rules .iter() .find(|rule| rule.pattern.matches(addr, port)) { - Some(AddrPolicyRule { accept, .. }) => Some(*accept), + Some(AddrPolicyRule { kind, .. }) => Some(*kind), None => None, } } + /// As allows, but accept a SocketAddr. + pub fn allows_sockaddr(&self, addr: &SocketAddr) -> Option { + self.allows(&addr.ip(), addr.port()) + } + /// Create a new AddrPolicy that matches nothing. pub fn new() -> Self { AddrPolicy { rules: Vec::new() } @@ -43,8 +58,8 @@ impl AddrPolicy { /// /// If accept is true, the rule is to accept addresses that match; /// if accept is false, the rule rejects such addresses. - pub fn push(&mut self, accept: bool, pattern: AddrPortPattern) { - self.rules.push(AddrPolicyRule { accept, pattern }) + pub fn push(&mut self, kind: RuleKind, pattern: AddrPortPattern) { + self.rules.push(AddrPolicyRule { kind, pattern }) } } @@ -57,19 +72,25 @@ impl Default for AddrPolicy { /// A single rule in an address policy. /// /// Contains a pattern and what to do with things that match it. +#[derive(Clone, Debug)] struct AddrPolicyRule { /// What do we do with items that match the pattern? - accept: bool, + kind: RuleKind, /// What pattern are we trying to match? pattern: AddrPortPattern, } +/* impl Display for AddrPolicyRule { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let cmd = if self.accept { "accept" } else { "reject" }; + let cmd = match self.kind { + RuleKind::Accept => "accept", + RuleKind::Reject => "reject", + }; write!(f, "{} {}", cmd, self.pattern) } } +*/ /// A pattern that may or may not match an address and port. /// @@ -89,6 +110,7 @@ impl Display for AddrPolicyRule { /// assert!(pat.matches(&localhost, 22)); /// assert!(! pat.matches(¬_localhost, 22)); /// ``` +#[derive(Clone, Debug)] pub struct AddrPortPattern { pattern: IpPattern, ports: PortRange, @@ -99,6 +121,10 @@ impl AddrPortPattern { pub fn matches(&self, addr: &IpAddr, port: u16) -> bool { self.pattern.matches(addr) && self.ports.contains(port) } + /// As matches, but accept a SocketAddr. + pub fn matches_sockaddr(&self, addr: &SocketAddr) -> bool { + self.matches(&addr.ip(), addr.port()) + } } impl Display for AddrPortPattern { @@ -128,6 +154,7 @@ impl FromStr for AddrPortPattern { } /// A pattern that matches one or more IP addresses. +#[derive(Clone, Debug)] enum IpPattern { /// Match all addresses. Star, @@ -188,6 +215,20 @@ impl Display for IpPattern { } } +/// Helper: try to parse a plain ipv4 address, or an IPv6 address +/// wrapped in brackets. +fn parse_addr(mut s: &str) -> Result { + let bracketed = s.starts_with('[') && s.ends_with(']'); + if bracketed { + s = &s[1..s.len() - 1]; + } + let addr: IpAddr = s.parse().map_err(|_| PolicyError::InvalidAddress)?; + if addr.is_ipv6() != bracketed { + return Err(PolicyError::InvalidAddress); + } + Ok(addr) +} + impl FromStr for IpPattern { type Err = PolicyError; fn from_str(s: &str) -> Result { @@ -199,15 +240,120 @@ impl FromStr for IpPattern { ("*", Some(_)) => Err(PolicyError::MaskWithStar), ("*", None) => Ok(IpPattern::Star), (s, Some(m)) => { - let a: IpAddr = s.parse().map_err(|_| PolicyError::InvalidAddress)?; + let a: IpAddr = parse_addr(s)?; let m: u8 = m.parse().map_err(|_| PolicyError::InvalidMask)?; IpPattern::from_addr_and_mask(a, m) } (s, None) => { - let a: IpAddr = s.parse().map_err(|_| PolicyError::InvalidAddress)?; + let a: IpAddr = parse_addr(s)?; let m = if a.is_ipv4() { 32 } else { 128 }; IpPattern::from_addr_and_mask(a, m) } } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_roundtrip_rules() { + fn check(inp: &str, outp: &str) { + let policy = inp.parse::().unwrap(); + assert_eq!(format!("{}", policy), outp); + } + + check("127.0.0.2/32:77-10000", "127.0.0.2:77-10000"); + check("127.0.0.2/32:*", "127.0.0.2:*"); + check("127.0.0.0/16:9-100", "127.0.0.0/16:9-100"); + check("127.0.0.0/0:443", "*:443"); + check("*:443", "*:443"); + check("[::1]:443", "[::1]:443"); + check("[ffaa::]/16:80", "[ffaa::]/16:80"); + check("[ffaa::77]/128:80", "[ffaa::77]:80"); + } + + #[test] + fn test_bad_rules() { + fn check(s: &str) { + assert!(s.parse::().is_err()); + } + + check("marzipan:80"); + check("1.2.3.4:90-80"); + check("1.2.3.4/100:8888"); + check("[1.2.3.4]/16:80"); + check("[::1]/130:8888"); + } + + #[test] + fn test_rule_matches() { + fn check(addr: &str, yes: &[&str], no: &[&str]) { + use std::net::SocketAddr; + let policy = addr.parse::().unwrap(); + for s in yes { + let sa = s.parse::().unwrap(); + assert!(policy.matches_sockaddr(&sa)); + } + for s in no { + let sa = s.parse::().unwrap(); + assert!(!policy.matches_sockaddr(&sa)); + } + } + + check( + "1.2.3.4/16:80", + &["1.2.3.4:80", "1.2.44.55:80"], + &["9.9.9.9:80", "1.3.3.4:80", "1.2.3.4:81"], + ); + check( + "*:443-8000", + &["1.2.3.4:443", "[::1]:500"], + &["9.0.0.0:80", "[::1]:80"], + ); + check( + "[face::]/8:80", + &["[fab0::7]:80"], + &["[dd00::]:80", "[face::7]:443"], + ); + + check("0.0.0.0/0:*", &["127.0.0.1:80"], &["[f00b::]:80"]); + check("[::]/0:*", &["[f00b::]:80"], &["127.0.0.1:80"]); + } + + #[test] + fn test_policy_matches() -> Result<(), PolicyError> { + let mut policy = AddrPolicy::default(); + policy.push(RuleKind::Accept, "*:443".parse()?); + policy.push(RuleKind::Accept, "[::1]:80".parse()?); + policy.push(RuleKind::Reject, "*:80".parse()?); + + let policy = policy; // drop mut + assert_eq!( + policy.allows_sockaddr(&"[::6]:443".parse().unwrap()), + Some(RuleKind::Accept) + ); + assert_eq!( + policy.allows_sockaddr(&"127.0.0.1:443".parse().unwrap()), + Some(RuleKind::Accept) + ); + assert_eq!( + policy.allows_sockaddr(&"[::1]:80".parse().unwrap()), + Some(RuleKind::Accept) + ); + assert_eq!( + policy.allows_sockaddr(&"[::2]:80".parse().unwrap()), + Some(RuleKind::Reject) + ); + assert_eq!( + policy.allows_sockaddr(&"127.0.0.1:80".parse().unwrap()), + Some(RuleKind::Reject) + ); + assert_eq!( + policy.allows_sockaddr(&"127.0.0.1:66".parse().unwrap()), + None + ); + Ok(()) + } +} diff --git a/tor-netdoc/src/types/policy/portpolicy.rs b/tor-netdoc/src/types/policy/portpolicy.rs index 30457ec70..34f9afc67 100644 --- a/tor-netdoc/src/types/policy/portpolicy.rs +++ b/tor-netdoc/src/types/policy/portpolicy.rs @@ -22,11 +22,11 @@ use super::{PolicyError, PortRange}; /// assert!(! policy.allows_port(1024)); /// assert!(! policy.allows_port(9000)); /// ``` -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct PortPolicy { /// A list of port ranges that this policy allows. /// - /// These ranges are sorted and disjoint. + /// These ranges sorted, disjoint, and compact. allowed: Vec, } @@ -70,6 +70,26 @@ impl PortPolicy { } self.allowed = new_allowed; } + /// Helper: add a new range to the end of this portpolicy. + /// + /// gives an error if this range cannot appear next in sequence. + fn push_policy(&mut self, item: PortRange) -> Result<(), PolicyError> { + if let Some(prev) = self.allowed.last() { + if prev.hi >= item.lo { + // Or should this be ">"? TODO XXXX + return Err(PolicyError::InvalidPolicy); + } else if prev.hi == item.lo - 1 { + // We compress a-b,(b+1)-c into a-c. + let r = PortRange::new_unchecked(prev.lo, item.hi); + self.allowed.pop(); + self.allowed.push(r); + return Ok(()); + } + } + + self.allowed.push(item); + Ok(()) + } /// Return true iff `port` is allowed by this policy. pub fn allows_port(&self, port: u16) -> bool { self.allowed @@ -94,13 +114,7 @@ impl FromStr for PortPolicy { s = &s[7..]; for item in s.split(',') { let r: PortRange = item.parse()?; - if let Some(prev) = result.allowed.last() { - if r.lo <= prev.hi { - // Or should this be "<"? TODO XXXX - return Err(PolicyError::InvalidPolicy); - } - } - result.allowed.push(r); + result.push_policy(r)?; } if invert { result.invert(); @@ -108,3 +122,67 @@ impl FromStr for PortPolicy { Ok(result) } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_roundtrip() { + fn check(inp: &str, outp: &str, allow: &[u16], deny: &[u16]) { + let policy = inp.parse::().unwrap(); + assert_eq!(format!("{}", policy), outp); + for p in allow { + assert!(policy.allows_port(*p)); + } + for p in deny { + assert!(!policy.allows_port(*p)); + } + } + + check( + "accept 1-10,30-50,600", + "accept 1-10,30-50,600", + &[1, 10, 35, 600], + &[0, 11, 55, 599, 601], + ); + check("accept 1-10,11-20", "accept 1-20", &[], &[]); + check( + "reject 1-30", + "accept 31-65535", + &[31, 10001, 65535], + &[0, 1, 30], + ); + check( + "reject 300-500", + "accept 1-299,501-65535", + &[31, 10001, 65535], + &[300, 301, 500], + ); + check("reject 10,11,12,13,15", "accept 1-9,14,16-65535", &[], &[]); + check( + "reject 1-65535", + "reject 1-65535", + &[], + &[1, 300, 301, 500, 10001, 65535], + ); + } + + #[test] + fn test_bad() { + for s in &[ + "ignore 1-10", + "allow 1-100", + "accept", + "reject", + "accept x-y", + "accept 1-20,19-30", + "accept 1-20,20-30", + "reject 1,1,1,1", + "reject 1,2,foo,4", + "reject 5,4,3,2", + ] { + assert!(s.parse::().is_err()); + } + } +}