use log::debug;
use std::io::{self, BufRead, BufReader, Read, Write};
use std::net::SocketAddr;
use std::net::TcpStream;
use std::ops::Div;
use std::time::Duration;
use std::time::Instant;
use std::{fmt, io::Cursor};
#[cfg(feature = "socks-proxy")]
use socks::{TargetAddr, ToTargetAddr};
use crate::chunked::Decoder as ChunkDecoder;
use crate::error::ErrorKind;
use crate::pool::{PoolKey, PoolReturner};
use crate::proxy::Proxy;
use crate::unit::Unit;
use crate::Response;
use crate::{error::Error, proxy::Proto};
pub trait ReadWrite: Read + Write + Send + Sync + fmt::Debug + 'static {
fn socket(&self) -> Option<&TcpStream>;
}
impl ReadWrite for TcpStream {
fn socket(&self) -> Option<&TcpStream> {
Some(self)
}
}
pub trait TlsConnector: Send + Sync {
fn connect(
&self,
dns_name: &str,
io: Box<dyn ReadWrite>,
) -> Result<Box<dyn ReadWrite>, crate::error::Error>;
}
pub(crate) struct Stream {
inner: BufReader<Box<dyn ReadWrite>>,
pub(crate) remote_addr: SocketAddr,
pool_returner: PoolReturner,
}
impl<T: ReadWrite + ?Sized> ReadWrite for Box<T> {
fn socket(&self) -> Option<&TcpStream> {
ReadWrite::socket(self.as_ref())
}
}
pub(crate) struct DeadlineStream {
stream: Stream,
deadline: Option<Instant>,
}
impl DeadlineStream {
pub(crate) fn new(stream: Stream, deadline: Option<Instant>) -> Self {
DeadlineStream { stream, deadline }
}
pub(crate) fn inner_ref(&self) -> &Stream {
&self.stream
}
pub(crate) fn inner_mut(&mut self) -> &mut Stream {
&mut self.stream
}
}
impl From<DeadlineStream> for Stream {
fn from(deadline_stream: DeadlineStream) -> Stream {
deadline_stream.stream
}
}
impl BufRead for DeadlineStream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if let Some(deadline) = self.deadline {
let timeout = time_until_deadline(deadline)?;
if let Some(socket) = self.stream.socket() {
socket.set_read_timeout(Some(timeout))?;
socket.set_write_timeout(Some(timeout))?;
}
}
self.stream.fill_buf().map_err(|e| {
if e.kind() == io::ErrorKind::WouldBlock {
return io_err_timeout("timed out reading response".to_string());
}
e
})
}
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
impl Read for DeadlineStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if !self.stream.inner.buffer().is_empty() {
let n = self.stream.inner.buffer().read(buf)?;
self.stream.inner.consume(n);
return Ok(n);
}
let nread = {
let mut rem = self.fill_buf()?;
rem.read(buf)?
};
self.consume(nread);
Ok(nread)
}
}
fn time_until_deadline(deadline: Instant) -> io::Result<Duration> {
let now = Instant::now();
match deadline.checked_duration_since(now) {
None => Err(io_err_timeout("timed out reading response".to_string())),
Some(duration) => Ok(duration),
}
}
pub(crate) fn io_err_timeout(error: String) -> io::Error {
io::Error::new(io::ErrorKind::TimedOut, error)
}
#[derive(Debug)]
pub(crate) struct ReadOnlyStream(Cursor<Vec<u8>>);
impl ReadOnlyStream {
pub(crate) fn new(v: Vec<u8>) -> Self {
Self(Cursor::new(v))
}
}
impl Read for ReadOnlyStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl std::io::Write for ReadOnlyStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl ReadWrite for ReadOnlyStream {
fn socket(&self) -> Option<&std::net::TcpStream> {
None
}
}
impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.inner.get_ref().socket() {
Some(_) => write!(f, "Stream({:?})", self.inner.get_ref()),
None => write!(f, "Stream(Test)"),
}
}
}
impl Stream {
pub(crate) fn new(
t: impl ReadWrite,
remote_addr: SocketAddr,
pool_returner: PoolReturner,
) -> Stream {
Stream::logged_create(Stream {
inner: BufReader::new(Box::new(t)),
remote_addr,
pool_returner,
})
}
fn logged_create(stream: Stream) -> Stream {
debug!("created stream: {:?}", stream);
stream
}
pub(crate) fn buffer(&self) -> &[u8] {
self.inner.buffer()
}
fn serverclosed_stream(stream: &std::net::TcpStream) -> io::Result<bool> {
let mut buf = [0; 1];
stream.set_nonblocking(true)?;
let result = match stream.peek(&mut buf) {
Ok(n) => {
debug!(
"peek on reused connection returned {}, not WouldBlock; discarding",
n
);
Ok(true)
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(false),
Err(e) => Err(e),
};
stream.set_nonblocking(false)?;
result
}
pub(crate) fn server_closed(&self) -> io::Result<bool> {
match self.socket() {
Some(socket) => Stream::serverclosed_stream(socket),
None => Ok(false),
}
}
pub(crate) fn set_unpoolable(&mut self) {
self.pool_returner = PoolReturner::none();
}
pub(crate) fn return_to_pool(mut self) -> io::Result<()> {
self.reset()?;
self.pool_returner.clone().return_to_pool(self);
Ok(())
}
pub(crate) fn reset(&mut self) -> io::Result<()> {
if let Some(socket) = self.socket() {
socket.set_read_timeout(None)?;
socket.set_write_timeout(None)?;
}
Ok(())
}
pub(crate) fn socket(&self) -> Option<&TcpStream> {
self.inner.get_ref().socket()
}
pub(crate) fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
if let Some(socket) = self.socket() {
socket.set_read_timeout(timeout)
} else {
Ok(())
}
}
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl BufRead for Stream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.inner.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.inner.consume(amt)
}
}
impl<R: Read> From<ChunkDecoder<R>> for Stream
where
R: Read,
Stream: From<R>,
{
fn from(chunk_decoder: ChunkDecoder<R>) -> Stream {
chunk_decoder.into_inner().into()
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.get_mut().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.get_mut().flush()
}
}
impl Drop for Stream {
fn drop(&mut self) {
debug!("dropping stream: {:?}", self);
}
}
pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error> {
let port = unit.url.port().unwrap_or(80);
let pool_key = PoolKey::from_parts("http", hostname, port);
let pool_returner = PoolReturner::new(&unit.agent, pool_key);
connect_host(unit, hostname, port).map(|(t, r)| Stream::new(t, r, pool_returner))
}
pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error> {
let port = unit.url.port().unwrap_or(443);
let (sock, remote_addr) = connect_host(unit, hostname, port)?;
let tls_conf = &unit.agent.config.tls_config;
let https_stream = tls_conf.connect(hostname, Box::new(sock))?;
let pool_key = PoolKey::from_parts("https", hostname, port);
let pool_returner = PoolReturner::new(&unit.agent, pool_key);
Ok(Stream::new(https_stream, remote_addr, pool_returner))
}
pub(crate) fn connect_host(
unit: &Unit,
hostname: &str,
port: u16,
) -> Result<(TcpStream, SocketAddr), Error> {
let connect_deadline: Option<Instant> =
if let Some(timeout_connect) = unit.agent.config.timeout_connect {
Instant::now().checked_add(timeout_connect)
} else {
unit.deadline
};
let proxy: Option<Proxy> = unit.agent.config.proxy.clone();
let netloc = match proxy {
Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port),
None => format!("{}:{}", hostname, port),
};
let sock_addrs = unit.resolver().resolve(&netloc).map_err(|e| {
ErrorKind::Dns
.msg(format!("resolve dns name '{}'", netloc))
.src(e)
})?;
if sock_addrs.is_empty() {
return Err(ErrorKind::Dns.msg(format!("No ip address for {}", hostname)));
}
let proto = proxy.as_ref().map(|proxy| proxy.proto);
let mut any_err = None;
let mut any_stream_and_addr = None;
let multiple_addrs = sock_addrs.len() > 1;
for sock_addr in sock_addrs {
let timeout = match connect_deadline {
Some(deadline) => {
let mut deadline = time_until_deadline(deadline)?;
if multiple_addrs {
deadline = deadline.div(2);
}
Some(deadline)
}
None => None,
};
debug!("connecting to {} at {}", netloc, &sock_addr);
#[allow(clippy::unnecessary_unwrap)]
let stream = if proto.is_some() && Some(Proto::HTTP) != proto {
connect_socks(
unit,
proxy.clone().unwrap(),
connect_deadline,
sock_addr,
hostname,
port,
proto.unwrap(),
)
} else if let Some(timeout) = timeout {
TcpStream::connect_timeout(&sock_addr, timeout)
} else {
TcpStream::connect(sock_addr)
};
if let Ok(stream) = stream {
any_stream_and_addr = Some((stream, sock_addr));
break;
} else if let Err(err) = stream {
any_err = Some(err);
}
}
let (mut stream, remote_addr) = if let Some(stream_and_addr) = any_stream_and_addr {
stream_and_addr
} else if let Some(e) = any_err {
return Err(ErrorKind::ConnectionFailed.msg("Connect error").src(e));
} else {
panic!("shouldn't happen: failed to connect to all IPs, but no error");
};
stream.set_nodelay(unit.agent.config.no_delay)?;
if let Some(deadline) = unit.deadline {
stream.set_read_timeout(Some(time_until_deadline(deadline)?))?;
} else {
stream.set_read_timeout(unit.agent.config.timeout_read)?;
}
if let Some(deadline) = unit.deadline {
stream.set_write_timeout(Some(time_until_deadline(deadline)?))?;
} else {
stream.set_write_timeout(unit.agent.config.timeout_write)?;
}
if proto == Some(Proto::HTTP) && unit.url.scheme() == "https" {
if let Some(ref proxy) = proxy {
write!(
stream,
"{}",
proxy.connect(hostname, port, &unit.agent.config.user_agent)
)
.unwrap();
stream.flush()?;
let s = stream.try_clone()?;
let pool_key = PoolKey::from_parts(unit.url.scheme(), hostname, port);
let pool_returner = PoolReturner::new(&unit.agent, pool_key);
let s = Stream::new(s, remote_addr, pool_returner);
let response = Response::do_from_stream(s, unit.clone())?;
Proxy::verify_response(&response)?;
}
}
Ok((stream, remote_addr))
}
#[cfg(feature = "socks-proxy")]
fn socks_local_nslookup(
unit: &Unit,
hostname: &str,
port: u16,
) -> Result<TargetAddr, std::io::Error> {
let addrs: Vec<SocketAddr> = unit
.resolver()
.resolve(&format!("{}:{}", hostname, port))
.map_err(|e| {
std::io::Error::new(io::ErrorKind::NotFound, format!("DNS failure: {}.", e))
})?;
if addrs.is_empty() {
return Err(std::io::Error::new(
io::ErrorKind::NotFound,
"DNS failure: no socket addrs found.",
));
}
match addrs[0].to_target_addr() {
Ok(addr) => Ok(addr),
Err(err) => {
return Err(std::io::Error::new(
io::ErrorKind::NotFound,
format!("DNS failure: {}.", err),
))
}
}
}
#[cfg(feature = "socks-proxy")]
fn connect_socks(
unit: &Unit,
proxy: Proxy,
deadline: Option<Instant>,
proxy_addr: SocketAddr,
host: &str,
port: u16,
proto: Proto,
) -> Result<TcpStream, std::io::Error> {
use socks::TargetAddr::Domain;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::str::FromStr;
let host_addr = if Ipv4Addr::from_str(host).is_ok()
|| Ipv6Addr::from_str(host).is_ok()
|| proto == Proto::SOCKS4
{
match socks_local_nslookup(unit, host, port) {
Ok(addr) => addr,
Err(err) => return Err(err),
}
} else {
Domain(String::from(host), port)
};
#[allow(clippy::mutex_atomic)]
let stream = if let Some(deadline) = deadline {
use std::sync::mpsc::channel;
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
let master_signal = Arc::new((Mutex::new(false), Condvar::new()));
let slave_signal = master_signal.clone();
let (tx, rx) = channel();
thread::spawn(move || {
let (lock, cvar) = &*slave_signal;
if tx .send(if proto == Proto::SOCKS5 {
get_socks5_stream(&proxy, &proxy_addr, host_addr)
} else {
get_socks4_stream(&proxy_addr, host_addr)
})
.is_ok()
{
let mut done = lock.lock().unwrap();
*done = true;
cvar.notify_one();
}
});
let (lock, cvar) = &*master_signal;
let done = lock.lock().unwrap();
let timeout_connect = time_until_deadline(deadline)?;
let done_result = cvar.wait_timeout(done, timeout_connect).unwrap();
let done = done_result.0;
if *done {
rx.recv().unwrap()?
} else {
return Err(io_err_timeout(format!(
"SOCKS proxy: {}:{} timed out connecting after {}ms.",
host,
port,
timeout_connect.as_millis()
)));
}
} else if proto == Proto::SOCKS5 {
get_socks5_stream(&proxy, &proxy_addr, host_addr)?
} else {
get_socks4_stream(&proxy_addr, host_addr)?
};
Ok(stream)
}
#[cfg(feature = "socks-proxy")]
fn get_socks5_stream(
proxy: &Proxy,
proxy_addr: &SocketAddr,
host_addr: TargetAddr,
) -> Result<TcpStream, std::io::Error> {
use socks::Socks5Stream;
if proxy.use_authorization() {
let stream = Socks5Stream::connect_with_password(
proxy_addr,
host_addr,
proxy.user.as_ref().unwrap(),
proxy.password.as_ref().unwrap(),
)?
.into_inner();
Ok(stream)
} else {
match Socks5Stream::connect(proxy_addr, host_addr) {
Ok(socks_stream) => Ok(socks_stream.into_inner()),
Err(err) => Err(err),
}
}
}
#[cfg(feature = "socks-proxy")]
fn get_socks4_stream(
proxy_addr: &SocketAddr,
host_addr: TargetAddr,
) -> Result<TcpStream, std::io::Error> {
match socks::Socks4Stream::connect(proxy_addr, host_addr, "") {
Ok(socks_stream) => Ok(socks_stream.into_inner()),
Err(err) => Err(err),
}
}
#[cfg(not(feature = "socks-proxy"))]
fn connect_socks(
_unit: &Unit,
_proxy: Proxy,
_deadline: Option<Instant>,
_proxy_addr: SocketAddr,
_hostname: &str,
_port: u16,
_proto: Proto,
) -> Result<TcpStream, std::io::Error> {
Err(std::io::Error::new(
io::ErrorKind::Other,
"SOCKS feature disabled.",
))
}
#[cfg(test)]
pub(crate) fn connect_test(unit: &Unit) -> Result<Stream, Error> {
use crate::test;
test::resolve_handler(unit)
}
#[cfg(not(test))]
pub(crate) fn connect_test(unit: &Unit) -> Result<Stream, Error> {
Err(ErrorKind::UnknownScheme.msg(format!("unknown scheme '{}'", unit.url.scheme())))
}
#[cfg(test)]
pub(crate) fn remote_addr_for_test() -> SocketAddr {
use std::net::{Ipv4Addr, SocketAddrV4};
SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0).into()
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
io::Read,
sync::{Arc, Mutex},
};
struct ReadRecorder {
reads: Arc<Mutex<Vec<usize>>>,
}
impl Read for ReadRecorder {
fn read(&mut self, buf: &mut [u8]) -> std::result::Result<usize, std::io::Error> {
self.reads.lock().unwrap().push(buf.len());
buf.fill(0);
Ok(buf.len())
}
}
impl Write for ReadRecorder {
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
unimplemented!()
}
fn flush(&mut self) -> io::Result<()> {
unimplemented!()
}
}
impl fmt::Debug for ReadRecorder {
fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result {
unimplemented!()
}
}
impl ReadWrite for ReadRecorder {
fn socket(&self) -> Option<&TcpStream> {
unimplemented!()
}
}
#[test]
fn test_deadline_stream_buffering() {
let reads = Arc::new(Mutex::new(vec![]));
let recorder = ReadRecorder {
reads: reads.clone(),
};
let stream = Stream::new(recorder, remote_addr_for_test(), PoolReturner::none());
let mut deadline_stream = DeadlineStream::new(stream, None);
let mut buf = [0u8; 1];
for _ in 0..8193 {
let _ = deadline_stream.read(&mut buf).unwrap();
}
let reads = reads.lock().unwrap();
assert_eq!(reads.len(), 2);
assert_eq!(reads[0], 8192);
assert_eq!(reads[1], 8192);
}
}