1use hmac::{Hmac, KeyInit, Mac};
6use sha2::{Digest, Sha256};
7use std::collections::HashMap;
8use std::time::Duration;
9
10type HmacSha256 = Hmac<Sha256>;
11
12#[derive(Debug, PartialEq)]
18pub struct S3Uri {
19 pub bucket: String,
20 pub key: String,
21}
22
23pub fn parse_s3_uri(uri: &str) -> Result<S3Uri, String> {
27 let rest = uri
28 .strip_prefix("s3://")
29 .ok_or_else(|| format!("S3 URI must start with s3://, got: {uri}"))?;
30
31 let slash = rest
32 .find('/')
33 .ok_or_else(|| format!("S3 URI missing key after bucket: {uri}"))?;
34
35 let bucket = &rest[..slash];
36 let key = &rest[slash + 1..];
37
38 if bucket.is_empty() {
39 return Err(format!("S3 URI has empty bucket: {uri}"));
40 }
41 if key.is_empty() {
42 return Err(format!("S3 URI has empty key: {uri}"));
43 }
44
45 Ok(S3Uri {
46 bucket: bucket.to_string(),
47 key: key.to_string(),
48 })
49}
50
51pub fn detect_bucket_region(bucket: &str) -> String {
67 let host = format!("{bucket}.s3.amazonaws.com");
68 detect_region_at(&host, 80, Duration::from_secs(2))
69}
70
71pub(crate) fn detect_region_at(host: &str, port: u16, timeout: Duration) -> String {
74 use std::io::{Read, Write};
75 use std::net::{TcpStream, ToSocketAddrs};
76
77 let addr_str = format!("{host}:{port}");
78 let sock_addr = match addr_str.to_socket_addrs() {
79 Ok(mut a) => match a.next() {
80 Some(s) => s,
81 None => return "eu-central-1".to_string(),
82 },
83 Err(_) => return "eu-central-1".to_string(),
84 };
85
86 let Ok(mut stream) = TcpStream::connect_timeout(&sock_addr, timeout) else {
87 return "eu-central-1".to_string();
88 };
89 stream.set_read_timeout(Some(timeout)).ok();
90 stream.set_write_timeout(Some(timeout)).ok();
91
92 let request = format!("HEAD / HTTP/1.0\r\nHost: {host}\r\nConnection: close\r\n\r\n");
93 if stream.write_all(request.as_bytes()).is_err() {
94 return "eu-central-1".to_string();
95 }
96
97 let mut buf = Vec::new();
98 stream.read_to_end(&mut buf).ok();
99
100 String::from_utf8_lossy(&buf)
102 .lines()
103 .find_map(|line| {
104 if line
105 .to_ascii_lowercase()
106 .starts_with("x-amz-bucket-region:")
107 {
108 line.split_once(':').map(|x| x.1.trim().to_string())
109 } else {
110 None
111 }
112 })
113 .unwrap_or_else(|| "eu-central-1".to_string())
114}
115
116fn sha256_hex(data: &[u8]) -> String {
121 hex::encode(Sha256::digest(data))
122}
123
124fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
125 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
126 mac.update(data);
127 mac.finalize().into_bytes().to_vec()
128}
129
130#[allow(clippy::too_many_arguments)]
142pub fn sign_put_request(
143 access_key: &str,
144 secret_key: &str,
145 session_token: &str,
146 region: &str,
147 bucket: &str,
148 key: &str,
149 body_sha256: &str,
150 amz_date: &str,
151 date_stamp: &str,
152) -> String {
153 let host = format!("{bucket}.s3.{region}.amazonaws.com");
154
155 let canonical_headers = format!(
157 "host:{host}\nx-amz-content-sha256:{body_sha256}\nx-amz-date:{amz_date}\nx-amz-security-token:{session_token}\n"
158 );
159 let signed_headers = "host;x-amz-content-sha256;x-amz-date;x-amz-security-token";
160 let canonical_request =
161 format!("PUT\n/{key}\n\n{canonical_headers}\n{signed_headers}\n{body_sha256}");
162
163 let credential_scope = format!("{date_stamp}/{region}/s3/aws4_request");
165 let string_to_sign = format!(
166 "AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{}",
167 sha256_hex(canonical_request.as_bytes())
168 );
169
170 let k_date = hmac_sha256(
172 format!("AWS4{secret_key}").as_bytes(),
173 date_stamp.as_bytes(),
174 );
175 let k_region = hmac_sha256(&k_date, region.as_bytes());
176 let k_service = hmac_sha256(&k_region, b"s3");
177 let k_signing = hmac_sha256(&k_service, b"aws4_request");
178
179 let signature = hex::encode(hmac_sha256(&k_signing, string_to_sign.as_bytes()));
180
181 format!(
182 "AWS4-HMAC-SHA256 Credential={access_key}/{credential_scope}, \
183 SignedHeaders={signed_headers}, Signature={signature}"
184 )
185}
186
187#[derive(Debug, Clone)]
193pub struct UploadCredentials {
194 pub access_key_id: String,
195 pub secret_access_key: String,
196 pub session_token: String,
197 pub expires_at: String,
199}
200
201pub fn s3_put(
206 agent: &ureq::Agent,
207 bucket: &str,
208 key: &str,
209 region: &str,
210 body: &[u8],
211 creds: &UploadCredentials,
212) -> Result<String, String> {
213 let base_url = format!("https://{bucket}.s3.{region}.amazonaws.com");
214 s3_put_to(agent, &base_url, bucket, key, region, body, creds)
215}
216
217pub(crate) fn s3_put_to(
220 agent: &ureq::Agent,
221 base_url: &str,
222 bucket: &str,
223 key: &str,
224 region: &str,
225 body: &[u8],
226 creds: &UploadCredentials,
227) -> Result<String, String> {
228 let now = std::time::SystemTime::now()
229 .duration_since(std::time::UNIX_EPOCH)
230 .unwrap_or_default();
231 let secs = now.as_secs();
232 let amz_date = format_amz_date(secs);
233 let date_stamp = &amz_date[..8];
234
235 let body_sha256 = sha256_hex(body);
236 let authorization = sign_put_request(
237 &creds.access_key_id,
238 &creds.secret_access_key,
239 &creds.session_token,
240 region,
241 bucket,
242 key,
243 &body_sha256,
244 &amz_date,
245 date_stamp,
246 );
247
248 let url = format!("{base_url}/{key}");
249 let result = agent
250 .put(&url)
251 .header("Content-Type", "application/gzip")
252 .header("Content-Length", &body.len().to_string())
253 .header("x-amz-content-sha256", &body_sha256)
254 .header("x-amz-date", &amz_date)
255 .header("x-amz-security-token", &creds.session_token)
256 .header("Authorization", &authorization)
257 .send(body);
258
259 match result {
260 Ok(r) if r.status() == 200 || r.status() == 201 => Ok(format!("s3://{bucket}/{key}")),
261 Ok(r) => Err(format!("S3 PUT returned HTTP {}: {}", r.status(), url)),
262 Err(e) => Err(format!("S3 PUT network error for {url}: {e}")),
263 }
264}
265
266pub fn format_amz_date(unix_secs: u64) -> String {
268 let (y, mo, d, h, mi, s) = epoch_to_utc(unix_secs);
269 format!("{y:04}{mo:02}{d:02}T{h:02}{mi:02}{s:02}Z")
270}
271
272fn epoch_to_utc(secs: u64) -> (u32, u32, u32, u32, u32, u32) {
275 let s = secs % 60;
276 let min = (secs / 60) % 60;
277 let h = (secs / 3600) % 24;
278
279 let days = secs / 86400;
281
282 let z = days + 719_468;
284 let era = z / 146_097;
285 let doe = z - era * 146_097;
286 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
287 let y = yoe + era * 400;
288 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
289 let mp = (5 * doy + 2) / 153;
290 let d = doy - (153 * mp + 2) / 5 + 1;
291 let mo = if mp < 10 { mp + 3 } else { mp - 9 };
292 let y = if mo <= 2 { y + 1 } else { y };
293
294 (
295 y as u32, mo as u32, d as u32, h as u32, min as u32, s as u32,
296 )
297}
298
299pub struct RegionCache(pub(crate) HashMap<String, String>);
306
307impl RegionCache {
308 pub fn new() -> Self {
309 Self(HashMap::new())
310 }
311
312 pub fn get_or_detect(&mut self, bucket: &str) -> String {
315 if let Some(r) = self.0.get(bucket) {
316 return r.clone();
317 }
318 let region = detect_bucket_region(bucket);
319 self.0.insert(bucket.to_string(), region.clone());
320 region
321 }
322}
323
324#[cfg(test)]
329mod tests {
330 use super::*;
331 use std::io::{Read, Write};
332
333 #[test]
335 fn test_parse_valid_s3_uri() {
336 let uri = parse_s3_uri("s3://my-bucket/path/to/obj.csv.gz").unwrap();
337 assert_eq!(uri.bucket, "my-bucket");
338 assert_eq!(uri.key, "path/to/obj.csv.gz");
339 }
340
341 #[test]
343 fn test_parse_https_uri_is_error() {
344 assert!(parse_s3_uri("https://bucket/path").is_err());
345 }
346
347 #[test]
349 fn test_parse_empty_key_is_error() {
350 assert!(parse_s3_uri("s3://bucket/").is_err());
351 }
352
353 #[test]
354 fn test_parse_missing_slash_is_error() {
355 assert!(parse_s3_uri("s3://bucket-only").is_err());
356 }
357
358 #[test]
363 fn test_sig_v4_golden_value() {
364 let auth = sign_put_request(
365 "AKIAIOSFODNN7EXAMPLE",
366 "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
367 "AQoDYXdzEJr//////////token",
368 "us-east-1",
369 "examplebucket",
370 "test/object.csv.gz",
371 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
372 "20130524T000000Z",
373 "20130524",
374 );
375
376 assert!(auth.starts_with("AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request"),
377 "unexpected auth header start: {auth}");
378 assert!(
379 auth.contains(
380 "SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token"
381 ),
382 "missing SignedHeaders: {auth}"
383 );
384
385 let sig = auth.split("Signature=").nth(1).unwrap_or("");
386 assert_eq!(
387 sig.len(),
388 64,
389 "signature should be 64 hex chars, got: {sig}"
390 );
391 assert!(
392 sig.chars().all(|c| c.is_ascii_hexdigit()),
393 "non-hex char in signature: {sig}"
394 );
395 }
396
397 #[test]
401 fn test_region_cache_skips_network_on_hit() {
402 let mut cache = RegionCache::new();
403 cache
405 .0
406 .insert("my-bucket".to_string(), "ap-southeast-1".to_string());
407
408 let r1 = cache.get_or_detect("my-bucket");
410 let r2 = cache.get_or_detect("my-bucket");
411 assert_eq!(r1, "ap-southeast-1");
412 assert_eq!(r2, "ap-southeast-1");
413 assert_eq!(cache.0.len(), 1);
415 }
416
417 #[test]
420 fn test_detect_region_from_mock_server() {
421 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
422 let port = listener.local_addr().unwrap().port();
423
424 std::thread::spawn(move || {
425 if let Ok((mut stream, _)) = listener.accept() {
426 let mut buf = [0u8; 256];
428 let _ = stream.read(&mut buf);
429 stream
430 .write_all(
431 b"HTTP/1.0 403 Forbidden\r\n\
432 x-amz-bucket-region: eu-west-1\r\n\
433 Content-Length: 0\r\n\r\n",
434 )
435 .ok();
436 }
437 });
438
439 let region = detect_region_at("127.0.0.1", port, Duration::from_secs(2));
440 assert_eq!(region, "eu-west-1");
441 }
442
443 #[test]
446 fn test_s3_put_to_mock_server_returns_uri() {
447 use std::sync::mpsc;
448
449 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
450 let port = listener.local_addr().unwrap().port();
451
452 let (tx, rx) = mpsc::channel::<Vec<u8>>();
454
455 std::thread::spawn(move || {
456 if let Ok((mut stream, _)) = listener.accept() {
457 let mut buf = vec![0u8; 8192];
459 let n = stream.read(&mut buf).unwrap_or(0);
460 buf.truncate(n);
461 tx.send(buf).ok();
462 stream
463 .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
464 .ok();
465 }
466 });
467
468 let agent = ureq::config::Config::builder()
469 .timeout_global(Some(Duration::from_secs(30)))
470 .build()
471 .new_agent();
472
473 let creds = UploadCredentials {
474 access_key_id: "AKID".to_string(),
475 secret_access_key: "SECRET".to_string(),
476 session_token: "TOKEN".to_string(),
477 expires_at: "2099-01-01T00:00:00Z".to_string(),
478 };
479
480 let base_url = format!("http://127.0.0.1:{port}");
481 let result = s3_put_to(
482 &agent,
483 &base_url,
484 "test-bucket",
485 "run-1/000001.csv.gz",
486 "us-east-1",
487 b"fake-gzip-content",
488 &creds,
489 );
490
491 assert!(result.is_ok(), "expected Ok, got: {result:?}");
492 assert_eq!(result.unwrap(), "s3://test-bucket/run-1/000001.csv.gz");
493
494 let raw_request = rx
499 .recv()
500 .expect("mock server did not send captured request");
501 let raw_str = String::from_utf8_lossy(&raw_request).to_ascii_lowercase();
502 assert!(
503 raw_str.contains("content-type: application/gzip"),
504 "expected 'content-type: application/gzip' in request headers, got:\n{raw_str}"
505 );
506 }
507
508 #[test]
509 fn test_format_amz_date_known_timestamp() {
510 assert_eq!(format_amz_date(1_369_353_600), "20130524T000000Z");
511 }
512
513 #[test]
514 fn test_epoch_to_utc_unix_epoch() {
515 assert_eq!(epoch_to_utc(0), (1970, 1, 1, 0, 0, 0));
516 }
517
518 #[test]
519 fn test_epoch_to_utc_known_date() {
520 assert_eq!(epoch_to_utc(1_775_046_896), (2026, 4, 1, 12, 34, 56));
523 }
524}