deduplicate dns requests based on transaction id

This commit is contained in:
trinity-1686a 2022-05-30 09:52:11 +00:00 committed by Ian Jackson
parent 75f968017d
commit 266b278c74
1 changed files with 138 additions and 48 deletions

View File

@ -3,13 +3,15 @@
//! A resolver is launched with [`run_dns_resolver()`], which listens for new
//! connections and then runs
use futures::lock::Mutex;
use futures::stream::StreamExt;
use futures::task::SpawnExt;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use tracing::{error, info, warn};
use tracing::{debug, error, info, warn};
use trust_dns_proto::op::{
header::MessageType, op_code::OpCode, response_code::ResponseCode, Message,
header::MessageType, op_code::OpCode, response_code::ResponseCode, Message, Query,
};
use trust_dns_proto::rr::{DNSClass, Name, RData, Record, RecordType};
use trust_dns_proto::serialize::binary::{BinDecodable, BinEncodable};
@ -22,18 +24,11 @@ use anyhow::{anyhow, Result};
/// Maximum length for receiving a single datagram
const MAX_DATAGRAM_SIZE: usize = 1536;
/// Send an error DNS response with code NotImplemented
async fn not_implemented<U: UdpSocket>(id: u16, addr: &SocketAddr, socket: &U) -> Result<()> {
let response = Message::error_msg(id, OpCode::Query, ResponseCode::NotImp);
socket.send(&response.to_bytes()?, addr).await?;
Ok(())
}
/// A Key used to isolate dns requests.
///
/// Composed of an usize (representing which listener socket accepted
/// the connection and the source IpAddr of the client)
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct DnsIsolationKey(usize, IpAddr);
impl arti_client::isolation::IsolationHelper for DnsIsolationKey {
@ -50,31 +45,38 @@ impl arti_client::isolation::IsolationHelper for DnsIsolationKey {
}
}
/// Given a datagram containing a DNS query, resolve the query over
/// the Tor network and send the response back.
async fn handle_dns_req<R, U>(
tor_client: TorClient<R>,
socket_id: usize,
packet: &[u8],
/// Identifier for a DNS request, composed of its source IP and transaction ID
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct DnsCacheKey(DnsIsolationKey, Vec<Query>);
/// Target for a DNS response
#[derive(Debug, Clone)]
struct DnsResponseTarget<U> {
/// Transaction ID
id: u16,
/// Address of the client
addr: SocketAddr,
/// Socket to send the response through
socket: Arc<U>,
) -> Result<()>
}
/// Run a DNS query over tor, returning either a list of answers, or a DNS error code.
async fn do_query<R>(
tor_client: TorClient<R>,
queries: &[Query],
prefs: &StreamPrefs,
) -> Result<Vec<Record>, ResponseCode>
where
R: Runtime,
U: UdpSocket,
{
let mut query = Message::from_bytes(packet)?;
let id = query.id();
let mut answers = Vec::new();
let mut prefs = StreamPrefs::new();
prefs.set_isolation(DnsIsolationKey(socket_id, addr.ip()));
for query in query.queries() {
for query in queries {
let mut a = Vec::new();
let mut ptr = Vec::new();
// TODO maybe support ANY?
// TODO if there are N questions, this would take N rtt to answer. By joining all futures it
// could take only 1 rtt, but having more than 1 question is actually very rare.
match query.query_class() {
DNSClass::IN => {
match query.query_type() {
@ -83,27 +85,36 @@ where
// name would be "torproject.org." without this
name.set_fqdn(false);
let res = tor_client
.resolve_with_prefs(&name.to_utf8(), &prefs)
.await?;
.resolve_with_prefs(&name.to_utf8(), prefs)
.await
.map_err(|_| ResponseCode::ServFail)?;
for ip in res {
a.push((query.name().clone(), ip, typ));
}
}
RecordType::PTR => {
let addr = query.name().parse_arpa_name()?.addr();
let res = tor_client.resolve_ptr_with_prefs(addr, &prefs).await?;
let addr = query
.name()
.parse_arpa_name()
.map_err(|_| ResponseCode::FormErr)?
.addr();
let res = tor_client
.resolve_ptr_with_prefs(addr, prefs)
.await
.map_err(|_| ResponseCode::ServFail)?;
for domain in res {
let domain = Name::from_utf8(domain)?;
let domain =
Name::from_utf8(domain).map_err(|_| ResponseCode::ServFail)?;
ptr.push((query.name().clone(), domain));
}
}
_ => {
return not_implemented(id, &addr, &*socket).await;
return Err(ResponseCode::NotImp);
}
}
}
_ => {
return not_implemented(id, &addr, &*socket).await;
return Err(ResponseCode::NotImp);
}
}
for (name, ip, typ) in a {
@ -122,18 +133,85 @@ where
}
}
let mut response = Message::new();
response
.set_id(id)
.set_message_type(MessageType::Response)
.set_op_code(OpCode::Query)
.set_recursion_desired(query.recursion_desired())
.set_recursion_available(true)
.add_queries(query.take_queries())
.add_answers(answers);
// TODO maybe add some edns?
Ok(answers)
}
socket.send(&response.to_bytes()?, &addr).await?;
/// Given a datagram containing a DNS query, resolve the query over
/// the Tor network and send the response back.
async fn handle_dns_req<R, U>(
tor_client: TorClient<R>,
socket_id: usize,
packet: &[u8],
addr: SocketAddr,
socket: Arc<U>,
current_requests: &Mutex<HashMap<DnsCacheKey, Vec<DnsResponseTarget<U>>>>,
) -> Result<()>
where
R: Runtime,
U: UdpSocket,
{
// if we can't parse the request, don't try to answer it.
let mut query = Message::from_bytes(packet)?;
let id = query.id();
let queries = query.queries();
let isolation = DnsIsolationKey(socket_id, addr.ip());
let request_id = {
let request_id = DnsCacheKey(isolation.clone(), queries.to_vec());
let response_target = DnsResponseTarget { id, addr, socket };
let mut current_requests = current_requests.lock().await;
let req = current_requests.entry(request_id.clone()).or_default();
req.push(response_target);
if req.len() > 1 {
debug!("Received a query already being served");
return Ok(());
}
debug!("Received a new query");
request_id
};
let mut prefs = StreamPrefs::new();
prefs.set_isolation(isolation);
let mut response = match do_query(tor_client, queries, &prefs).await {
Ok(answers) => {
let mut response = Message::new();
response
.set_message_type(MessageType::Response)
.set_op_code(OpCode::Query)
.set_recursion_desired(query.recursion_desired())
.set_recursion_available(true)
.add_queries(query.take_queries())
.add_answers(answers);
// TODO maybe add some edns?
response
}
Err(error_type) => Message::error_msg(id, OpCode::Query, error_type),
};
// remove() should never return None, but just in case
let targets = current_requests
.lock()
.await
.remove(&request_id)
.unwrap_or_default();
for target in targets {
response.set_id(target.id);
// ignore errors, we want to reply to everybody
let response = if let Ok(r) = response.to_bytes() {
r
} else {
error!("Failed to serialize DNS packet: {:?}", response);
continue;
};
let _ = target.socket.send(&response, &target.addr).await;
}
Ok(())
}
@ -184,6 +262,7 @@ pub async fn run_dns_resolver<R: Runtime>(
}),
);
let pending_requests = Arc::new(Mutex::new(HashMap::new()));
while let Some((packet, id)) = incoming.next().await {
let (packet, size, addr, socket) = match packet {
Ok(packet) => packet,
@ -195,10 +274,21 @@ pub async fn run_dns_resolver<R: Runtime>(
};
let client_ref = tor_client.clone();
runtime.spawn(async move {
let res = handle_dns_req(client_ref, id, &packet[..size], addr, socket).await;
if let Err(e) = res {
warn!("connection exited with error: {}", e);
runtime.spawn({
let pending_requests = pending_requests.clone();
async move {
let res = handle_dns_req(
client_ref,
id,
&packet[..size],
addr,
socket,
&pending_requests,
)
.await;
if let Err(e) = res {
warn!("connection exited with error: {}", e);
}
}
})?;
}