Merge branch 'feature/proxy' into 'master'

add support for arbitrary proxies

See merge request famedly/conduit!54
This commit is contained in:
Timo Kösters 2021-07-01 19:46:18 +00:00
commit 5f6b0c673c
7 changed files with 200 additions and 5 deletions

13
Cargo.lock generated
View File

@ -1761,6 +1761,7 @@ dependencies = [
"serde_urlencoded", "serde_urlencoded",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-socks",
"url", "url",
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
@ -2732,6 +2733,18 @@ dependencies = [
"webpki", "webpki",
] ]
[[package]]
name = "tokio-socks"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51165dfa029d2a65969413a6cc96f354b86b464498702f174a4efa13608fd8c0"
dependencies = [
"either",
"futures-util",
"thiserror",
"tokio",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.6.6" version = "0.6.6"

View File

@ -47,7 +47,7 @@ rand = "0.8.3"
# Used to hash passwords # Used to hash passwords
rust-argon2 = "0.8.3" rust-argon2 = "0.8.3"
# Used to send requests # Used to send requests
reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots"] } reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots", "socks"] }
# Custom TLS verifier # Custom TLS verifier
rustls = { version = "0.19", features = ["dangerous_configuration"] } rustls = { version = "0.19", features = ["dangerous_configuration"] }
rustls-native-certs = "0.5.0" rustls-native-certs = "0.5.0"

View File

@ -41,3 +41,5 @@ trusted_servers = ["matrix.org"]
#workers = 4 # default: cpu core count * 2 #workers = 4 # default: cpu core count * 2
address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy
proxy = "none" # more examples can be found at src/database/proxy.rs:6

View File

