Skip to content

Commit

Permalink
Fix response part of rwf2#1067 by importing typed headers from 'heade…
Browse files Browse the repository at this point in the history
…rs' crate
  • Loading branch information
jespersm committed Sep 22, 2024
1 parent 3bf9ef0 commit 970af43
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 1 deletion.
1 change: 1 addition & 0 deletions core/http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ memchr = "2"
stable-pattern = "0.1"
cookie = { version = "0.18", features = ["percent-encode"] }
state = "0.6"
headers = "0.4.0"

[dependencies.serde]
version = "1.0"
Expand Down
145 changes: 145 additions & 0 deletions core/http/src/header/header.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use core::str;
use std::borrow::{Borrow, Cow};
use std::fmt;

use headers::{Header as HHeader, HeaderValue};
use indexmap::IndexMap;

use crate::uncased::{Uncased, UncasedStr};
Expand Down Expand Up @@ -798,10 +800,153 @@ impl From<&cookie::Cookie<'_>> for Header<'static> {
}
}

/// A destination for `HeaderValue`s that can be used to accumulate
/// a single header value using from hyperium headers' decode protocol.
#[derive(Default)]
struct HeaderValueDestination {
value: Option<HeaderValue>,
count: usize,
}

impl <'r>HeaderValueDestination {
fn into_value(self) -> HeaderValue {
if let Some(value) = self.value {
// TODO: if value.count > 1, then log that multiple header values are
// generated by the typed header, but that the dropped.
value
} else {
// Perhaps log that the typed header didn't create any values.
// This won't happen in the current implementation (headers 0.4.0).
HeaderValue::from_static("")
}
}

fn into_header_string(self) -> Cow<'static, str> {
let value = self.into_value();
// TODO: Optimize if we know this is a static reference.
value.to_str().unwrap_or("").to_string().into()
}
}

impl Extend<HeaderValue> for HeaderValueDestination {
fn extend<T: IntoIterator<Item = HeaderValue>>(&mut self, iter: T) {
for value in iter {
self.count += 1;
if self.value.is_none() {
self.value = Some(value)
}
}
}
}

macro_rules! import_typed_headers {
($($name:ident),*) => ($(
pub use headers::$name;

impl ::std::convert::From<self::$name> for Header<'static> {
fn from(header: self::$name) -> Self {
let mut destination = HeaderValueDestination::default();
header.encode(&mut destination);
let name = self::$name::name();
Header::new(name.as_str(), destination.into_header_string())
}
}
)*)
}

macro_rules! import_generic_typed_headers {
($($name:ident<$bound:ident>),*) => ($(
pub use headers::$name;

impl <T1: 'static + $bound>::std::convert::From<self::$name<T1>>
for Header<'static> {
fn from(header: self::$name<T1>) -> Self {
let mut destination = HeaderValueDestination::default();
header.encode(&mut destination);
let name = self::$name::<T1>::name();
Header::new(name.as_str(), destination.into_header_string())
}
}
)*)
}

// The following headers from 'headers' 0.4 are not imported, since they are
// provided by other Rocket features.

// * ContentType, // Content-Type header, defined in RFC7231
// * Cookie, // Cookie header, defined in RFC6265
// * Host, // The Host header.
// * Location, // Location header, defined in RFC7231
// * SetCookie, // Set-Cookie header, defined RFC6265

