mirror of
https://git.asonix.dog/asonix/pict-rs
synced 2024-12-22 03:11:24 +00:00
postgres: allow connecting to TLS-enabled databases
This commit is contained in:
parent
f6087d65be
commit
19147e2035
7 changed files with 184 additions and 19 deletions
28
Cargo.lock
generated
28
Cargo.lock
generated
|
@ -1838,6 +1838,8 @@ dependencies = [
|
|||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
"reqwest-tracing",
|
||||
"rustls 0.22.2",
|
||||
"rustls-pemfile 2.0.0",
|
||||
"rusty-s3",
|
||||
"serde",
|
||||
"serde-tuple-vec-map",
|
||||
|
@ -1864,6 +1866,7 @@ dependencies = [
|
|||
"tracing-subscriber",
|
||||
"url",
|
||||
"uuid",
|
||||
"webpki-roots 0.26.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -2233,7 +2236,7 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls 0.21.10",
|
||||
"rustls-pemfile",
|
||||
"rustls-pemfile 1.0.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
|
@ -2247,7 +2250,7 @@ dependencies = [
|
|||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"webpki-roots",
|
||||
"webpki-roots 0.25.3",
|
||||
"winreg",
|
||||
]
|
||||
|
||||
|
@ -2370,6 +2373,8 @@ version = "0.22.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring 0.17.7",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.102.1",
|
||||
"subtle",
|
||||
|
@ -2385,6 +2390,16 @@ dependencies = [
|
|||
"base64 0.21.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.1.0"
|
||||
|
@ -3522,6 +3537,15 @@ version = "0.25.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10"
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0de2cfda980f21be5a7ed2eadb3e6fe074d56022bea2cdeb1a62eb220fc04188"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.4.1"
|
||||
|
|
|
@ -46,6 +46,10 @@ refinery = { version = "0.8.10", features = ["tokio-postgres", "postgres"] }
|
|||
reqwest = { version = "0.11.18", default-features = false, features = ["json", "rustls-tls", "stream"] }
|
||||
reqwest-middleware = "0.2.2"
|
||||
reqwest-tracing = { version = "0.4.5" }
|
||||
# pinned to tokio-postgres-rustls
|
||||
rustls = "0.22.0"
|
||||
# pinned to rustls
|
||||
rustls-pemfile = "2.0.0"
|
||||
rusty-s3 = "0.5.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde-tuple-vec-map = "1.0.1"
|
||||
|
@ -81,6 +85,8 @@ tracing-subscriber = { version = "0.3.0", features = [
|
|||
] }
|
||||
url = { version = "2.2", features = ["serde"] }
|
||||
uuid = { version = "1", features = ["serde", "std", "v4", "v7"] }
|
||||
# pinned to rustls
|
||||
webpki-roots = "0.26.0"
|
||||
|
||||
[dependencies.tracing-actix-web]
|
||||
version = "0.7.8"
|
||||
|
|
|
@ -1415,6 +1415,14 @@ pub(super) struct Postgres {
|
|||
/// The URL of the postgres database
|
||||
#[arg(short, long)]
|
||||
pub(super) url: Url,
|
||||
|
||||
/// whether to connect to postgres via TLS
|
||||
#[arg(short, long)]
|
||||
pub(super) use_tls: bool,
|
||||
|
||||
/// The path to the root certificate for postgres' CA
|
||||
#[arg(short, long)]
|
||||
pub(super) certificate_file: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser, serde::Serialize)]
|
||||
|
|
|
@ -389,7 +389,11 @@ impl From<crate::config::commandline::Sled> for crate::config::file::Sled {
|
|||
|
||||
impl From<crate::config::commandline::Postgres> for crate::config::file::Postgres {
|
||||
fn from(value: crate::config::commandline::Postgres) -> Self {
|
||||
crate::config::file::Postgres { url: value.url }
|
||||
crate::config::file::Postgres {
|
||||
url: value.url,
|
||||
use_tls: value.use_tls,
|
||||
certificate_file: value.certificate_file,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -458,4 +458,6 @@ pub(crate) struct Sled {
|
|||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) struct Postgres {
|
||||
pub(crate) url: Url,
|
||||
pub(crate) use_tls: bool,
|
||||
pub(crate) certificate_file: Option<PathBuf>,
|
||||
}
|
||||
|
|
|
@ -802,8 +802,13 @@ impl Repo {
|
|||
|
||||
Ok(Self::Sled(repo))
|
||||
}
|
||||
config::Repo::Postgres(config::Postgres { url }) => {
|
||||
let repo = self::postgres::PostgresRepo::connect(url).await?;
|
||||
config::Repo::Postgres(config::Postgres {
|
||||
url,
|
||||
use_tls,
|
||||
certificate_file,
|
||||
}) => {
|
||||
let repo =
|
||||
self::postgres::PostgresRepo::connect(url, use_tls, certificate_file).await?;
|
||||
|
||||
Ok(Self::Postgres(repo))
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ mod schema;
|
|||
|
||||
use std::{
|
||||
collections::{BTreeSet, VecDeque},
|
||||
path::PathBuf,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc, Weak,
|
||||
|
@ -22,7 +23,8 @@ use diesel_async::{
|
|||
};
|
||||
use futures_core::Stream;
|
||||
use tokio::sync::Notify;
|
||||
use tokio_postgres::{tls::NoTlsStream, AsyncMessage, Connection, NoTls, Notification, Socket};
|
||||
use tokio_postgres::{AsyncMessage, Connection, NoTls, Notification, Socket};
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
use tracing::Instrument;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
@ -82,6 +84,9 @@ pub(crate) enum ConnectPostgresError {
|
|||
#[error("Failed to connect to postgres for migrations")]
|
||||
ConnectForMigration(#[source] tokio_postgres::Error),
|
||||
|
||||
#[error("Failed to build TLS configuration")]
|
||||
Tls(#[source] TlsError),
|
||||
|
||||
#[error("Failed to run migrations")]
|
||||
Migration(#[source] refinery::Error),
|
||||
|
||||
|
@ -116,6 +121,21 @@ pub(crate) enum PostgresError {
|
|||
DbTimeout,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum TlsError {
|
||||
#[error("Couldn't read configured certificate file")]
|
||||
ReadCertificate(#[source] std::io::Error),
|
||||
|
||||
#[error("Couldn't parse configured certificate file: {0:?}")]
|
||||
ParseCertificate(rustls_pemfile::Error),
|
||||
|
||||
#[error("Configured certificate file is not a certificate")]
|
||||
NotCertificate,
|
||||
|
||||
#[error("Couldn't add certificate to root store")]
|
||||
AddCertificate(#[source] rustls::Error),
|
||||
}
|
||||
|
||||
impl PostgresError {
|
||||
pub(super) const fn error_code(&self) -> ErrorCode {
|
||||
match self {
|
||||
|
@ -146,13 +166,90 @@ impl PostgresError {
|
|||
}
|
||||
}
|
||||
|
||||
impl PostgresRepo {
|
||||
pub(crate) async fn connect(postgres_url: Url) -> Result<Self, ConnectPostgresError> {
|
||||
let (mut client, conn) = tokio_postgres::connect(postgres_url.as_str(), NoTls)
|
||||
async fn build_tls_connector(
|
||||
certificate_file: Option<PathBuf>,
|
||||
) -> Result<MakeRustlsConnect, TlsError> {
|
||||
let mut cert_store = rustls::RootCertStore {
|
||||
roots: Vec::from(webpki_roots::TLS_SERVER_ROOTS),
|
||||
};
|
||||
|
||||
if let Some(certificate_file) = certificate_file {
|
||||
let bytes = tokio::fs::read(certificate_file)
|
||||
.await
|
||||
.map_err(TlsError::ReadCertificate)?;
|
||||
|
||||
let opt =
|
||||
rustls_pemfile::read_one_from_slice(&bytes).map_err(TlsError::ParseCertificate)?;
|
||||
let (item, _remainder) = opt.ok_or(TlsError::NotCertificate)?;
|
||||
|
||||
let cert = if let rustls_pemfile::Item::X509Certificate(cert) = item {
|
||||
cert
|
||||
} else {
|
||||
return Err(TlsError::NotCertificate);
|
||||
};
|
||||
|
||||
cert_store.add(cert).map_err(TlsError::AddCertificate)?;
|
||||
}
|
||||
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(cert_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
let tls = MakeRustlsConnect::new(config);
|
||||
|
||||
Ok(tls)
|
||||
}
|
||||
|
||||
async fn connect_for_migrations(
|
||||
postgres_url: &Url,
|
||||
tls_connector: Option<MakeRustlsConnect>,
|
||||
) -> Result<
|
||||
(
|
||||
tokio_postgres::Client,
|
||||
DropHandle<Result<(), tokio_postgres::Error>>,
|
||||
),
|
||||
ConnectPostgresError,
|
||||
> {
|
||||
let tup = if let Some(connector) = tls_connector {
|
||||
let (client, conn) = tokio_postgres::connect(postgres_url.as_str(), connector)
|
||||
.await
|
||||
.map_err(ConnectPostgresError::ConnectForMigration)?;
|
||||
|
||||
let handle = crate::sync::abort_on_drop(crate::sync::spawn("postgres-migrations", conn));
|
||||
(
|
||||
client,
|
||||
crate::sync::abort_on_drop(crate::sync::spawn("postgres-connection", conn)),
|
||||
)
|
||||
} else {
|
||||
let (client, conn) = tokio_postgres::connect(postgres_url.as_str(), NoTls)
|
||||
.await
|
||||
.map_err(ConnectPostgresError::ConnectForMigration)?;
|
||||
|
||||
(
|
||||
client,
|
||||
crate::sync::abort_on_drop(crate::sync::spawn("postgres-connection", conn)),
|
||||
)
|
||||
};
|
||||
|
||||
Ok(tup)
|
||||
}
|
||||
|
||||
impl PostgresRepo {
|
||||
pub(crate) async fn connect(
|
||||
postgres_url: Url,
|
||||
use_tls: bool,
|
||||
certificate_file: Option<PathBuf>,
|
||||
) -> Result<Self, ConnectPostgresError> {
|
||||
let connector = if use_tls {
|
||||
Some(
|
||||
build_tls_connector(certificate_file)
|
||||
.await
|
||||
.map_err(ConnectPostgresError::Tls)?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (mut client, handle) = connect_for_migrations(&postgres_url, connector.clone()).await?;
|
||||
|
||||
embedded::migrations::runner()
|
||||
.run_async(&mut client)
|
||||
|
@ -169,7 +266,7 @@ impl PostgresRepo {
|
|||
let (tx, rx) = flume::bounded(10);
|
||||
|
||||
let mut config = ManagerConfig::default();
|
||||
config.custom_setup = build_handler(tx);
|
||||
config.custom_setup = build_handler(tx, connector);
|
||||
|
||||
let mgr = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
|
||||
postgres_url,
|
||||
|
@ -388,22 +485,39 @@ async fn delegate_notifications(
|
|||
tracing::warn!("Notification delegator shutting down");
|
||||
}
|
||||
|
||||
fn build_handler(sender: flume::Sender<Notification>) -> ConfigFn {
|
||||
fn build_handler(
|
||||
sender: flume::Sender<Notification>,
|
||||
connector: Option<MakeRustlsConnect>,
|
||||
) -> ConfigFn {
|
||||
Box::new(
|
||||
move |config: &str| -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> {
|
||||
let sender = sender.clone();
|
||||
let connector = connector.clone();
|
||||
|
||||
let connect_span = tracing::trace_span!(parent: None, "connect future");
|
||||
|
||||
Box::pin(
|
||||
async move {
|
||||
let (client, conn) =
|
||||
tokio_postgres::connect(config, tokio_postgres::tls::NoTls)
|
||||
let client = if let Some(connector) = connector {
|
||||
let (client, conn) = tokio_postgres::connect(config, connector)
|
||||
.await
|
||||
.map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
|
||||
|
||||
// not very cash money (structured concurrency) of me
|
||||
spawn_db_notification_task(sender, conn);
|
||||
// not very cash money (structured concurrency) of me
|
||||
spawn_db_notification_task(sender, conn);
|
||||
|
||||
client
|
||||
} else {
|
||||
let (client, conn) =
|
||||
tokio_postgres::connect(config, tokio_postgres::tls::NoTls)
|
||||
.await
|
||||
.map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
|
||||
|
||||
// not very cash money (structured concurrency) of me
|
||||
spawn_db_notification_task(sender, conn);
|
||||
|
||||
client
|
||||
};
|
||||
|
||||
AsyncPgConnection::try_from(client).await
|
||||
}
|
||||
|
@ -413,10 +527,12 @@ fn build_handler(sender: flume::Sender<Notification>) -> ConfigFn {
|
|||
)
|
||||
}
|
||||
|
||||
fn spawn_db_notification_task(
|
||||
fn spawn_db_notification_task<S>(
|
||||
sender: flume::Sender<Notification>,
|
||||
mut conn: Connection<Socket, NoTlsStream>,
|
||||
) {
|
||||
mut conn: Connection<Socket, S>,
|
||||
) where
|
||||
S: tokio_postgres::tls::TlsStream + Unpin + 'static,
|
||||
{
|
||||
crate::sync::spawn("postgres-notifications", async move {
|
||||
while let Some(res) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
|
||||
tracing::trace!("db_notification_task: looping");
|
||||
|
|
Loading…
Reference in a new issue