@ -6,6 +6,7 @@ pub mod appservice;
pub mod globals; pub mod globals;
pub mod key_backups; pub mod key_backups;
pub mod media; pub mod media;
pub mod proxy;
pub mod pusher; pub mod pusher;
pub mod rooms; pub mod rooms;
pub mod sending; pub mod sending;
@ -28,6 +29,8 @@ use std::{
}; };
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
use self::proxy::ProxyConfig;
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Config { pub struct Config {
server_name: Box<ServerName>, server_name: Box<ServerName>,
@ -46,6 +49,8 @@ pub struct Config {
allow_federation: bool, allow_federation: bool,
#[serde(default = "false_fn")] #[serde(default = "false_fn")]
pub allow_jaeger: bool, pub allow_jaeger: bool,
#[serde(default)]
proxy: ProxyConfig,
jwt_secret: Option<String>, jwt_secret: Option<String>,
#[serde(default = "Vec::new")] #[serde(default = "Vec::new")]
trusted_servers: Vec<Box<ServerName>>, trusted_servers: Vec<Box<ServerName>>,

View File

@ -125,13 +125,15 @@ impl Globals {
tlsconfig.root_store = tlsconfig.root_store =
rustls_native_certs::load_native_certs().expect("Error loading system certificates"); rustls_native_certs::load_native_certs().expect("Error loading system certificates");
let reqwest_client = reqwest::Client::builder() let mut reqwest_client_builder = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(30)) .connect_timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(60 * 3)) .timeout(Duration::from_secs(60 * 3))
.pool_max_idle_per_host(1) .pool_max_idle_per_host(1)
.use_preconfigured_tls(tlsconfig) .use_preconfigured_tls(tlsconfig);
.build() if let Some(proxy) = config.proxy.to_proxy()? {
.unwrap(); reqwest_client_builder = reqwest_client_builder.proxy(proxy);
}
let reqwest_client = reqwest_client_builder.build().unwrap();
let jwt_decoding_key = config let jwt_decoding_key = config
.jwt_secret .jwt_secret

146
src/database/proxy.rs Normal file
View File

@ -0,0 +1,146 @@
use reqwest::{Proxy, Url};
use serde::Deserialize;
use crate::Result;
/// ## Examples:
/// - No proxy (default):
/// ```toml
/// proxy ="none"
/// ```
/// - Global proxy
/// ```toml
/// [proxy]
/// global = { url = "socks5h://localhost:9050" }
/// ```
/// - Proxy some domains
/// ```toml
/// [proxy]
/// [[proxy.by_domain]]
/// url = "socks5h://localhost:9050"
/// include = ["*.onion", "matrix.myspecial.onion"]
/// exclude = ["*.myspecial.onion"]
/// ```
/// ## Include vs. Exclude
/// If include is an empty list, it is assumed to be `["*"]`.
///
/// If a domain matches both the exclude and include list, the proxy will only be used if it was
/// included because of a more specific rule than it was excluded. In the above example, the proxy
/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ProxyConfig {
None,
Global {
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
url: Url,
},
ByDomain(Vec<PartialProxyConfig>),
}
impl ProxyConfig {
pub fn to_proxy(&self) -> Result<Option<Proxy>> {
Ok(match self.clone() {
ProxyConfig::None => None,
ProxyConfig::Global { url } => Some(Proxy::all(url)?),
ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
})),
})
}
}
impl Default for ProxyConfig {
fn default() -> Self {
ProxyConfig::None
}
}
#[derive(Clone, Debug, Deserialize)]
pub struct PartialProxyConfig {
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
url: Url,
#[serde(default)]
include: Vec<WildCardedDomain>,
#[serde(default)]
exclude: Vec<WildCardedDomain>,
}
impl PartialProxyConfig {
pub fn for_url(&self, url: &Url) -> Option<&Url> {
let domain = url.domain()?;
let mut included_because = None; // most specific reason it was included
let mut excluded_because = None; // most specific reason it was excluded
if self.include.is_empty() {
// treat empty include list as `*`
included_because = Some(&WildCardedDomain::WildCard)
}
for wc_domain in &self.include {
if wc_domain.matches(domain) {
match included_because {
Some(prev) if !wc_domain.more_specific_than(prev) => (),
_ => included_because = Some(wc_domain),
}
}
}
for wc_domain in &self.exclude {
if wc_domain.matches(domain) {
match excluded_because {
Some(prev) if !wc_domain.more_specific_than(prev) => (),
_ => excluded_because = Some(wc_domain),
}
}
}
match (included_because, excluded_because) {
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
(Some(_), None) => Some(&self.url),
_ => None,
}
}
}
/// A domain name, that optionally allows a * as its first subdomain.
#[derive(Clone, Debug)]
pub enum WildCardedDomain {
WildCard,
WildCarded(String),
Exact(String),
}
impl WildCardedDomain {
pub fn matches(&self, domain: &str) -> bool {
match self {
WildCardedDomain::WildCard => true,
WildCardedDomain::WildCarded(d) => domain.ends_with(d),
WildCardedDomain::Exact(d) => domain == d,
}
}
pub fn more_specific_than(&self, other: &Self) -> bool {
match (self, other) {
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
(_, WildCardedDomain::WildCard) => true,
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
a != b && a.ends_with(b)
}
_ => false,
}
}
}
impl std::str::FromStr for WildCardedDomain {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
// maybe do some domain validation?
Ok(if s.starts_with("*.") {
WildCardedDomain::WildCarded(s[1..].to_owned())
} else if s == "*" {
WildCardedDomain::WildCarded("".to_owned())
} else {
WildCardedDomain::Exact(s.to_owned())
})
}
}
impl<'de> serde::de::Deserialize<'de> for WildCardedDomain {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
crate::utils::deserialize_from_str(deserializer)
}
}

View File

@ -5,6 +5,7 @@ use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject};
use std::{ use std::{
cmp, cmp,
convert::TryInto, convert::TryInto,
str::FromStr,
time::{SystemTime, UNIX_EPOCH}, time::{SystemTime, UNIX_EPOCH},
}; };
@ -115,3 +116,29 @@ pub fn to_canonical_object<T: serde::Serialize>(
))), ))),
} }
} }
pub fn deserialize_from_str<
'de,
D: serde::de::Deserializer<'de>,
T: FromStr<Err = E>,
E: std::fmt::Display,
>(
deserializer: D,
) -> std::result::Result<T, D::Error> {
struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>);
impl<'de, T: FromStr<Err = Err>, Err: std::fmt::Display> serde::de::Visitor<'de>
for Visitor<T, Err>
{
type Value = T;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "a parsable string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
v.parse().map_err(|e| serde::de::Error::custom(e))
}
}
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
}