deduplicate dns requests based on transaction id
This commit is contained in:
parent
75f968017d
commit
266b278c74
|
@ -3,13 +3,15 @@
|
|||
//! A resolver is launched with [`run_dns_resolver()`], which listens for new
|
||||
//! connections and then runs
|
||||
|
||||
use futures::lock::Mutex;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::task::SpawnExt;
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info, warn};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use trust_dns_proto::op::{
|
||||
header::MessageType, op_code::OpCode, response_code::ResponseCode, Message,
|
||||
header::MessageType, op_code::OpCode, response_code::ResponseCode, Message, Query,
|
||||
};
|
||||
use trust_dns_proto::rr::{DNSClass, Name, RData, Record, RecordType};
|
||||
use trust_dns_proto::serialize::binary::{BinDecodable, BinEncodable};
|
||||
|
@ -22,18 +24,11 @@ use anyhow::{anyhow, Result};
|
|||
/// Maximum length for receiving a single datagram
|
||||
const MAX_DATAGRAM_SIZE: usize = 1536;
|
||||
|
||||
/// Send an error DNS response with code NotImplemented
|
||||
async fn not_implemented<U: UdpSocket>(id: u16, addr: &SocketAddr, socket: &U) -> Result<()> {
|
||||
let response = Message::error_msg(id, OpCode::Query, ResponseCode::NotImp);
|
||||
socket.send(&response.to_bytes()?, addr).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A Key used to isolate dns requests.
|
||||
///
|
||||
/// Composed of an usize (representing which listener socket accepted
|
||||
/// the connection and the source IpAddr of the client)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct DnsIsolationKey(usize, IpAddr);
|
||||
|
||||
impl arti_client::isolation::IsolationHelper for DnsIsolationKey {
|
||||
|
@ -50,31 +45,38 @@ impl arti_client::isolation::IsolationHelper for DnsIsolationKey {
|
|||
}
|
||||
}
|
||||
|
||||
/// Given a datagram containing a DNS query, resolve the query over
|
||||
/// the Tor network and send the response back.
|
||||
async fn handle_dns_req<R, U>(
|
||||
tor_client: TorClient<R>,
|
||||
socket_id: usize,
|
||||
packet: &[u8],
|
||||
/// Identifier for a DNS request, composed of its source IP and transaction ID
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct DnsCacheKey(DnsIsolationKey, Vec<Query>);
|
||||
|
||||
/// Target for a DNS response
|
||||
#[derive(Debug, Clone)]
|
||||
struct DnsResponseTarget<U> {
|
||||
/// Transaction ID
|
||||
id: u16,
|
||||
/// Address of the client
|
||||
addr: SocketAddr,
|
||||
/// Socket to send the response through
|
||||
socket: Arc<U>,
|
||||
) -> Result<()>
|
||||
}
|
||||
|
||||
/// Run a DNS query over tor, returning either a list of answers, or a DNS error code.
|
||||
async fn do_query<R>(
|
||||
tor_client: TorClient<R>,
|
||||
queries: &[Query],
|
||||
prefs: &StreamPrefs,
|
||||
) -> Result<Vec<Record>, ResponseCode>
|
||||
where
|
||||
R: Runtime,
|
||||
U: UdpSocket,
|
||||
{
|
||||
let mut query = Message::from_bytes(packet)?;
|
||||
let id = query.id();
|
||||
|
||||
let mut answers = Vec::new();
|
||||
|
||||
let mut prefs = StreamPrefs::new();
|
||||
prefs.set_isolation(DnsIsolationKey(socket_id, addr.ip()));
|
||||
|
||||
for query in query.queries() {
|
||||
for query in queries {
|
||||
let mut a = Vec::new();
|
||||
let mut ptr = Vec::new();
|
||||
// TODO maybe support ANY?
|
||||
|
||||
// TODO if there are N questions, this would take N rtt to answer. By joining all futures it
|
||||
// could take only 1 rtt, but having more than 1 question is actually very rare.
|
||||
match query.query_class() {
|
||||
DNSClass::IN => {
|
||||
match query.query_type() {
|
||||
|
@ -83,27 +85,36 @@ where
|
|||
// name would be "torproject.org." without this
|
||||
name.set_fqdn(false);
|
||||
let res = tor_client
|
||||
.resolve_with_prefs(&name.to_utf8(), &prefs)
|
||||
.await?;
|
||||
.resolve_with_prefs(&name.to_utf8(), prefs)
|
||||
.await
|
||||
.map_err(|_| ResponseCode::ServFail)?;
|
||||
for ip in res {
|
||||
a.push((query.name().clone(), ip, typ));
|
||||
}
|
||||
}
|
||||
RecordType::PTR => {
|
||||
let addr = query.name().parse_arpa_name()?.addr();
|
||||
let res = tor_client.resolve_ptr_with_prefs(addr, &prefs).await?;
|
||||
let addr = query
|
||||
.name()
|
||||
.parse_arpa_name()
|
||||
.map_err(|_| ResponseCode::FormErr)?
|
||||
.addr();
|
||||
let res = tor_client
|
||||
.resolve_ptr_with_prefs(addr, prefs)
|
||||
.await
|
||||
.map_err(|_| ResponseCode::ServFail)?;
|
||||
for domain in res {
|
||||
let domain = Name::from_utf8(domain)?;
|
||||
let domain =
|
||||
Name::from_utf8(domain).map_err(|_| ResponseCode::ServFail)?;
|
||||
ptr.push((query.name().clone(), domain));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return not_implemented(id, &addr, &*socket).await;
|
||||
return Err(ResponseCode::NotImp);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return not_implemented(id, &addr, &*socket).await;
|
||||
return Err(ResponseCode::NotImp);
|
||||
}
|
||||
}
|
||||
for (name, ip, typ) in a {
|
||||
|
@ -122,18 +133,85 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
let mut response = Message::new();
|
||||
response
|
||||
.set_id(id)
|
||||
.set_message_type(MessageType::Response)
|
||||
.set_op_code(OpCode::Query)
|
||||
.set_recursion_desired(query.recursion_desired())
|
||||
.set_recursion_available(true)
|
||||
.add_queries(query.take_queries())
|
||||
.add_answers(answers);
|
||||
// TODO maybe add some edns?
|
||||
Ok(answers)
|
||||
}
|
||||
|
||||
socket.send(&response.to_bytes()?, &addr).await?;
|
||||
/// Given a datagram containing a DNS query, resolve the query over
|
||||
/// the Tor network and send the response back.
|
||||
async fn handle_dns_req<R, U>(
|
||||
tor_client: TorClient<R>,
|
||||
socket_id: usize,
|
||||
packet: &[u8],
|
||||
addr: SocketAddr,
|
||||
socket: Arc<U>,
|
||||
current_requests: &Mutex<HashMap<DnsCacheKey, Vec<DnsResponseTarget<U>>>>,
|
||||
) -> Result<()>
|
||||
where
|
||||
R: Runtime,
|
||||
U: UdpSocket,
|
||||
{
|
||||
// if we can't parse the request, don't try to answer it.
|
||||
let mut query = Message::from_bytes(packet)?;
|
||||
let id = query.id();
|
||||
let queries = query.queries();
|
||||
let isolation = DnsIsolationKey(socket_id, addr.ip());
|
||||
|
||||
let request_id = {
|
||||
let request_id = DnsCacheKey(isolation.clone(), queries.to_vec());
|
||||
|
||||
let response_target = DnsResponseTarget { id, addr, socket };
|
||||
|
||||
let mut current_requests = current_requests.lock().await;
|
||||
|
||||
let req = current_requests.entry(request_id.clone()).or_default();
|
||||
req.push(response_target);
|
||||
|
||||
if req.len() > 1 {
|
||||
debug!("Received a query already being served");
|
||||
return Ok(());
|
||||
}
|
||||
debug!("Received a new query");
|
||||
|
||||
request_id
|
||||
};
|
||||
|
||||
let mut prefs = StreamPrefs::new();
|
||||
prefs.set_isolation(isolation);
|
||||
|
||||
let mut response = match do_query(tor_client, queries, &prefs).await {
|
||||
Ok(answers) => {
|
||||
let mut response = Message::new();
|
||||
response
|
||||
.set_message_type(MessageType::Response)
|
||||
.set_op_code(OpCode::Query)
|
||||
.set_recursion_desired(query.recursion_desired())
|
||||
.set_recursion_available(true)
|
||||
.add_queries(query.take_queries())
|
||||
.add_answers(answers);
|
||||
// TODO maybe add some edns?
|
||||
response
|
||||
}
|
||||
Err(error_type) => Message::error_msg(id, OpCode::Query, error_type),
|
||||
};
|
||||
|
||||
// remove() should never return None, but just in case
|
||||
let targets = current_requests
|
||||
.lock()
|
||||
.await
|
||||
.remove(&request_id)
|
||||
.unwrap_or_default();
|
||||
|
||||
for target in targets {
|
||||
response.set_id(target.id);
|
||||
// ignore errors, we want to reply to everybody
|
||||
let response = if let Ok(r) = response.to_bytes() {
|
||||
r
|
||||
} else {
|
||||
error!("Failed to serialize DNS packet: {:?}", response);
|
||||
continue;
|
||||
};
|
||||
let _ = target.socket.send(&response, &target.addr).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -184,6 +262,7 @@ pub async fn run_dns_resolver<R: Runtime>(
|
|||
}),
|
||||
);
|
||||
|
||||
let pending_requests = Arc::new(Mutex::new(HashMap::new()));
|
||||
while let Some((packet, id)) = incoming.next().await {
|
||||
let (packet, size, addr, socket) = match packet {
|
||||
Ok(packet) => packet,
|
||||
|
@ -195,10 +274,21 @@ pub async fn run_dns_resolver<R: Runtime>(
|
|||
};
|
||||
|
||||
let client_ref = tor_client.clone();
|
||||
runtime.spawn(async move {
|
||||
let res = handle_dns_req(client_ref, id, &packet[..size], addr, socket).await;
|
||||
if let Err(e) = res {
|
||||
warn!("connection exited with error: {}", e);
|
||||
runtime.spawn({
|
||||
let pending_requests = pending_requests.clone();
|
||||
async move {
|
||||
let res = handle_dns_req(
|
||||
client_ref,
|
||||
id,
|
||||
&packet[..size],
|
||||
addr,
|
||||
socket,
|
||||
&pending_requests,
|
||||
)
|
||||
.await;
|
||||
if let Err(e) = res {
|
||||
warn!("connection exited with error: {}", e);
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue