pull/192/head
Σrebe - Romain GERARD 2023-10-30 08:13:38 +01:00
parent b478288848
commit 466cb425bc
No known key found for this signature in database
GPG Key ID: 7A42B4B97E0332F4
11 changed files with 159 additions and 320 deletions

3
rustfmt.toml Normal file
View File

@ -0,0 +1,3 @@
edition = "2021"
max_width = 120
fn_call_width = 80

View File

@ -3,14 +3,13 @@ use tokio_rustls::rustls::{Certificate, PrivateKey};
pub static TLS_PRIVATE_KEY: Lazy<PrivateKey> = Lazy::new(|| {
let key = include_bytes!("../certs/key.pem");
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key.as_slice())
.expect("failed to load embedded tls private key");
let mut keys =
rustls_pemfile::pkcs8_private_keys(&mut key.as_slice()).expect("failed to load embedded tls private key");
PrivateKey(keys.remove(0))
});
pub static TLS_CERTIFICATE: Lazy<Vec<Certificate>> = Lazy::new(|| {
let cert = include_bytes!("../certs/cert.pem");
let certs = rustls_pemfile::certs(&mut cert.as_slice())
.expect("failed to load embedded tls certificate");
let certs = rustls_pemfile::certs(&mut cert.as_slice()).expect("failed to load embedded tls certificate");
certs.into_iter().map(Certificate).collect()
});

View File

@ -67,13 +67,7 @@ struct Client {
/// This option set the maximum number of connection that will be kept open.
/// This is useful if you plan to create/destroy a lot of tunnel (i.e: with socks5 to navigate with a browser)
/// It will avoid the latency of doing tcp + tls handshake with the server
#[arg(
short = 'c',
long,
value_name = "INT",
default_value = "0",
verbatim_doc_comment
)]
#[arg(short = 'c', long, value_name = "INT", default_value = "0", verbatim_doc_comment)]
connection_min_idle: u32,
/// Domain name that will be use as SNI during TLS handshake
@ -88,12 +82,7 @@ struct Client {
tls_verify_certificate: bool,
/// If set, will use this http proxy to connect to the server
#[arg(
short = 'p',
long,
value_name = "http://USER:PASS@HOST:PORT",
verbatim_doc_comment
)]
#[arg(short = 'p', long, value_name = "http://USER:PASS@HOST:PORT", verbatim_doc_comment)]
http_proxy: Option<Url>,
/// Use a specific prefix that will show up in the http path during the upgrade request.
@ -241,9 +230,7 @@ fn parse_local_bind(arg: &str) -> Result<(SocketAddr, &str), io::Error> {
}
#[allow(clippy::type_complexity)]
fn parse_tunnel_dest(
remaining: &str,
) -> Result<(Host<String>, u16, BTreeMap<String, String>), io::Error> {
fn parse_tunnel_dest(remaining: &str) -> Result<(Host<String>, u16, BTreeMap<String, String>), io::Error> {
use std::io::Error;
let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
@ -290,13 +277,7 @@ fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
let timeout = options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| {
if d == 0 {
None
} else {
Some(Duration::from_secs(d))
}
})
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
.unwrap_or(Some(Duration::from_secs(30)));
Ok(LocalToRemote {
@ -355,10 +336,7 @@ fn parse_http_headers(arg: &str) -> Result<(HeaderName, HeaderValue), io::Error>
Err(err) => {
return Err(io::Error::new(
ErrorKind::InvalidInput,
format!(
"cannot parse http header value from {} due to {:?}",
value, err
),
format!("cannot parse http header value from {} due to {:?}", value, err),
))
}
};
@ -394,10 +372,7 @@ fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
}
if url.host().is_none() {
return Err(io::Error::new(
ErrorKind::InvalidInput,
format!("invalid server host {}", arg),
));
return Err(io::Error::new(ErrorKind::InvalidInput, format!("invalid server host {}", arg)));
}
Ok(url)
@ -474,15 +449,9 @@ impl WsClientConfig {
}
pub fn tls_server_name(&self) -> ServerName {
match self
.tls
.as_ref()
.and_then(|tls| tls.tls_sni_override.as_ref())
{
match self.tls.as_ref().and_then(|tls| tls.tls_sni_override.as_ref()) {
None => match &self.remote_addr.0 {
Host::Domain(domain) => {
ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap())
}
Host::Domain(domain) => ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap()),
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)),
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)),
},
@ -529,12 +498,11 @@ async fn main() {
};
// Extract host header from http_headers
let host_header =
if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) {
host_val.clone()
} else {
HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap()
};
let host_header = if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) {
host_val.clone()
} else {
HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap()
};
let mut client_config = WsClientConfig {
remote_addr: (
args.remote_addr.host().unwrap().to_owned(),
@ -544,16 +512,10 @@ async fn main() {
tls,
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
http_upgrade_credentials: args.http_upgrade_credentials,
http_headers: args
.http_headers
.into_iter()
.filter(|(k, _)| k != HOST)
.collect(),
http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(),
http_header_host: host_header,
timeout_connect: Duration::from_secs(10),
websocket_ping_frequency: args
.websocket_ping_frequency_sec
.unwrap_or(Duration::from_secs(30)),
websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)),
websocket_mask_frame: args.websocket_mask_frame,
http_proxy: args.http_proxy,
cnx_pool: None,
@ -579,16 +541,12 @@ async fn main() {
let remote = tunnel.remote.clone();
let server = tcp::run_server(tunnel.local)
.await
.unwrap_or_else(|err| {
panic!("Cannot start TCP server on {}: {}", tunnel.local, err)
})
.unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err))
.map_err(anyhow::Error::new)
.map_ok(move |stream| (stream.into_split(), remote.clone()));
tokio::spawn(async move {
if let Err(err) =
tunnel::client::run_tunnel(client_config, tunnel, server).await
{
if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
error!("{:?}", err);
}
});
@ -597,16 +555,12 @@ async fn main() {
let remote = tunnel.remote.clone();
let server = udp::run_server(tunnel.local, *timeout)
.await
.unwrap_or_else(|err| {
panic!("Cannot start UDP server on {}: {}", tunnel.local, err)
})
.unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err))
.map_err(anyhow::Error::new)
.map_ok(move |stream| (tokio::io::split(stream), remote.clone()));
tokio::spawn(async move {
if let Err(err) =
tunnel::client::run_tunnel(client_config, tunnel, server).await
{
if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
error!("{:?}", err);
}
});
@ -614,15 +568,11 @@ async fn main() {
LocalProtocol::Socks5 => {
let server = socks5::run_server(tunnel.local)
.await
.unwrap_or_else(|err| {
panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)
})
.unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err))
.map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest));
tokio::spawn(async move {
if let Err(err) =
tunnel::client::run_tunnel(client_config, tunnel, server).await
{
if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
error!("{:?}", err);
}
});
@ -656,8 +606,7 @@ async fn main() {
Commands::Server(args) => {
let tls_config = if args.remote_addr.scheme() == "wss" {
let tls_certificate = if let Some(cert_path) = args.tls_certificate {
tls::load_certificates_from_pem(&cert_path)
.expect("Cannot load tls certificate")
tls::load_certificates_from_pem(&cert_path).expect("Cannot load tls certificate")
} else {
embedded_certificate::TLS_CERTIFICATE.clone()
};

View File

@ -19,10 +19,7 @@ pub struct Socks5Listener {
impl Stream for Socks5Listener {
type Item = anyhow::Result<(TcpStream, (Host, u16))>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
unsafe { self.map_unchecked_mut(|x| &mut x.stream) }.poll_next(cx)
}
}

View File

@ -14,12 +14,9 @@ use tracing::log::info;
use url::{Host, Url};
fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(), anyhow::Error> {
socket.set_nodelay(true).with_context(|| {
format!(
"cannot set no_delay on socket: {}",
io::Error::last_os_error()
)
})?;
socket
.set_nodelay(true)
.with_context(|| format!("cannot set no_delay on socket: {}", io::Error::last_os_error()))?;
#[cfg(target_os = "linux")]
if let Some(so_mark) = so_mark {
@ -35,10 +32,7 @@ fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(),
);
if ret != 0 {
return Err(anyhow!(
"Cannot set SO_MARK on the connection {:?}",
io::Error::last_os_error()
));
return Err(anyhow!("Cannot set SO_MARK on the connection {:?}", io::Error::last_os_error()));
}
}
}
@ -117,17 +111,14 @@ pub async fn connect_with_http_proxy(
let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout).await?;
info!("Connected to http proxy {}:{}", proxy_host, proxy_port);
let authorization =
if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) {
let creds =
base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password));
format!("Proxy-Authorization: Basic {}\r\n", creds)
} else {
"".to_string()
};
let authorization = if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) {
let creds = base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password));
format!("Proxy-Authorization: Basic {}\r\n", creds)
} else {
"".to_string()
};
let connect_request =
format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
let connect_request = format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
socket.write_all(connect_request.as_bytes()).await?;
let mut buf = BytesMut::with_capacity(1024);
@ -136,16 +127,15 @@ pub async fn connect_with_http_proxy(
match nb_bytes {
Ok(Ok(0)) => {
return Err(anyhow!(
"Cannot connect to http proxy. Proxy closed the connection without returning any response"));
"Cannot connect to http proxy. Proxy closed the connection without returning any response"
));
}
Ok(Ok(_)) => {}
Ok(Err(err)) => {
return Err(anyhow!("Cannot connect to http proxy. {err}"));
}
Err(_) => {
return Err(anyhow!(
"Cannot connect to http proxy. Proxy took too long to connect"
));
return Err(anyhow!("Cannot connect to http proxy. Proxy took too long to connect"));
}
};
@ -225,8 +215,7 @@ mod tests {
let server = TcpListener::bind(server_addr).await.unwrap();
let docker = testcontainers::clients::Cli::default();
let mitm_proxy: RunnableImage<MitmProxy> =
RunnableImage::from(MitmProxy {}).with_network("host".to_string());
let mitm_proxy: RunnableImage<MitmProxy> = RunnableImage::from(MitmProxy {}).with_network("host".to_string());
let _node = docker.run(mitm_proxy);
let mut client = connect_with_http_proxy(
@ -239,10 +228,7 @@ mod tests {
.await
.unwrap();
client
.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice())
.await
.unwrap();
client.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice()).await.unwrap();
let client_srv = server.accept().await.unwrap().0;
pin_mut!(client_srv);

View File

@ -45,21 +45,15 @@ pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
match keys.len() {
0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")),
1 => Ok(PrivateKey(keys.remove(0))),
_ => Err(anyhow!(
"More than one PKCS8-encoded private key found in {path:?}"
)),
_ => Err(anyhow!("More than one PKCS8-encoded private key found in {path:?}")),
}
}
fn tls_connector(
tls_cfg: &TlsClientConfig,
alpn_protocols: Option<Vec<Vec<u8>>>,
) -> anyhow::Result<TlsConnector> {
fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsConnector> {
let mut root_store = rustls::RootCertStore::empty();
// Load system certificates and add them to the root store
let certs = rustls_native_certs::load_native_certs()
.with_context(|| "Cannot load system certificates")?;
let certs = rustls_native_certs::load_native_certs().with_context(|| "Cannot load system certificates")?;
for cert in certs {
root_store.add(&Certificate(cert.0))?;
}
@ -71,9 +65,7 @@ fn tls_connector(
// To bypass certificate verification
if !tls_cfg.tls_verify_certificate {
config
.dangerous()
.set_certificate_verifier(Arc::new(NullVerifier));
config.dangerous().set_certificate_verifier(Arc::new(NullVerifier));
}
if let Some(alpn_protocols) = alpn_protocols {
@ -83,10 +75,7 @@ fn tls_connector(
Ok(tls_connector)
}
pub fn tls_acceptor(
tls_cfg: &TlsServerConfig,
alpn_protocols: Option<Vec<Vec<u8>>>,
) -> anyhow::Result<TlsAcceptor> {
pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsAcceptor> {
let mut config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
@ -114,12 +103,7 @@ pub async fn connect(
let tls_stream = tls_connector
.connect(sni, tcp_stream)
.await
.with_context(|| {
format!(
"failed to do TLS handshake with the server {:?}",
client_cfg.remote_addr
)
})?;
.with_context(|| format!("failed to do TLS handshake with the server {:?}", client_cfg.remote_addr))?;
Ok(tls_stream)
}

View File

@ -44,9 +44,7 @@ pub async fn connect(
) -> anyhow::Result<WebSocket<Upgraded>> {
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
Ok(tcp_stream) => tcp_stream,
Err(err) => Err(anyhow!(
"failed to get a connection to the server from the pool: {err:?}"
))?,
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}"))?,
};
let mut req = Request::builder()
@ -80,12 +78,7 @@ pub async fn connect(
let transport = pooled_cnx.deref_mut().take().unwrap();
let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport)
.await
.with_context(|| {
format!(
"failed to do websocket handshake with the server {:?}",
client_cfg.remote_addr
)
})?;
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;
Ok(ws)
}
@ -109,10 +102,7 @@ where
// Forward local tx to websocket tx
let ping_frequency = client_cfg.websocket_ping_frequency;
tokio::spawn(
super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency)
.instrument(Span::current()),
);
tokio::spawn(super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()));
// Forward websocket rx to local rx
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await;

View File

@ -1,12 +1,12 @@
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
use futures_util::pin_mut;
use hyper::upgrade::Upgraded;
use std::pin::Pin;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::select;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tracing::log::debug;
use tracing::{error, info, trace, warn};
@ -20,7 +20,14 @@ pub(super) async fn propagate_read(
info!("Closing local tx ==> websocket tx tunnel");
});
let mut buffer = vec![0u8; 8 * 1024];
static JUMBO_FRAME_SIZE: usize = 9 * 1024; // enough for a jumbo frame
let mut buffer = vec![0u8; JUMBO_FRAME_SIZE];
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
// We reuse the future to avoid creating a timer in the tight loop
let mut timeout_unpin = tokio::time::sleep(ping_frequency);
let mut timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
pin_mut!(local_rx);
loop {
let read_len = select! {
@ -30,9 +37,12 @@ pub(super) async fn propagate_read(
_ = close_tx.closed() => break,
_ = timeout(ping_frequency, futures_util::future::pending::<()>()) => {
_ = &mut timeout => {
debug!("sending ping to keep websocket connection alive");
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
timeout_unpin = tokio::time::sleep(ping_frequency);
timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
continue;
}
};
@ -41,32 +51,30 @@ pub(super) async fn propagate_read(
Ok(0) => break,
Ok(read_len) => read_len,
Err(err) => {
warn!(
"error while reading incoming bytes from local tx tunnel {}",
err
);
warn!("error while reading incoming bytes from local tx tunnel {}", err);
break;
}
};
trace!("read {} bytes", read_len);
match ws_tx
if let Err(err) = ws_tx
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len])))
.await
{
Ok(_) => {}
Err(err) => {
warn!("error while writing to websocket tx tunnel {}", err);
break;
}
warn!("error while writing to websocket tx tunnel {}", err);
break;
}
// If the buffer has been completely filled with previous read, Double it !
// For the buffer to not be a bottleneck when the TCP window scale
// For udp, the buffer will never grows.
if buffer.capacity() == read_len {
buffer.clear();
buffer.resize(buffer.capacity() * 2, 0);
}
}
// Send normal close
let _ = ws_tx.write_frame(Frame::close(1000, &[])).await;
Ok(())
@ -104,20 +112,15 @@ pub(super) async fn propagate_write(
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
let ret = match msg.opcode {
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
local_tx.write_all(msg.payload.as_ref()).await
}
OpCode::Continuation | OpCode::Text | OpCode::Binary => local_tx.write_all(msg.payload.as_ref()).await,
OpCode::Close => break,
OpCode::Ping => Ok(()),
OpCode::Pong => Ok(()),
};
match ret {
Ok(_) => {}
Err(err) => {
error!("error while writing bytes to local for rx tunnel {}", err);
break;
}
if let Err(err) = ret {
error!("error while writing bytes to local for rx tunnel {}", err);
break;
}
}

View File

@ -42,12 +42,8 @@ impl JwtTunnelConfig {
}
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| {
(
Header::new(Algorithm::HS256),
EncodingKey::from_secret(JWT_SECRET),
)
});
static JWT_KEY: Lazy<(Header, EncodingKey)> =
Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET)));
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
let mut validation = Validation::new(Algorithm::HS256);
@ -61,11 +57,7 @@ pub enum TransportStream {
}
impl AsyncRead for TransportStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
@ -74,11 +66,7 @@ impl AsyncRead for TransportStream {
}
impl AsyncWrite for TransportStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),

View File

@ -46,15 +46,8 @@ async fn from_query(
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
if let Some(allowed_dests) = &server_config.restrict_to {
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
if allowed_dests
.iter()
.any(|dest| dest == &requested_dest)
.not()
{
warn!(
"Rejecting connection with not allowed destination: {}",
requested_dest
);
if allowed_dests.iter().any(|dest| dest == &requested_dest).not() {
warn!("Rejecting connection with not allowed destination: {}", requested_dest);
return Err(anyhow::anyhow!("Invalid upgrade request"));
}
}
@ -75,14 +68,9 @@ async fn from_query(
LocalProtocol::Tcp { .. } => {
let host = Host::parse(&jwt.claims.r)?;
let port = jwt.claims.rp;
let (rx, tx) = tcp::connect(
&host,
port,
&server_config.socket_so_mark,
Duration::from_secs(10),
)
.await?
.into_split();
let (rx, tx) = tcp::connect(&host, port, &server_config.socket_so_mark, Duration::from_secs(10))
.await?
.into_split();
Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx)))
}
@ -100,10 +88,7 @@ async fn server_upgrade(
}
if !req.uri().path().ends_with("/events") {
warn!(
"Rejecting connection with bad upgrade request: {}",
req.uri()
);
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request"))
@ -118,10 +103,7 @@ async fn server_upgrade(
|| &path[min_len..max_len] != path_prefix.as_str()
|| !path[max_len..].starts_with('/')
{
warn!(
"Rejecting connection with bad path prefix in upgrade request: {}",
req.uri()
);
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request"))
@ -133,11 +115,7 @@ async fn server_upgrade(
match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
Ok(ret) => ret,
Err(err) => {
warn!(
"Rejecting connection with bad upgrade request: {} {}",
err,
req.uri()
);
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
@ -149,11 +127,7 @@ async fn server_upgrade(
let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
Ok(ret) => ret,
Err(err) => {
warn!(
"Rejecting connection with bad upgrade request: {} {}",
err,
req.uri()
);
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
@ -171,14 +145,10 @@ async fn server_upgrade(
}
};
let (close_tx, close_rx) = oneshot::channel::<()>();
let ping_frequency = server_config
.websocket_ping_frequency
.unwrap_or(Duration::MAX);
let ping_frequency = server_config.websocket_ping_frequency.unwrap_or(Duration::MAX);
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
tokio::task::spawn(
super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()),
);
tokio::task::spawn(super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()));
let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await;
}
@ -189,10 +159,7 @@ async fn server_upgrade(
}
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
info!(
"Starting wstunnel server listening on {}",
server_config.bind
);
info!("Starting wstunnel server listening on {}", server_config.bind);
let config = server_config.clone();
let upgrade_fn = move |req: Request<Body>| server_upgrade(config.clone(), req);

View File

@ -1,18 +1,17 @@
use anyhow::Context;
use futures_util::{stream, Stream};
use parking_lot::{Mutex, RwLock};
use parking_lot::RwLock;
use pin_project::{pin_project, pinned_drop};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::future::Future;
use std::io;
use std::io::{Error, ErrorKind};
use std::net::SocketAddr;
use std::ops::DerefMut;
use std::pin::{pin, Pin};
use std::sync::{Arc, Weak};
use std::task::{ready, Poll, Waker};
use std::task::{ready, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket;
@ -23,8 +22,7 @@ use tokio::time::Sleep;
use tracing::{debug, error, info};
struct IoInner {
has_data_to_read: &'static Notify,
waker: Mutex<Option<Waker>>,
has_data_to_read: Notify,
has_read_data: Notify,
}
struct UdpServer {
@ -43,6 +41,7 @@ impl UdpServer {
cnx_timeout: timeout,
}
}
#[inline]
fn clean_dead_keys(&mut self) {
let nb_key_to_delete = self.keys_to_delete.read().len();
if nb_key_to_delete == 0 {
@ -52,16 +51,7 @@ impl UdpServer {
debug!("Cleaning {} dead udp peers", nb_key_to_delete);
let mut keys_to_delete = self.keys_to_delete.write();
for key in keys_to_delete.iter() {
let Some(peer) = self.peers.remove(key) else {
continue;
};
#[allow(mutable_transmutes)]
unsafe {
let _ = Box::from_raw(std::mem::transmute::<&Notify, &mut Notify>(
peer.has_data_to_read,
));
}
self.peers.remove(key);
}
keys_to_delete.clear();
}
@ -90,7 +80,42 @@ impl PinnedDrop for UdpStream {
keys_to_delete.write().push(self.peer);
}
self.io.has_read_data.notify_one();
// safety: we are dropping the notification as we extend its lifetime to 'static unsafely
// So it must be gone before we drop its parent. It should never happen but in case
let mut project = self.project();
project.pending_notification.as_mut().set(None);
project.io.has_read_data.notify_one();
}
}
impl UdpStream {
fn new(
socket: Arc<UdpSocket>,
peer: SocketAddr,
deadline: Option<Sleep>,
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
) -> (Self, Arc<IoInner>) {
let has_data_to_read = Notify::new();
let has_read_data = Notify::new();
has_data_to_read.notify_one();
let io = Arc::new(IoInner {
has_data_to_read,
has_read_data,
});
let mut s = Self {
socket,
peer,
deadline,
has_been_notified: false,
pending_notification: None,
io: io.clone(),
keys_to_delete,
};
let pending_notification = unsafe { std::mem::transmute(s.io.has_data_to_read.notified()) };
s.pending_notification = Some(pending_notification);
(s, io)
}
}
@ -111,43 +136,28 @@ impl AsyncRead for UdpStream {
}
if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() {
if !notified.poll(cx).is_ready() {
project.io.waker.lock().replace(cx.waker().clone());
return Poll::Pending;
}
ready!(notified.poll(cx));
project.pending_notification.as_mut().set(None);
}
let _ = ready!(project.socket.poll_recv(cx, obuf));
project
.pending_notification
.as_mut()
.set(Some(project.io.has_data_to_read.notified()));
let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) };
project.pending_notification.as_mut().set(Some(notified));
project.io.has_read_data.notify_one();
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for UdpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
self.socket.poll_send_to(cx, buf, self.peer)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
self.socket.poll_send_ready(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
}
@ -178,39 +188,22 @@ pub async fn run_server(
}
};
match server.peers.entry(peer_addr) {
Entry::Occupied(mut peer) => {
let io = peer.get_mut();
match server.peers.get(&peer_addr) {
Some(io) => {
io.has_read_data.notified().await;
io.has_data_to_read.notify_one();
let waker = io.waker.lock().deref_mut().take();
if let Some(waker) = waker {
waker.wake();
}
}
Entry::Vacant(peer) => {
let has_data_to_read: &'static Notify = Box::leak(Box::new(Notify::new()));
let pending_notification = has_data_to_read.notified();
let has_read_data = Notify::new();
has_data_to_read.notify_one();
let io = Arc::new(IoInner {
has_data_to_read,
waker: Mutex::new(None),
has_read_data,
});
peer.insert(io.clone());
let udp_client = UdpStream {
socket: server.clone_socket(),
peer: peer_addr,
deadline: server
None => {
let (udp_client, io) = UdpStream::new(
server.clone_socket(),
peer_addr,
server
.cnx_timeout
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
.map(tokio::time::sleep_until),
keys_to_delete: Arc::downgrade(&server.keys_to_delete),
has_been_notified: false,
pending_notification: Some(pending_notification),
io,
};
Arc::downgrade(&server.keys_to_delete),
);
server.peers.insert(peer_addr, io);
return Some((Ok(udp_client), (server)));
}
}
@ -231,11 +224,7 @@ impl MyUdpSocket {
}
impl AsyncRead for MyUdpSocket {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
fn poll_read(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }
.poll_recv_from(cx, buf)
.map(|x| x.map(|_| ()))
@ -243,25 +232,15 @@ impl AsyncRead for MyUdpSocket {
}
impl AsyncWrite for MyUdpSocket {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
fn poll_flush(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
}
@ -331,10 +310,7 @@ mod tests {
assert!(client.send_to(b"aaaaa".as_ref(), server_addr).await.is_ok());
let client2 = UdpSocket::bind("[::1]:0").await.unwrap();
assert!(client2
.send_to(b"bbbbb".as_ref(), server_addr)
.await
.is_ok());
assert!(client2.send_to(b"bbbbb".as_ref(), server_addr).await.is_ok());
// Should have a new connection
let fut = timeout(Duration::from_millis(100), server.next()).await;
@ -360,10 +336,7 @@ mod tests {
assert_eq!(&buf[..6], b"bbbbb\0");
assert!(client.send_to(b"ccccc".as_ref(), server_addr).await.is_ok());
assert!(client2
.send_to(b"ddddd".as_ref(), server_addr)
.await
.is_ok());
assert!(client2.send_to(b"ddddd".as_ref(), server_addr).await.is_ok());
// Server need to be polled to feed the stream with need data
let _ = timeout(Duration::from_millis(100), server.next()).await;