commit e0e3e507862560e133336a90828c5a42026f7641 Author: Glenn Griffin Date: Fri Dec 6 16:39:17 2019 -0800 First in-progress version of httptest diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6936990 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +**/*.rs.bk +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..99bbc9a --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "httptest" +version = "0.1.0" +authors = ["Glenn Griffin "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +hyper = {version = "=0.13.0-alpha.4", features = ["unstable-stream"]} +futures-preview = {version = "=0.3.0-alpha.19", features = ["std", "async-await"]} +tokio = "=0.2.0-alpha.6" +crossbeam-channel = "0.4.0" +http = "0.1.18" +log = "0.4.8" +bstr = "0.2.8" +regex = "1.3.1" +url = "2.1.0" +serde_json = "1.0.44" +serde = "1.0.103" + +[dev-dependencies] +pretty_env_logger = "0.3.1" diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..b01c6ef --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,29 @@ +#[macro_export] +macro_rules! all_of { + ($($x:expr),*) => ($crate::mappers::all_of($crate::vec_of_boxes![$($x),*])); + ($($x:expr,)*) => ($crate::all_of![$($x),*]); +} + +#[macro_export] +macro_rules! any_of { + ($($x:expr),*) => ($crate::mappers::any_of($crate::vec_of_boxes![$($x),*])); + ($($x:expr,)*) => ($crate::any_of![$($x),*]); +} + +#[macro_export] +macro_rules! vec_of_boxes { + ($($x:expr),*) => (std::vec![$(std::boxed::Box::new($x)),*]); + ($($x:expr,)*) => ($crate::vec_of_boxes![$($x),*]); +} + +pub mod mappers; +pub mod responders; +pub mod server; + +pub type FullRequest = hyper::Request>; +pub type FullResponse = hyper::Response>; +pub use mappers::Matcher; + +pub use server::Expectation; +pub use server::Server; +pub use server::Times; diff --git a/src/mappers.rs b/src/mappers.rs new file mode 100644 index 0000000..4ba7565 --- /dev/null +++ b/src/mappers.rs @@ -0,0 +1,349 @@ +use std::borrow::Borrow; +use std::fmt; +use std::marker::PhantomData; + +// import the any_of and all_of macros from crate root so they are accessible if +// people glob import this module. +pub use crate::all_of; +pub use crate::any_of; +pub mod request; +pub mod response; + +pub trait Mapper: Send + fmt::Debug +where + IN: ?Sized, +{ + type Out; + + fn map(&mut self, input: &IN) -> Self::Out; +} + +// Matcher is just a special case of Mapper that returns a boolean. Simply +// provides the `matches` method rather than `map` as that reads a little +// better. +pub trait Matcher: Send + fmt::Debug +where + IN: ?Sized, +{ + fn matches(&mut self, input: &IN) -> bool; +} +impl Matcher for T +where + T: Mapper, +{ + fn matches(&mut self, input: &IN) -> bool { + self.map(input) + } +} + +pub fn any() -> impl Mapper { + Any +} +#[derive(Debug)] +pub struct Any; +impl Mapper for Any { + type Out = bool; + + fn map(&mut self, _input: &IN) -> bool { + true + } +} + +pub fn contains(value: T) -> impl Mapper +where + T: AsRef<[u8]> + fmt::Debug + Send, + IN: AsRef<[u8]> + ?Sized, +{ + Contains(value) +} +#[derive(Debug)] +pub struct Contains(T); +impl Mapper for Contains +where + T: AsRef<[u8]> + fmt::Debug + Send, + IN: AsRef<[u8]> + ?Sized, +{ + type Out = bool; + + fn map(&mut self, input: &IN) -> bool { + use bstr::ByteSlice; + input.as_ref().contains_str(self.0.as_ref()) + } +} + +pub fn eq(value: T) -> impl Mapper +where + T: Borrow + fmt::Debug + Send, + IN: PartialEq + ?Sized, +{ + Eq(value) +} +#[derive(Debug)] +pub struct Eq(T); +impl Mapper for Eq +where + T: Borrow + fmt::Debug + Send, + IN: PartialEq + ?Sized, +{ + type Out = bool; + + fn map(&mut self, input: &IN) -> bool { + self.0.borrow() == input + } +} + +pub fn matches(value: &str) -> impl Mapper +where + IN: AsRef<[u8]> + ?Sized, +{ + let regex = regex::bytes::Regex::new(value).expect("failed to create regex"); + Matches(regex) +} +#[derive(Debug)] +pub struct Matches(regex::bytes::Regex); +impl Mapper for Matches +where + IN: AsRef<[u8]> + ?Sized, +{ + type Out = bool; + + fn map(&mut self, input: &IN) -> bool { + self.0.is_match(input.as_ref()) + } +} + +pub fn not(inner: C) -> impl Mapper +where + C: Mapper, + IN: ?Sized, +{ + Not(inner, PhantomData) +} +pub struct Not(C, PhantomData) +where + IN: ?Sized; +impl Mapper for Not +where + C: Mapper, + IN: ?Sized, +{ + type Out = bool; + + fn map(&mut self, input: &IN) -> bool { + !self.0.map(input) + } +} +impl fmt::Debug for Not +where + C: Mapper, + IN: ?Sized, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Not({:?})", &self.0) + } +} + +pub fn all_of(inner: Vec>>) -> impl Mapper +where + IN: fmt::Debug + ?Sized, +{ + AllOf(inner) +} + +#[derive(Debug)] +pub struct AllOf(Vec>>) +where + IN: ?Sized; +impl Mapper for AllOf +where + IN: fmt::Debug + ?Sized, +{ + type Out = bool; + + fn map(&mut self, input: &IN) -> bool { + self.0.iter_mut().all(|maper| maper.map(input)) + } +} + +pub fn any_of(inner: Vec>>) -> impl Mapper +where + IN: fmt::Debug + ?Sized, +{ + AnyOf(inner) +} +#[derive(Debug)] +pub struct AnyOf(Vec>>) +where + IN: ?Sized; +impl Mapper for AnyOf +where + IN: fmt::Debug + ?Sized, +{ + type Out = bool; + + fn map(&mut self, input: &IN) -> bool { + self.0.iter_mut().any(|maper| maper.map(input)) + } +} + +pub fn uri_decoded(inner: C) -> impl Mapper +where + IN: AsRef<[u8]> + ?Sized, + C: Mapper<[(String, String)]>, +{ + UriDecoded(inner) +} +#[derive(Debug)] +pub struct UriDecoded(C); +impl Mapper for UriDecoded +where + IN: AsRef<[u8]> + ?Sized, + C: Mapper<[(String, String)]>, +{ + type Out = C::Out; + + fn map(&mut self, input: &IN) -> C::Out { + let decoded: Vec<(String, String)> = url::form_urlencoded::parse(input.as_ref()) + .into_owned() + .collect(); + self.0.map(&decoded) + } +} + +pub fn json_decoded(inner: C) -> impl Mapper +where + IN: AsRef<[u8]> + ?Sized, + C: Mapper, +{ + JsonDecoded(inner) +} +#[derive(Debug)] +pub struct JsonDecoded(C); +impl Mapper for JsonDecoded +where + IN: AsRef<[u8]> + ?Sized, + C: Mapper, +{ + type Out = C::Out; + + fn map(&mut self, input: &IN) -> C::Out { + let json_value: serde_json::Value = + serde_json::from_slice(input.as_ref()).unwrap_or(serde_json::Value::Null); + self.0.map(&json_value) + } +} + +pub fn lowercase(inner: C) -> impl Mapper +where + IN: AsRef<[u8]> + ?Sized, + C: Mapper<[u8]>, +{ + Lowercase(inner) +} +#[derive(Debug)] +pub struct Lowercase(C); +impl Mapper for Lowercase +where + IN: AsRef<[u8]> + ?Sized, + C: Mapper<[u8]>, +{ + type Out = C::Out; + + fn map(&mut self, input: &IN) -> C::Out { + use bstr::ByteSlice; + self.0.map(&input.as_ref().to_lowercase()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_contains() { + let mut c = contains("foo"); + assert_eq!(true, c.map("foobar")); + assert_eq!(true, c.map("bazfoobar")); + assert_eq!(false, c.map("bar")); + } + + #[test] + fn test_eq() { + let mut c = eq("foo"); + assert_eq!(false, c.map("foobar")); + assert_eq!(false, c.map("bazfoobar")); + assert_eq!(false, c.map("bar")); + assert_eq!(true, c.map("foo")); + } + + #[test] + fn test_matches() { + let mut c = matches(r#"^foo\d*bar$"#); + assert_eq!(true, c.map("foobar")); + assert_eq!(true, c.map("foo99bar")); + assert_eq!(false, c.map("foo99barz")); + assert_eq!(false, c.map("bat")); + } + + #[test] + fn test_not() { + let mut c = not(matches(r#"^foo\d*bar$"#)); + assert_eq!(false, c.map("foobar")); + assert_eq!(false, c.map("foo99bar")); + assert_eq!(true, c.map("foo99barz")); + assert_eq!(true, c.map("bat")); + } + + #[test] + fn test_all_of() { + let mut c = all_of![contains("foo"), contains("bar")]; + assert_eq!(true, c.map("foobar")); + assert_eq!(true, c.map("barfoo")); + assert_eq!(false, c.map("foo")); + assert_eq!(false, c.map("bar")); + } + + #[test] + fn test_any_of() { + let mut c = any_of![contains("foo"), contains("bar")]; + assert_eq!(true, c.map("foobar")); + assert_eq!(true, c.map("barfoo")); + assert_eq!(true, c.map("foo")); + assert_eq!(true, c.map("bar")); + assert_eq!(false, c.map("baz")); + } + + #[test] + fn test_uri_decoded() { + let expected = vec![ + ("key 1".to_owned(), "value 1".to_owned()), + ("key2".to_owned(), "".to_owned()), + ]; + let mut c = request::query(uri_decoded(eq(expected))); + let req = http::Request::get("https://example.com/path?key%201=value%201&key2") + .body(Vec::new()) + .unwrap(); + + assert_eq!(true, c.map(&req)); + } + + #[test] + fn test_json_decoded() { + let mut c = json_decoded(eq(serde_json::json!({ + "foo": 1, + "bar": 99, + }))); + assert_eq!(true, c.map(r#"{"foo": 1, "bar": 99}"#)); + assert_eq!(true, c.map(r#"{"bar": 99, "foo": 1}"#)); + assert_eq!(false, c.map(r#"{"foo": 1, "bar": 100}"#)); + } + + #[test] + fn test_lowercase() { + let mut c = lowercase(contains("foo")); + assert_eq!(true, c.map("FOO")); + assert_eq!(true, c.map("FoOBar")); + assert_eq!(true, c.map("foobar")); + assert_eq!(false, c.map("bar")); + } +} diff --git a/src/mappers/request.rs b/src/mappers/request.rs new file mode 100644 index 0000000..e18effa --- /dev/null +++ b/src/mappers/request.rs @@ -0,0 +1,177 @@ +use super::Mapper; +use crate::FullRequest; + +pub fn method(inner: C) -> impl Mapper +where + C: Mapper, +{ + Method(inner) +} +#[derive(Debug)] +pub struct Method(C); +impl Mapper for Method +where + C: Mapper, +{ + type Out = C::Out; + + fn map(&mut self, input: &FullRequest) -> C::Out { + self.0.map(input.method().as_str()) + } +} + +pub fn path(inner: C) -> impl Mapper +where + C: Mapper, +{ + Path(inner) +} +#[derive(Debug)] +pub struct Path(C); +impl Mapper for Path +where + C: Mapper, +{ + type Out = C::Out; + + fn map(&mut self, input: &FullRequest) -> C::Out { + self.0.map(input.uri().path()) + } +} + +pub fn query(inner: C) -> impl Mapper +where + C: Mapper, +{ + Query(inner) +} +#[derive(Debug)] +pub struct Query(C); +impl Mapper for Query +where + C: Mapper, +{ + type Out = C::Out; + + fn map(&mut self, input: &FullRequest) -> C::Out { + self.0.map(input.uri().query().unwrap_or("")) + } +} + +pub fn headers(inner: C) -> impl Mapper +where + C: Mapper<[(Vec, Vec)]>, +{ + Headers(inner) +} +#[derive(Debug)] +pub struct Headers(C); +impl Mapper for Headers +where + C: Mapper<[(Vec, Vec)]>, +{ + type Out = C::Out; + + fn map(&mut self, input: &FullRequest) -> C::Out { + let headers: Vec<(Vec, Vec)> = input + .headers() + .iter() + .map(|(k, v)| (k.as_str().into(), v.as_bytes().into())) + .collect(); + self.0.map(&headers) + } +} + +pub fn body(inner: C) -> impl Mapper +where + C: Mapper<[u8]>, +{ + Body(inner) +} +#[derive(Debug)] +pub struct Body(C); +impl Mapper for Body +where + C: Mapper<[u8]>, +{ + type Out = C::Out; + + fn map(&mut self, input: &FullRequest) -> C::Out { + self.0.map(input.body()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mappers::*; + + #[test] + fn test_path() { + let req = hyper::Request::get("https://example.com/foo") + .body(Vec::new()) + .unwrap(); + assert!(path(eq("/foo")).map(&req)); + + let req = hyper::Request::get("https://example.com/foobar") + .body(Vec::new()) + .unwrap(); + assert!(path(eq("/foobar")).map(&req)) + } + + #[test] + fn test_query() { + let req = hyper::Request::get("https://example.com/path?foo=bar&baz=bat") + .body(Vec::new()) + .unwrap(); + assert!(query(eq("foo=bar&baz=bat")).map(&req)); + let req = hyper::Request::get("https://example.com/path?search=1") + .body(Vec::new()) + .unwrap(); + assert!(query(eq("search=1")).map(&req)); + } + + #[test] + fn test_method() { + let req = hyper::Request::get("https://example.com/foo") + .body(Vec::new()) + .unwrap(); + assert!(method(eq("GET")).map(&req)); + let req = hyper::Request::post("https://example.com/foobar") + .body(Vec::new()) + .unwrap(); + assert!(method(eq("POST")).map(&req)); + } + + #[test] + fn test_headers() { + let expected = vec![ + (Vec::from("host"), Vec::from("example.com")), + (Vec::from("content-length"), Vec::from("101")), + ]; + let mut req = hyper::Request::get("https://example.com/path?key%201=value%201&key2") + .body(Vec::new()) + .unwrap(); + req.headers_mut().extend(vec![ + ( + hyper::header::HOST, + hyper::header::HeaderValue::from_static("example.com"), + ), + ( + hyper::header::CONTENT_LENGTH, + hyper::header::HeaderValue::from_static("101"), + ), + ]); + + assert!(headers(eq(expected)).map(&req)); + } + + #[test] + fn test_body() { + use bstr::{ByteVec, B}; + let req = hyper::Request::get("https://example.com/foo") + .body(Vec::from_slice("my request body")) + .unwrap(); + assert!(body(eq(B("my request body"))).map(&req)); + } +} diff --git a/src/mappers/response.rs b/src/mappers/response.rs new file mode 100644 index 0000000..2c45a00 --- /dev/null +++ b/src/mappers/response.rs @@ -0,0 +1,109 @@ +use super::Mapper; +use crate::FullResponse; + +pub fn status_code(inner: C) -> impl Mapper +where + C: Mapper, +{ + StatusCode(inner) +} +#[derive(Debug)] +pub struct StatusCode(C); +impl Mapper for StatusCode +where + C: Mapper, +{ + type Out = C::Out; + + fn map(&mut self, input: &FullResponse) -> C::Out { + self.0.map(&input.status().as_u16()) + } +} + +pub fn headers(inner: C) -> impl Mapper +where + C: Mapper<[(Vec, Vec)]>, +{ + Headers(inner) +} +#[derive(Debug)] +pub struct Headers(C); +impl Mapper for Headers +where + C: Mapper<[(Vec, Vec)]>, +{ + type Out = C::Out; + + fn map(&mut self, input: &FullResponse) -> C::Out { + let headers: Vec<(Vec, Vec)> = input + .headers() + .iter() + .map(|(k, v)| (k.as_str().into(), v.as_bytes().into())) + .collect(); + self.0.map(&headers) + } +} + +pub fn body(inner: C) -> impl Mapper +where + C: Mapper<[u8]>, +{ + Body(inner) +} +#[derive(Debug)] +pub struct Body(C); +impl Mapper for Body +where + C: Mapper<[u8]>, +{ + type Out = C::Out; + + fn map(&mut self, input: &FullResponse) -> C::Out { + self.0.map(input.body()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mappers::*; + + #[test] + fn test_status_code() { + let resp = hyper::Response::builder() + .status(hyper::StatusCode::NOT_FOUND) + .body(Vec::new()) + .unwrap(); + assert!(status_code(eq(404)).map(&resp)); + + let resp = hyper::Response::builder() + .status(hyper::StatusCode::OK) + .body(Vec::new()) + .unwrap(); + assert!(status_code(eq(200)).map(&resp)); + } + + #[test] + fn test_headers() { + let expected = vec![ + (Vec::from("host"), Vec::from("example.com")), + (Vec::from("content-length"), Vec::from("101")), + ]; + let resp = hyper::Response::builder() + .header("host", "example.com") + .header("content-length", 101) + .body(Vec::new()) + .unwrap(); + + assert!(headers(eq(expected)).map(&resp)); + } + + #[test] + fn test_body() { + use bstr::{ByteVec, B}; + let resp = hyper::Response::builder() + .body(Vec::from_slice("my request body")) + .unwrap(); + assert!(body(eq(B("my request body"))).map(&resp)); + } +} diff --git a/src/responders.rs b/src/responders.rs new file mode 100644 index 0000000..fd6a77e --- /dev/null +++ b/src/responders.rs @@ -0,0 +1,85 @@ +use std::fmt; +use std::future::Future; +use std::pin::Pin; + +pub trait Responder: Send + fmt::Debug { + fn respond(&mut self) -> Pin> + Send>>; +} + +pub fn status_code(code: u16) -> impl Responder { + StatusCode(code) +} +#[derive(Debug)] +pub struct StatusCode(u16); +impl Responder for StatusCode { + fn respond(&mut self) -> Pin> + Send>> { + async fn _respond(status_code: u16) -> http::Response { + hyper::Response::builder() + .status(status_code) + .body(hyper::Body::empty()) + .unwrap() + } + Box::pin(_respond(self.0)) + } +} + +pub fn json_encoded(data: T) -> impl Responder +where + T: serde::Serialize, +{ + JsonEncoded(serde_json::to_vec(&data).unwrap()) +} +#[derive(Debug)] +pub struct JsonEncoded(Vec); +impl Responder for JsonEncoded { + fn respond(&mut self) -> Pin> + Send>> { + async fn _respond(body: Vec) -> http::Response { + hyper::Response::builder() + .status(200) + .header("Content-Type", "application/json") + .body(body.into()) + .unwrap() + } + Box::pin(_respond(self.0.clone())) + } +} + +impl Responder for crate::FullResponse { + fn respond(&mut self) -> Pin> + Send>> { + async fn _respond(resp: http::Response) -> http::Response { + resp + } + // Turn &hyper::Response> into a hyper::Response + let mut builder = hyper::Response::builder(); + builder + .status(self.status().clone()) + .version(self.version().clone()); + *builder.headers_mut().unwrap() = self.headers().clone(); + let resp = builder.body(self.body().clone().into()).unwrap(); + + Box::pin(_respond(resp)) + } +} + +// TODO: make a macro for this to avoid the vec![Box::new] dance. +pub fn cycle(responders: Vec>) -> impl Responder { + if responders.is_empty() { + panic!("empty vector provided to cycle"); + } + Cycle { idx: 0, responders } +} +#[derive(Debug)] +pub struct Cycle { + idx: usize, + responders: Vec>, +} +impl Responder for Cycle { + fn respond(&mut self) -> Pin> + Send>> { + let response = self.responders[self.idx].respond(); + self.idx = (self.idx + 1) % self.responders.len(); + response + } +} + +#[cfg(test)] +mod tests {} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..5b747ee --- /dev/null +++ b/src/server.rs @@ -0,0 +1,286 @@ +use crate::responders::Responder; +use crate::{FullRequest, Matcher}; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; + +pub struct Server { + trigger_shutdown: Option>, + join_handle: Option>, + addr: SocketAddr, + state: ServerState, +} + +impl Server { + pub fn run() -> Self { + use futures::future::FutureExt; + use hyper::{ + service::{make_service_fn, service_fn}, + Error, + }; + let bind_addr = ([127, 0, 0, 1], 0).into(); + // And a MakeService to handle each connection... + let state = ServerState::default(); + let make_service = make_service_fn({ + let state = state.clone(); + move |_| { + let state = state.clone(); + async move { + let state = state.clone(); + Ok::<_, Error>(service_fn({ + let state = state.clone(); + move |req: hyper::Request| { + let state = state.clone(); + async move { + // read the full body into memory prior to handing it to mappers. + let (head, body) = req.into_parts(); + use futures::TryStreamExt; + let full_body = body.try_concat().await?; + let req = hyper::Request::from_parts(head, full_body.to_vec()); + log::debug!("Received Request: {:?}", req); + let resp = on_req(state, req).await; + log::debug!("Sending Response: {:?}", resp); + hyper::Result::Ok(resp) + } + } + })) + } + } + }); + // Then bind and serve... + let server = hyper::Server::bind(&bind_addr).serve(make_service); + let addr = server.local_addr(); + let (trigger_shutdown, shutdown_received) = futures::channel::oneshot::channel(); + let join_handle = std::thread::spawn(move || { + let mut runtime = tokio::runtime::current_thread::Runtime::new().unwrap(); + runtime.block_on(async move { + futures::select! { + _ = server.fuse() => {}, + _ = shutdown_received.fuse() => {}, + } + }); + }); + + Server { + trigger_shutdown: Some(trigger_shutdown), + join_handle: Some(join_handle), + addr, + state, + } + } + + pub fn addr(&self) -> SocketAddr { + self.addr + } + + pub fn url(&self, path_and_query: T) -> http::Uri + where + http::uri::PathAndQuery: http::HttpTryFrom, + { + http::Uri::builder() + .scheme("http") + .authority(format!("{}", &self.addr).as_str()) + .path_and_query(path_and_query) + .build() + .unwrap() + } + + pub fn expect(&self, expectation: Expectation) { + self.state.push_expectation(expectation); + } + + pub fn verify_and_clear(&mut self) { + let mut state = self.state.lock(); + for expectation in state.expected.iter() { + let is_valid_cardinality = match &expectation.cardinality { + Times::AnyNumber => true, + Times::AtLeast(lower_bound) if expectation.hit_count >= *lower_bound => true, + Times::AtLeast(_) => false, + Times::AtMost(limit) if expectation.hit_count <= *limit => true, + Times::AtMost(_) => false, + Times::Between(range) + if expectation.hit_count <= *range.end() + && expectation.hit_count >= *range.start() => + { + true + } + Times::Between(_) => false, + Times::Exactly(limit) if expectation.hit_count == *limit => true, + Times::Exactly(_) => false, + }; + if !is_valid_cardinality { + panic!(format!( + "Unexpected number of requests for matcher '{:?}'; received {}; expected {:?}", + &expectation.matcher, expectation.hit_count, &expectation.cardinality, + )); + } + } + state.expected.clear(); + if !state.unexpected_requests.is_empty() { + // TODO: format and print the requests. + panic!("unexpected requests received"); + } + } +} + +impl Drop for Server { + fn drop(&mut self) { + // drop the trigger_shutdown channel to tell the server to shutdown. + // Then wait for the shutdown to complete. + self.trigger_shutdown = None; + let _ = self.join_handle.take().unwrap().join(); + self.verify_and_clear(); + } +} + +async fn on_req(state: ServerState, req: FullRequest) -> http::Response { + let response_future = { + let mut state = state.lock(); + let mut iter = state.expected.iter_mut(); + let response_future = loop { + let expectation = match iter.next() { + None => break None, + Some(expectation) => expectation, + }; + if expectation.matcher.matches(&req) { + log::debug!("found matcher: {:?}", &expectation.matcher); + expectation.hit_count += 1; + let is_valid_cardinality = match &expectation.cardinality { + Times::AnyNumber => true, + Times::AtLeast(_) => true, + Times::AtMost(limit) if expectation.hit_count <= *limit => true, + Times::AtMost(_) => false, + Times::Between(range) if expectation.hit_count <= *range.end() => true, + Times::Between(_) => false, + Times::Exactly(limit) if expectation.hit_count <= *limit => true, + Times::Exactly(_) => false, + }; + if is_valid_cardinality { + break Some(expectation.responder.respond()); + } else { + break Some(Box::pin(cardinality_error( + &*expectation.matcher as &dyn Matcher, + &expectation.cardinality, + expectation.hit_count, + ))); + } + } + }; + if response_future.is_none() { + // TODO: provide real request id. + state.unexpected_requests.push(RequestID(1)); + } + response_future + }; + if let Some(f) = response_future { + f.await + } else { + http::Response::builder() + .status(http::StatusCode::INTERNAL_SERVER_ERROR) + .body(hyper::Body::empty()) + .unwrap() + } +} + +#[derive(Debug, Clone)] +pub enum Times { + AnyNumber, + AtLeast(usize), + AtMost(usize), + Between(std::ops::RangeInclusive), + Exactly(usize), +} + +pub struct Expectation { + matcher: Box>, + cardinality: Times, + responder: Box, + hit_count: usize, +} + +impl Expectation { + pub fn matching(matcher: impl Matcher + 'static) -> ExpectationBuilder { + ExpectationBuilder { + matcher: Box::new(matcher), + cardinality: Times::Exactly(1), + } + } +} + +pub struct ExpectationBuilder { + matcher: Box>, + cardinality: Times, +} + +impl ExpectationBuilder { + pub fn times(self, cardinality: Times) -> ExpectationBuilder { + ExpectationBuilder { + cardinality, + ..self + } + } + + pub fn respond_with(self, responder: impl Responder + 'static) -> Expectation { + Expectation { + matcher: self.matcher, + cardinality: self.cardinality, + responder: Box::new(responder), + hit_count: 0, + } + } +} + +#[derive(Debug, Clone, Copy)] +struct RequestID(u64); + +#[derive(Clone)] +struct ServerState(Arc>); + +impl ServerState { + fn lock(&self) -> std::sync::MutexGuard { + self.0.lock().expect("mutex poisoned") + } + + fn push_expectation(&self, expectation: Expectation) { + let mut inner = self.lock(); + inner.expected.push(expectation); + } +} + +impl Default for ServerState { + fn default() -> Self { + ServerState(Default::default()) + } +} + +struct ServerStateInner { + unexpected_requests: Vec, + expected: Vec, +} + +impl Default for ServerStateInner { + fn default() -> Self { + ServerStateInner { + unexpected_requests: Default::default(), + expected: Default::default(), + } + } +} + +fn cardinality_error( + matcher: &dyn Matcher, + cardinality: &Times, + hit_count: usize, +) -> Pin> + Send + 'static>> { + let body = hyper::Body::from(format!( + "Unexpected number of requests for matcher '{:?}'; received {}; expected {:?}", + matcher, hit_count, cardinality, + )); + Box::pin(async move { + http::Response::builder() + .status(http::StatusCode::INTERNAL_SERVER_ERROR) + .body(body) + .unwrap() + }) +} diff --git a/tests/tests.rs b/tests/tests.rs new file mode 100644 index 0000000..d8c2fce --- /dev/null +++ b/tests/tests.rs @@ -0,0 +1,145 @@ +use httptest::{mappers::*, responders::*, Expectation, Times}; + +async fn read_response_body(resp: hyper::Response) -> hyper::Response> { + use futures::stream::TryStreamExt; + let (head, body) = resp.into_parts(); + let body = body.try_concat().await.unwrap().to_vec(); + hyper::Response::from_parts(head, body) +} + +#[tokio::test] +async fn test_server() { + let _ = pretty_env_logger::try_init(); + + // Setup a server to expect a single GET /foo request. + let server = httptest::Server::run(); + server.expect( + Expectation::matching(all_of![ + request::method(eq("GET")), + request::path(eq("/foo")) + ]) + .times(Times::Exactly(1)) + .respond_with(status_code(200)), + ); + + // Issue the GET /foo to the server and verify it returns a 200. + let client = hyper::Client::new(); + let resp = read_response_body(client.get(server.url("/foo")).await.unwrap()).await; + assert!(response::status_code(eq(200)).matches(&resp)); + + // The Drop impl of the server will assert that all expectations were satisfied or else it will panic. +} + +#[tokio::test] +#[should_panic] +async fn test_expectation_cardinality_not_reached() { + let _ = pretty_env_logger::try_init(); + + // Setup a server to expect a single GET /foo request. + let server = httptest::Server::run(); + server.expect( + Expectation::matching(all_of![ + request::method(eq("GET")), + request::path(eq("/foo")) + ]) + .times(Times::Exactly(1)) + .respond_with(status_code(200)), + ); + + // Don't send any requests. Should panic. +} + +#[tokio::test] +#[should_panic] +async fn test_expectation_cardinality_exceeded() { + let _ = pretty_env_logger::try_init(); + + // Setup a server to expect a single GET /foo request. + let server = httptest::Server::run(); + server.expect( + Expectation::matching(all_of![ + request::method(eq("GET")), + request::path(eq("/foo")) + ]) + .times(Times::Exactly(1)) + .respond_with( + http::Response::builder() + .status(http::StatusCode::INTERNAL_SERVER_ERROR) + .body(Vec::new()) + .unwrap(), + ), + ); + + // Issue the GET /foo to the server and verify it returns a 200. + let client = hyper::Client::new(); + let resp = read_response_body(client.get(server.url("/foo")).await.unwrap()).await; + assert!(response::status_code(eq(200)).matches(&resp)); + + // Issue a second GET /foo and verify it returns a 500 because the cardinality of the expectation has been exceeded. + let resp = read_response_body(client.get(server.url("/foo")).await.unwrap()).await; + assert!(response::status_code(eq(500)).matches(&resp)); + + // Should panic on Server drop. +} + +#[tokio::test] +async fn test_json() { + let _ = pretty_env_logger::try_init(); + + let my_data = serde_json::json!({ + "foo": "bar", + "baz": [1, 2, 3], + }); + + // Setup a server to expect a single GET /foo request and respond with a + // json encoding of my_data. + let server = httptest::Server::run(); + server.expect( + Expectation::matching(all_of![ + request::method(eq("GET")), + request::path(eq("/foo")) + ]) + .times(Times::Exactly(1)) + .respond_with(json_encoded(my_data.clone())), + ); + + // Issue the GET /foo to the server and verify it returns a 200. + let client = hyper::Client::new(); + let resp = read_response_body(client.get(server.url("/foo")).await.unwrap()).await; + assert!(all_of![ + response::status_code(eq(200)), + response::body(json_decoded(eq(my_data))), + ] + .matches(&resp)); +} + +#[tokio::test] +async fn test_cycle() { + let _ = pretty_env_logger::try_init(); + + // Setup a server to expect a single GET /foo request and respond with a + // json encoding of my_data. + let server = httptest::Server::run(); + server.expect( + Expectation::matching(all_of![ + request::method(eq("GET")), + request::path(eq("/foo")) + ]) + .times(Times::Exactly(4)) + .respond_with(cycle(vec![ + Box::new(status_code(200)), + Box::new(status_code(404)), + ])), + ); + + // Issue the GET /foo to the server and verify it returns a 200. + let client = hyper::Client::new(); + let resp = read_response_body(client.get(server.url("/foo")).await.unwrap()).await; + assert!(response::status_code(eq(200)).matches(&resp)); + let resp = read_response_body(client.get(server.url("/foo")).await.unwrap()).await; + assert!(response::status_code(eq(404)).matches(&resp)); + let resp = read_response_body(client.get(server.url("/foo")).await.unwrap()).await; + assert!(response::status_code(eq(200)).matches(&resp)); + let resp = read_response_body(client.get(server.url("/foo")).await.unwrap()).await; + assert!(response::status_code(eq(404)).matches(&resp)); +}