diff --git a/crates/arti/src/dns.rs b/crates/arti/src/dns.rs index 203a20431..c8b00354e 100644 --- a/crates/arti/src/dns.rs +++ b/crates/arti/src/dns.rs @@ -3,13 +3,15 @@ //! A resolver is launched with [`run_dns_resolver()`], which listens for new //! connections and then runs +use futures::lock::Mutex; use futures::stream::StreamExt; use futures::task::SpawnExt; +use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use trust_dns_proto::op::{ - header::MessageType, op_code::OpCode, response_code::ResponseCode, Message, + header::MessageType, op_code::OpCode, response_code::ResponseCode, Message, Query, }; use trust_dns_proto::rr::{DNSClass, Name, RData, Record, RecordType}; use trust_dns_proto::serialize::binary::{BinDecodable, BinEncodable}; @@ -22,18 +24,11 @@ use anyhow::{anyhow, Result}; /// Maximum length for receiving a single datagram const MAX_DATAGRAM_SIZE: usize = 1536; -/// Send an error DNS response with code NotImplemented -async fn not_implemented(id: u16, addr: &SocketAddr, socket: &U) -> Result<()> { - let response = Message::error_msg(id, OpCode::Query, ResponseCode::NotImp); - socket.send(&response.to_bytes()?, addr).await?; - Ok(()) -} - /// A Key used to isolate dns requests. /// /// Composed of an usize (representing which listener socket accepted /// the connection and the source IpAddr of the client) -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct DnsIsolationKey(usize, IpAddr); impl arti_client::isolation::IsolationHelper for DnsIsolationKey { @@ -50,31 +45,38 @@ impl arti_client::isolation::IsolationHelper for DnsIsolationKey { } } -/// Given a datagram containing a DNS query, resolve the query over -/// the Tor network and send the response back. -async fn handle_dns_req( - tor_client: TorClient, - socket_id: usize, - packet: &[u8], +/// Identifier for a DNS request, composed of its source IP and transaction ID +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct DnsCacheKey(DnsIsolationKey, Vec); + +/// Target for a DNS response +#[derive(Debug, Clone)] +struct DnsResponseTarget { + /// Transaction ID + id: u16, + /// Address of the client addr: SocketAddr, + /// Socket to send the response through socket: Arc, -) -> Result<()> +} + +/// Run a DNS query over tor, returning either a list of answers, or a DNS error code. +async fn do_query( + tor_client: TorClient, + queries: &[Query], + prefs: &StreamPrefs, +) -> Result, ResponseCode> where R: Runtime, - U: UdpSocket, { - let mut query = Message::from_bytes(packet)?; - let id = query.id(); - let mut answers = Vec::new(); - let mut prefs = StreamPrefs::new(); - prefs.set_isolation(DnsIsolationKey(socket_id, addr.ip())); - - for query in query.queries() { + for query in queries { let mut a = Vec::new(); let mut ptr = Vec::new(); - // TODO maybe support ANY? + + // TODO if there are N questions, this would take N rtt to answer. By joining all futures it + // could take only 1 rtt, but having more than 1 question is actually very rare. match query.query_class() { DNSClass::IN => { match query.query_type() { @@ -83,27 +85,36 @@ where // name would be "torproject.org." without this name.set_fqdn(false); let res = tor_client - .resolve_with_prefs(&name.to_utf8(), &prefs) - .await?; + .resolve_with_prefs(&name.to_utf8(), prefs) + .await + .map_err(|_| ResponseCode::ServFail)?; for ip in res { a.push((query.name().clone(), ip, typ)); } } RecordType::PTR => { - let addr = query.name().parse_arpa_name()?.addr(); - let res = tor_client.resolve_ptr_with_prefs(addr, &prefs).await?; + let addr = query + .name() + .parse_arpa_name() + .map_err(|_| ResponseCode::FormErr)? + .addr(); + let res = tor_client + .resolve_ptr_with_prefs(addr, prefs) + .await + .map_err(|_| ResponseCode::ServFail)?; for domain in res { - let domain = Name::from_utf8(domain)?; + let domain = + Name::from_utf8(domain).map_err(|_| ResponseCode::ServFail)?; ptr.push((query.name().clone(), domain)); } } _ => { - return not_implemented(id, &addr, &*socket).await; + return Err(ResponseCode::NotImp); } } } _ => { - return not_implemented(id, &addr, &*socket).await; + return Err(ResponseCode::NotImp); } } for (name, ip, typ) in a { @@ -122,18 +133,85 @@ where } } - let mut response = Message::new(); - response - .set_id(id) - .set_message_type(MessageType::Response) - .set_op_code(OpCode::Query) - .set_recursion_desired(query.recursion_desired()) - .set_recursion_available(true) - .add_queries(query.take_queries()) - .add_answers(answers); - // TODO maybe add some edns? + Ok(answers) +} - socket.send(&response.to_bytes()?, &addr).await?; +/// Given a datagram containing a DNS query, resolve the query over +/// the Tor network and send the response back. +async fn handle_dns_req( + tor_client: TorClient, + socket_id: usize, + packet: &[u8], + addr: SocketAddr, + socket: Arc, + current_requests: &Mutex>>>, +) -> Result<()> +where + R: Runtime, + U: UdpSocket, +{ + // if we can't parse the request, don't try to answer it. + let mut query = Message::from_bytes(packet)?; + let id = query.id(); + let queries = query.queries(); + let isolation = DnsIsolationKey(socket_id, addr.ip()); + + let request_id = { + let request_id = DnsCacheKey(isolation.clone(), queries.to_vec()); + + let response_target = DnsResponseTarget { id, addr, socket }; + + let mut current_requests = current_requests.lock().await; + + let req = current_requests.entry(request_id.clone()).or_default(); + req.push(response_target); + + if req.len() > 1 { + debug!("Received a query already being served"); + return Ok(()); + } + debug!("Received a new query"); + + request_id + }; + + let mut prefs = StreamPrefs::new(); + prefs.set_isolation(isolation); + + let mut response = match do_query(tor_client, queries, &prefs).await { + Ok(answers) => { + let mut response = Message::new(); + response + .set_message_type(MessageType::Response) + .set_op_code(OpCode::Query) + .set_recursion_desired(query.recursion_desired()) + .set_recursion_available(true) + .add_queries(query.take_queries()) + .add_answers(answers); + // TODO maybe add some edns? + response + } + Err(error_type) => Message::error_msg(id, OpCode::Query, error_type), + }; + + // remove() should never return None, but just in case + let targets = current_requests + .lock() + .await + .remove(&request_id) + .unwrap_or_default(); + + for target in targets { + response.set_id(target.id); + // ignore errors, we want to reply to everybody + let response = if let Ok(r) = response.to_bytes() { + r + } else { + error!("Failed to serialize DNS packet: {:?}", response); + continue; + }; + let _ = target.socket.send(&response, &target.addr).await; + } Ok(()) } @@ -184,6 +262,7 @@ pub async fn run_dns_resolver( }), ); + let pending_requests = Arc::new(Mutex::new(HashMap::new())); while let Some((packet, id)) = incoming.next().await { let (packet, size, addr, socket) = match packet { Ok(packet) => packet, @@ -195,10 +274,21 @@ pub async fn run_dns_resolver( }; let client_ref = tor_client.clone(); - runtime.spawn(async move { - let res = handle_dns_req(client_ref, id, &packet[..size], addr, socket).await; - if let Err(e) = res { - warn!("connection exited with error: {}", e); + runtime.spawn({ + let pending_requests = pending_requests.clone(); + async move { + let res = handle_dns_req( + client_ref, + id, + &packet[..size], + addr, + socket, + &pending_requests, + ) + .await; + if let Err(e) = res { + warn!("connection exited with error: {}", e); + } } })?; }