socksproto: Use fallible writers.

Also, make private a function that had formerly been `pub`.
This commit is contained in:
Nick Mathewson 2022-07-09 13:33:42 -04:00
parent 5a61a6d73a
commit 9ca301faee
3 changed files with 39 additions and 26 deletions

View File

@ -198,7 +198,9 @@ where
// Send back a SOCKS response, telling the client that it
// successfully connected.
let reply = request.reply(tor_socksproto::SocksStatus::SUCCEEDED, None);
let reply = request
.reply(tor_socksproto::SocksStatus::SUCCEEDED, None)
.context("Encoding socks reply")?;
write_all_and_flush(&mut socks_w, &reply[..]).await?;
let (tor_r, tor_w) = tor_stream.split();
@ -216,10 +218,12 @@ where
Err(e) => return reply_error(&mut socks_w, &request, e).await,
};
if let Some(addr) = addrs.first() {
let reply = request.reply(
tor_socksproto::SocksStatus::SUCCEEDED,
Some(&SocksAddr::Ip(*addr)),
);
let reply = request
.reply(
tor_socksproto::SocksStatus::SUCCEEDED,
Some(&SocksAddr::Ip(*addr)),
)
.context("Encoding socks reply")?;
write_all_and_close(&mut socks_w, &reply[..]).await?;
}
}
@ -229,8 +233,9 @@ where
let addr: IpAddr = match addr.parse() {
Ok(ip) => ip,
Err(e) => {
let reply =
request.reply(tor_socksproto::SocksStatus::ADDRTYPE_NOT_SUPPORTED, None);
let reply = request
.reply(tor_socksproto::SocksStatus::ADDRTYPE_NOT_SUPPORTED, None)
.context("Encoding socks reply")?;
write_all_and_close(&mut socks_w, &reply[..]).await?;
return Err(anyhow!(e));
}
@ -243,14 +248,18 @@ where
// this conversion should never fail, legal DNS names len must be <= 253 but Socks
// names can be up to 255 chars.
let hostname = SocksAddr::Hostname(host.try_into()?);
let reply = request.reply(tor_socksproto::SocksStatus::SUCCEEDED, Some(&hostname));
let reply = request
.reply(tor_socksproto::SocksStatus::SUCCEEDED, Some(&hostname))
.context("Encoding socks reply")?;
write_all_and_close(&mut socks_w, &reply[..]).await?;
}
}
_ => {
// We don't support this SOCKS command.
warn!("Dropping request; {:?} is unsupported", request.command());
let reply = request.reply(tor_socksproto::SocksStatus::COMMAND_NOT_SUPPORTED, None);
let reply = request
.reply(tor_socksproto::SocksStatus::COMMAND_NOT_SUPPORTED, None)
.context("Encoding socks reply")?;
write_all_and_close(&mut socks_w, &reply[..]).await?;
}
};
@ -307,7 +316,8 @@ where
request.reply(tor_socksproto::SocksStatus::TTL_EXPIRED, None)
}
_ => request.reply(tor_socksproto::SocksStatus::GENERAL_FAILURE, None),
};
}
.context("Encoding socks reply")?;
// if writing back the error fail, still return the original error
let _ = write_all_and_close(writer, &reply[..]).await;

View File

View File

@ -264,7 +264,7 @@ impl SocksRequest {
///
/// Note that an address should be provided only when the request
/// was for a RESOLVE.
pub fn reply(&self, status: SocksStatus, addr: Option<&SocksAddr>) -> Vec<u8> {
pub fn reply(&self, status: SocksStatus, addr: Option<&SocksAddr>) -> EncodeResult<Vec<u8>> {
match self.version() {
SocksVersion::V4 => self.s4(status, addr),
SocksVersion::V5 => self.s5(status, addr),
@ -272,38 +272,38 @@ impl SocksRequest {
}
/// Format a SOCKS4 reply.
fn s4(&self, status: SocksStatus, addr: Option<&SocksAddr>) -> Vec<u8> {
fn s4(&self, status: SocksStatus, addr: Option<&SocksAddr>) -> EncodeResult<Vec<u8>> {
let mut w = Vec::new();
w.write_u8(0);
w.write_u8(status.into_socks4_status());
match addr {
Some(SocksAddr::Ip(IpAddr::V4(ip))) => {
w.write_u16(self.port());
w.write_infallible(ip);
w.write(ip)?;
}
_ => {
w.write_u16(0);
w.write_u32(0);
}
}
w
Ok(w)
}
/// Format a SOCKS5 reply.
pub fn s5(&self, status: SocksStatus, addr: Option<&SocksAddr>) -> Vec<u8> {
fn s5(&self, status: SocksStatus, addr: Option<&SocksAddr>) -> EncodeResult<Vec<u8>> {
let mut w = Vec::new();
w.write_u8(5);
w.write_u8(status.into());
w.write_u8(0); // reserved.
if let Some(a) = addr {
w.write_infallible(a);
w.write(a)?;
w.write_u16(self.port());
} else {
// TODO: sometimes I think we want to answer with ::, not 0.0.0.0
w.write_infallible(&SocksAddr::Ip(std::net::Ipv4Addr::UNSPECIFIED.into()));
w.write(&SocksAddr::Ip(std::net::Ipv4Addr::UNSPECIFIED.into()))?;
w.write_u16(0);
}
w
Ok(w)
}
}
@ -340,11 +340,11 @@ impl Writeable for SocksAddr {
match self {
SocksAddr::Ip(IpAddr::V4(ip)) => {
w.write_u8(1);
w.write_infallible(ip);
w.write(ip)?;
}
SocksAddr::Ip(IpAddr::V6(ip)) => {
w.write_u8(4);
w.write_infallible(ip);
w.write(ip)?;
}
SocksAddr::Hostname(h) => {
let h = h.as_ref();
@ -352,7 +352,7 @@ impl Writeable for SocksAddr {
let hlen = h.len() as u8;
w.write_u8(3);
w.write_u8(hlen);
w.write_infallible(h.as_bytes());
w.write(h.as_bytes())?;
}
}
Ok(())
@ -386,7 +386,8 @@ mod test {
req.reply(
SocksStatus::GENERAL_FAILURE,
Some(&SocksAddr::Ip("127.0.0.1".parse().unwrap()))
),
)
.unwrap(),
hex!("00 5B 0050 7f000001")
);
}
@ -411,7 +412,7 @@ mod test {
assert_eq!(req.command(), SocksCmd::CONNECT);
assert_eq!(
req.reply(SocksStatus::SUCCEEDED, None),
req.reply(SocksStatus::SUCCEEDED, None).unwrap(),
hex!("00 5A 0000 00000000")
);
}
@ -487,7 +488,8 @@ mod test {
Some(&SocksAddr::Hostname(
"foo.example.com".to_string().try_into().unwrap()
))
),
)
.unwrap(),
hex!("05 04 00 03 0f 666f6f2e6578616d706c652e636f6d 1f90")
);
}
@ -515,7 +517,8 @@ mod test {
assert_eq!(req.auth(), &SocksAuth::NoAuth);
assert_eq!(
req.reply(SocksStatus::GENERAL_FAILURE, Some(req.addr())),
req.reply(SocksStatus::GENERAL_FAILURE, Some(req.addr()))
.unwrap(),
hex!("05 01 00 04 f000 0000 0000 0000 0000 0000 0000 ff11 1f90")
);
}
@ -541,7 +544,7 @@ mod test {
assert_eq!(req.auth(), &SocksAuth::NoAuth);
assert_eq!(
req.reply(SocksStatus::SUCCEEDED, None),
req.reply(SocksStatus::SUCCEEDED, None).unwrap(),
hex!("05 00 00 01 00000000 0000")
);
}