import_typed_headers! {
AcceptRanges, // Accept-Ranges header, defined in RFC7233
AccessControlAllowCredentials, // Access-Control-Allow-Credentials header, part of CORS
AccessControlAllowHeaders, // Access-Control-Allow-Headers header, part of CORS
AccessControlAllowMethods, // Access-Control-Allow-Methods header, part of CORS
AccessControlAllowOrigin, // The Access-Control-Allow-Origin response header, part of CORS
AccessControlExposeHeaders, // Access-Control-Expose-Headers header, part of CORS
AccessControlMaxAge, // Access-Control-Max-Age header, part of CORS
AccessControlRequestHeaders, // Access-Control-Request-Headers header, part of CORS
AccessControlRequestMethod, // Access-Control-Request-Method header, part of CORS
Age, // Age header, defined in RFC7234
Allow, // Allow header, defined in RFC7231
CacheControl, // Cache-Control header, defined in RFC7234 with extensions in RFC8246
Connection, // Connection header, defined in RFC7230
ContentDisposition, // A Content-Disposition header, (re)defined in RFC6266.
ContentEncoding, // Content-Encoding header, defined in RFC7231
ContentLength, // Content-Length header, defined in RFC7230
ContentLocation, // Content-Location header, defined in RFC7231
ContentRange, // Content-Range, described in RFC7233
Date, // Date header, defined in RFC7231
ETag, // ETag header, defined in RFC7232
Expect, // The Expect header.
Expires, // Expires header, defined in RFC7234
IfMatch, // If-Match header, defined in RFC7232
IfModifiedSince, // If-Modified-Since header, defined in RFC7232
IfNoneMatch, // If-None-Match header, defined in RFC7232
IfRange, // If-Range header, defined in RFC7233
IfUnmodifiedSince, // If-Unmodified-Since header, defined in RFC7232
LastModified, // Last-Modified header, defined in RFC7232
Origin, // The Origin header.
Pragma, // The Pragma header defined by HTTP/1.0.
Range, // Range header, defined in RFC7233
Referer, // Referer header, defined in RFC7231
ReferrerPolicy, // Referrer-Policy header, part of Referrer Policy
RetryAfter, // The Retry-After header.
SecWebsocketAccept, // The Sec-Websocket-Accept header.
SecWebsocketKey, // The Sec-Websocket-Key header.
SecWebsocketVersion, // The Sec-Websocket-Version header.
Server, // Server header, defined in RFC7231
StrictTransportSecurity, // StrictTransportSecurity header, defined in RFC6797
Te, // TE header, defined in RFC7230
TransferEncoding, // Transfer-Encoding header, defined in RFC7230
Upgrade, // Upgrade header, defined in RFC7230
UserAgent, // User-Agent header, defined in RFC7231
Vary // Vary header, defined in RFC7231
}

import_generic_typed_headers! {
Authorization<Credentials>, // Authorization header, defined in RFC7235
ProxyAuthorization<Credentials> // Proxy-Authorization header, defined in RFC7235
}

pub use headers::authorization::Credentials;

#[cfg(test)]
mod tests {
use std::time::SystemTime;

use super::HeaderMap;

#[test]
fn add_typed_header() {
use super::LastModified;
let mut map = HeaderMap::new();
map.add(LastModified::from(SystemTime::now()));
assert!(map.get_one("last-modified").unwrap().contains("GMT"));
}

#[test]
fn case_insensitive_add_get() {
let mut map = HeaderMap::new();
Expand Down
13 changes: 12 additions & 1 deletion core/http/src/header/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@ mod proxy_proto;
pub use self::content_type::ContentType;
pub use self::accept::{Accept, QMediaType};
pub use self::media_type::MediaType;
pub use self::header::{Header, HeaderMap};
pub use self::header::{
Header, HeaderMap, AcceptRanges, AccessControlAllowCredentials,
AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlAllowOrigin,
AccessControlExposeHeaders, AccessControlMaxAge, AccessControlRequestHeaders,
AccessControlRequestMethod, Age, Allow, CacheControl, Connection, ContentDisposition,
ContentEncoding, ContentLength, ContentLocation, ContentRange, Date, ETag, Expect,
Expires, IfMatch, IfModifiedSince, IfNoneMatch, IfRange, IfUnmodifiedSince,
LastModified, Origin, Pragma, Range, Referer, ReferrerPolicy, RetryAfter,
SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion, Server, StrictTransportSecurity,
Te, TransferEncoding, Upgrade, UserAgent, Vary, Authorization, ProxyAuthorization,
Credentials
};
pub use self::proxy_proto::ProxyProto;

pub(crate) use self::media_type::Source;
30 changes: 30 additions & 0 deletions core/lib/tests/typed-headers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#[macro_use]
extern crate rocket;

use std::time::{Duration, SystemTime};
use rocket::http::Expires;

#[derive(Responder)]
struct MyResponse {
body: String,
expires: Expires,
}

#[get("/")]
fn index() -> MyResponse {
let some_future_time =
SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(60 * 60 * 24 * 365 * 100)).unwrap();

MyResponse {
body: "Hello, world!".into(),
expires: Expires::from(some_future_time)
}
}

#[test]
fn typed_header() {
let rocket = rocket::build().mount("/", routes![index]);
let client = rocket::local::blocking::Client::debug(rocket).unwrap();
let response = client.get("/").dispatch();
assert_eq!(response.headers().get_one("Expires").unwrap(), "Sat, 07 Dec 2069 00:00:00 GMT");
}

0 comments on commit 970af43

Please sign in to comment.