diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index be6ca521..37857c7d 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -25,8 +25,8 @@ jobs: override: true - name: Check ohkami's buildability - run: cargo check --features rt_${{ matrix.rt }}${{ matrix.x == '' && '' || ',' }}${{ matrix.x }}${{ matrix.toolchain == 'nightly' && ',nightly' || '' }},DEBUG + run: cargo check --features rt_${{ matrix.rt }}${{ matrix.x == '' && '' || ',' }}${{ matrix.x }}${{ matrix.toolchain == 'nightly' && ',nightly' || '' }},DEBUG - name: Check examples' buildablity working-directory: examples - run: cargo check + run: cargo check diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8b631c50..76fffef0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,7 +24,9 @@ jobs: profile: minimal override: true - - uses: actions-rs/cargo@v1 - with: - command: test - args: --features rt_${{ matrix.rt }}${{ matrix.x == '' && '' || ',' }}${{ matrix.x }}${{ matrix.toolchain == 'nightly' && ',nightly' || '' }},DEBUG + - name: Run ohkami's tests + run: cargo test --features rt_${{ matrix.rt }}${{ matrix.x == '' && '' || ',' }}${{ matrix.x }}${{ matrix.toolchain == 'nightly' && ',nightly' || '' }},DEBUG + + - name: Run tests of examples + working-directory: examples + run: cargo check diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 86885ec9..c4e2886a 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -13,7 +13,7 @@ categories = ["asynchronous", "web-programming::http-server"] license = "MIT" [package.metadata.docs.rs] -features = ["rt_tokio"] +features = ["rt_tokio", "websocket", "nightly"] [dependencies] ohkami_macros = { version = "0.5", path = "../ohkami_macros" } @@ -40,5 +40,5 @@ DEBUG = [ # "rt_tokio", # #"rt_async-std", # "DEBUG", -# # "nightly" -#] \ No newline at end of file +# "nightly", +#]# \ No newline at end of file diff --git a/ohkami/src/layer0_lib/base64.rs b/ohkami/src/layer0_lib/base64.rs new file mode 100644 index 00000000..1a361ff9 --- /dev/null +++ b/ohkami/src/layer0_lib/base64.rs @@ -0,0 +1,302 @@ +#[inline(always)] pub fn encode(src: impl AsRef<[u8]>) -> String { + encode_by( + src.as_ref(), + b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", + Some(b'='), + ) +} + +#[inline(always)] pub fn encode_url(src: impl AsRef<[u8]>) -> String { + encode_by( + src.as_ref(), + b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", + None, + ) +} + +#[cfg(test)] +#[inline(always)] pub fn decode(encoded: &[u8]) -> Vec { + decode_by( + encoded, + b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", + Some(b'='), + ) +} + +#[inline(always)] pub fn decode_url(encoded: &str) -> Vec { + decode_by( + encoded.as_bytes(), + b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", + None, + ) +} + + +fn encode_by(src: &[u8], encode_map: &[u8; 64], padding: Option) -> String { + let src = src.as_ref(); + let src_len = src.len(); + + if src_len == 0 { + return String::new() + } + + let mut dst = { + let encoded_len = + if padding.is_none() { + src_len / 3 * 4 + (src_len % 3 * 8 + 5) / 6 + } else { + (src_len + 2) / 3 * 4 + }; + vec![u8::default(); encoded_len] + }; + + let (mut di, mut si) = (0, 0); + let n = (src_len / 3) * 3; // `n` is `src_len - (src_len % 3)` + while si < n { + let val = (src[si+0] as usize)<<16 | (src[si+1] as usize)<<8 | (src[si+2] as usize); + + dst[di+0] = encode_map[val>>18&0x3F]; + dst[di+1] = encode_map[val>>12&0x3F]; + dst[di+2] = encode_map[val>>6&0x3F]; + dst[di+3] = encode_map[val&0x3F]; + + si += 3; + di += 4; + } + + let remain = src_len - si; // `remain` is `src_len % 3` + if remain == 0 { + return (|| unsafe {String::from_utf8_unchecked(dst)})() + } + + let mut val = (src[si+0] as usize) << 16; + if remain == 2 { + val |= (src[si+1] as usize) << 8; + } + + dst[di+0] = encode_map[val>>18&0x3F]; + dst[di+1] = encode_map[val>>12&0x3F]; + + match remain { + 2 => { + dst[di+2] = encode_map[val>>6&0x3F]; + if let Some(p) = padding { + dst[di+3] = p; + } + } + 1 => if let Some(p) = padding { + dst[di+2] = p; + dst[di+3] = p; + } + _ => unsafe {std::hint::unreachable_unchecked()} + } + + unsafe {String::from_utf8_unchecked(dst)} +} + +fn decode_by(encoded: &[u8], encode_map: &[u8; 64], padding: Option) -> Vec { + #[inline] fn assemble64(n: [u8; 8]) -> Option { + let [n1, n2, n3, n4, n5, n6, n7, n8] = n.map(>::into); + (n1|n2|n3|n4|n5|n6|n7|n8 != 0xff).then_some( + n1<<58 | n2<<52 | n3<<46 | n4<<40 | n5<<34 | n6<<28 | n7<<22 | n8<<16 + ) + } + #[inline] fn assemble32(n: [u8; 4]) -> Option { + let [n1, n2, n3, n4] = n.map(>::into); + (n1|n2|n3|n4 != 0xff).then_some( + n1<<26 | n2<<20 | n3<<14 | n4<<8 + ) + } + fn decode_quantum( + dst: &mut [u8], + encoded: &[u8], + mut si: usize, + decode_map: &[u8; 256], + padding: Option, + ) -> (/*new si*/usize, /*n increase*/usize) { + let mut d_len = 4; + let mut d_buf = [u8::default(); 4]; + + let mut i = 0; + while i < d_buf.len() { + if encoded.len() == si { + if i == 0 { + return (si, 0) + } else if i == 1 || padding.is_some() { + unreachable!("Illegal base64 data at input byte {}", si - i) + } + + d_len = i; + break + } + + let input = encoded[si]; + si += 1; + + let output = decode_map[input as usize]; + if output != 0xff { + d_buf[i] = output; + + i += 1; continue + } + + if matches!(input, b'\r' | b'\n') { + /* With no increase of `i` */ continue + } + + if padding != Some(input) { + unreachable!("Illegal base64 data at input byte {}", si - 1) + } + + /* We've reached the end and there's padding */ + match i { + 0 | 1 => unreachable!("Illegal base64 data at input byte {}: incorrect padding", si - 1), + 2 => {/* "==" is expected, the first "=" is already consumed. */ + /* skip over newlines */ + while si < encoded.len() && matches!(encoded[si], b'\r' | b'\n') {si += 1} + + if si == encoded.len() { + unreachable!("Illegal base64 data at input byte {}: not enough padding", encoded.len()) + } else if padding != Some(encoded[si]) { + unreachable!("Illegal base64 data at input byte {}: incorrect padding", si - 1) + } + + si += 1 + } + _ => () + } + + /* skip over newlines */ + while si < encoded.len() && matches!(encoded[si], b'\r' | b'\n') {si += 1} + if si < encoded.len() { + unreachable!("Illegal base64 data at input byte {}: trailing garbage", si) + } + d_len = i; + break + } + + let val = (d_buf[0] as usize)<<18 | (d_buf[1] as usize)<<12 | (d_buf[2] as usize)<<6 | (d_buf[3] as usize); + (d_buf[2], d_buf[1], d_buf[0]) = ((val>>0) as u8, (val>>8) as u8, (val>>16) as u8); + if d_len >= 4 { + dst[2] = d_buf[2]; + d_buf[2] = 0; + } + if d_len >= 3 { + dst[1] = d_buf[1]; + d_buf[1] = 0; + } + if d_len >= 2 { + dst[0] = d_buf[0]; + } + + (si, d_len - 1) + } + + // ================================================== + + let mut decoded = { + let max_len = encoded.len() / 4 * 3 + + padding.is_none().then_some(encoded.len() % 4 * 6 / 8).unwrap_or(0); + vec![u8::default(); max_len] + }; + if decoded.is_empty() {return decoded} + + let decode_map = { + let mut map =[0xff; 256]; + for (i, &byte) in encode_map.iter().enumerate() { + map[byte as usize] = i as u8 + } + map + }; + + let mut si = 0; + let mut n = 0; + + #[cfg(target_pointer_width = "64")] + while encoded.len() - si >= 8 && decoded.len() - n >= 8 { + let encoded2: [_; 8] = encoded[si..(si + 8)].try_into().unwrap(); + if let Some(dn) = assemble64(encoded2.map(|byte| decode_map[byte as usize])) { + decoded[n..(n + 8)].copy_from_slice(&dn.to_be_bytes()); + si += 8; + n += 6; + } else { + let (new_si, n_inc) = decode_quantum(&mut decoded[n..], encoded, si, &decode_map, padding); + si = new_si; + n += n_inc; + } + } + + while encoded.len() - si >= 4 && decoded.len() - n >= 4 { + let encoded2: [_; 4] = encoded[si..(si + 4)].try_into().unwrap(); + if let Some(dn) = assemble32(encoded2.map(|byte| decode_map[byte as usize])) { + decoded[n..(n + 4)].copy_from_slice(&dn.to_be_bytes()); + si += 4; + n += 3; + } else { + let (new_si, n_inc) = decode_quantum(&mut decoded[n..], encoded, si, &decode_map, padding); + si = new_si; + n += n_inc; + } + } + + while si < encoded.len() { + let (new_si, n_inc) = decode_quantum(&mut decoded[n..], encoded, si, &decode_map, padding); + si = new_si; + n += n_inc; + } + + decoded.truncate(n); + decoded +} + + + + +#[cfg(test)] mod test { + type Src = &'static [u8]; + type Encoded = &'static str; + + const CASES: &[(Src, Encoded)] = &[ + // RFC 3548 examples + (b"\x14\xfb\x9c\x03\xd9\x7e", "FPucA9l+"), + (b"\x14\xfb\x9c\x03\xd9", "FPucA9k="), + (b"\x14\xfb\x9c\x03", "FPucAw=="), + + // RFC 4648 examples + (b"", ""), + (b"f", "Zg=="), + (b"fo", "Zm8="), + (b"foo", "Zm9v"), + (b"foob", "Zm9vYg=="), + (b"fooba", "Zm9vYmE="), + (b"foobar", "Zm9vYmFy"), + + // Wikipedia examples + (b"sure.", "c3VyZS4="), + (b"sure", "c3VyZQ=="), + (b"sur", "c3Vy"), + (b"su", "c3U="), + (b"leasure.", "bGVhc3VyZS4="), + (b"easure.", "ZWFzdXJlLg=="), + (b"asure.", "YXN1cmUu"), + ]; + + #[test] fn test_encode() { + for (src, encoded) in CASES { + assert_eq!(super::encode(src), *encoded); + } + } + + #[test] fn test_decode() { + for (original, encoded) in CASES { + let (actual, expected) = (super::decode(encoded.as_bytes()), original); + if actual != *expected { + panic!("\n\ + \0 actual: `{}`\n\ + expected: `{}`\n\ + ", actual.escape_ascii(), expected.escape_ascii()) + } + } + } +} + diff --git a/ohkami/src/layer0_lib/hmac_sha256.rs b/ohkami/src/layer0_lib/hmac_sha256.rs new file mode 100644 index 00000000..6cba7c10 --- /dev/null +++ b/ohkami/src/layer0_lib/hmac_sha256.rs @@ -0,0 +1,500 @@ +use std::borrow::Cow; + +const CHUNK: usize = 64; +const SIZE: usize = 32/* 256 bits */; +const BLOCK_SIZE: usize = 64; + +#[allow(non_camel_case_types)] +pub struct HMAC_SHA256 { + opad: Vec, + ipad: Vec, + outer: SHA256, + inner: SHA256, +} + +impl HMAC_SHA256 { + pub fn new(secret_key: impl AsRef<[u8]>) -> Self { + let mut secret_key = Cow::<'_, [u8]>::Borrowed(secret_key.as_ref()); + + let mut this = HMAC_SHA256 { + opad: vec![0; BLOCK_SIZE], + ipad: vec![0; BLOCK_SIZE], + outer: SHA256::new(), + inner: SHA256::new(), + }; + + if secret_key.len() > BLOCK_SIZE { + this.outer.write(&secret_key); + secret_key = Cow::Owned(this.outer.clone().sum().to_vec()); + } + this.ipad[..secret_key.len()].copy_from_slice(&secret_key); + this.opad[..secret_key.len()].copy_from_slice(&secret_key); + for p in &mut this.ipad { + *p ^= 0x36; + } + for p in &mut this.opad { + *p ^= 0x5c; + } + this.inner.write(&this.ipad); + + this + } + + #[inline] pub fn write(&mut self, p: &[u8]) { + self.inner.write(p) + } + + #[inline] pub fn sum(self) -> [u8; SIZE] { + let Self { opad, ipad:_, mut outer, inner } = self; + + let in_sum = inner.sum(); + + outer.reset(); + outer.write(&opad); + + outer.write(&in_sum); + outer.sum() + } +} + + +#[derive(Clone)] +pub struct SHA256 { + h: [u32; 8], + x: [u8; CHUNK], + nx: usize, + len: usize, +} + +impl SHA256 { + #[inline] pub const fn new() -> Self { + Self { + h: [ + 0x6A09E667, + 0xBB67AE85, + 0x3C6EF372, + 0xA54FF53A, + 0x510E527F, + 0x9B05688C, + 0x1F83D9AB, + 0x5BE0CD19, + ], + x: [0; CHUNK], + nx: 0, + len: 0, + } + } + + pub fn write(&mut self, mut p: &[u8]) { + let nn = p.len(); + + self.len += nn; + + if self.nx > 0 { + let n = usize::min(self.x.len() - self.nx, p.len()); + self.x[self.nx..(self.nx + n)].copy_from_slice(&p[..n]); + self.nx += n; + if self.nx == CHUNK { + self.block(&self.x.clone()); + self.nx = 0; + } + p = &p[n..]; + } + if p.len() >= CHUNK { + let n = p.len() & (!(CHUNK - 1)); + self.block(&p[..n]); + p = &p[n..]; + } + if p.len() > 0 { + let n = usize::min(self.x.len(), p.len()); + self.x[..n].copy_from_slice(&p[..n]); + self.nx = n; + } + } + + #[inline] pub fn sum(mut self) -> [u8; SIZE] { + let mut len = self.len; + + let mut tmp = [u8::default(); 64+8]; + tmp[0] = 0x80; + + let t = if len%64 < 56 { + 56 - len%64 + } else { + 64 + 56 - len%64 + }; + + len <<= 3; // Length in bits + tmp[t..t+8].copy_from_slice(&(len as u64).to_be_bytes()); + self.write(&tmp[..t+8]); + + debug_assert!(self.nx == 0, "`self.nx` is not 0"); + + let mut digest = [u8::default(); 32]; + digest[0.. 4 ].copy_from_slice(&self.h[0].to_be_bytes()); + digest[4.. 8 ].copy_from_slice(&self.h[1].to_be_bytes()); + digest[8.. 12].copy_from_slice(&self.h[2].to_be_bytes()); + digest[12..16].copy_from_slice(&self.h[3].to_be_bytes()); + digest[16..20].copy_from_slice(&self.h[4].to_be_bytes()); + digest[20..24].copy_from_slice(&self.h[5].to_be_bytes()); + digest[24..28].copy_from_slice(&self.h[6].to_be_bytes()); + digest[28..32].copy_from_slice(&self.h[7].to_be_bytes()); + digest + } +} + +impl SHA256 { + #[inline] fn reset(&mut self) { + *self = Self { + x: self.x, + ..Self::new() + } + } +} + +impl SHA256 { + fn block(&mut self, mut p: &[u8]) { + const K: [u32; 64] = [ + 0x428a2f98, + 0x71374491, + 0xb5c0fbcf, + 0xe9b5dba5, + 0x3956c25b, + 0x59f111f1, + 0x923f82a4, + 0xab1c5ed5, + 0xd807aa98, + 0x12835b01, + 0x243185be, + 0x550c7dc3, + 0x72be5d74, + 0x80deb1fe, + 0x9bdc06a7, + 0xc19bf174, + 0xe49b69c1, + 0xefbe4786, + 0x0fc19dc6, + 0x240ca1cc, + 0x2de92c6f, + 0x4a7484aa, + 0x5cb0a9dc, + 0x76f988da, + 0x983e5152, + 0xa831c66d, + 0xb00327c8, + 0xbf597fc7, + 0xc6e00bf3, + 0xd5a79147, + 0x06ca6351, + 0x14292967, + 0x27b70a85, + 0x2e1b2138, + 0x4d2c6dfc, + 0x53380d13, + 0x650a7354, + 0x766a0abb, + 0x81c2c92e, + 0x92722c85, + 0xa2bfe8a1, + 0xa81a664b, + 0xc24b8b70, + 0xc76c51a3, + 0xd192e819, + 0xd6990624, + 0xf40e3585, + 0x106aa070, + 0x19a4c116, + 0x1e376c08, + 0x2748774c, + 0x34b0bcb5, + 0x391c0cb3, + 0x4ed8aa4a, + 0x5b9cca4f, + 0x682e6ff3, + 0x748f82ee, + 0x78a5636f, + 0x84c87814, + 0x8cc70208, + 0x90befffa, + 0xa4506ceb, + 0xbef9a3f7, + 0xc67178f2, + ]; + + let mut w = [u32::default(); 64]; + let [mut h0, mut h1, mut h2, mut h3, mut h4, mut h5, mut h6, mut h7] = self.h; + + while p.len() >= CHUNK { + for i in 0..16 { + let j = i * 4; + w[i] = (p[j] as u32)<<24 | (p[j+1] as u32)<<16 | (p[j+2] as u32)<<8 | (p[j+3] as u32); + } + for i in 16..64 { + let v1 = w[i-2]; + let t1 = v1.rotate_right(17) ^ v1.rotate_right(19) ^ (v1 >> 10); + + let v2 = w[i-15]; + let t2 = v2.rotate_right(7) ^ v2.rotate_right(18) ^ (v2 >> 3); + + w[i] = (t1).wrapping_add(w[i-7]).wrapping_add(t2).wrapping_add(w[i-16]); + } + + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = [h0, h1, h2, h3, h4, h5, h6, h7]; + + for i in 0..64 { + let t1 = (h) + .wrapping_add(e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25)) + .wrapping_add((e & f) ^ (!e & g)) + .wrapping_add(K[i]) + .wrapping_add(w[i]); + let t2 = (a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22)) + .wrapping_add((a & b) ^ (a & c) ^ (b & c)); + + h = g; + g = f; + f = e; + e = (d).wrapping_add(t1); + d = c; + c = b; + b = a; + a = (t1).wrapping_add(t2); + } + + h0 = (h0).wrapping_add(a); + h1 = (h1).wrapping_add(b); + h2 = (h2).wrapping_add(c); + h3 = (h3).wrapping_add(d); + h4 = (h4).wrapping_add(e); + h5 = (h5).wrapping_add(f); + h6 = (h6).wrapping_add(g); + h7 = (h7).wrapping_add(h); + + p = &p[CHUNK..]; + } + + self.h = [h0, h1, h2, h3, h4, h5, h6, h7]; + } +} + + + + +#[cfg(test)] mod test { + use super::{SHA256, HMAC_SHA256}; + + #[test] fn test_sha256() { + for (expected/* hex literal */, input) in [ + ("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", ""), + ("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb", "a"), + ("fb8e20fc2e4c3f248c60c39bd652f3c1347298bb977b8b4d5903b85055620603", "ab"), + ("ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad", "abc"), + ("88d4266fd4e6338d13b845fcf289579d209c897823b9217da3e161936f031589", "abcd"), + ("36bbe50ed96841d10443bcb670d6554f0a34b761be67ec9c4a8ad2c0c44ca42c", "abcde"), + ("bef57ec7f53a6d40beb640a780a639c83bc29ac8a9816f1fc6c5c6dcd93c4721", "abcdef"), + ("7d1a54127b222502f5b79b5fb0803061152a44f92b37e23c6527baf665d4da9a", "abcdefg"), + ("9c56cc51b374c3ba189210d5b6d4bf57790d351c96c47c02190ecf1e430635ab", "abcdefgh"), + ("19cc02f26df43cc571bc9ed7b0c4d29224a3ec229529221725ef76d021c8326f", "abcdefghi"), + ("72399361da6a7754fec986dca5b7cbaf1c810a28ded4abaf56b2106d06cb78b0", "abcdefghij"), + ("a144061c271f152da4d151034508fed1c138b8c976339de229c3bb6d4bbb4fce", "Discard medicine more than two years old."), + ("6dae5caa713a10ad04b46028bf6dad68837c581616a1589a265a11288d4bb5c4", "He who has a shady past knows that nice guys finish last."), + ("ae7a702a9509039ddbf29f0765e70d0001177914b86459284dab8b348c2dce3f", "I wouldn't marry him with a ten foot pole."), + ("6748450b01c568586715291dfa3ee018da07d36bb7ea6f180c1af6270215c64f", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave"), + ("14b82014ad2b11f661b5ae6a99b75105c2ffac278cd071cd6c05832793635774", "The days of the digital watch are numbered. -Tom Stoppard"), + ("7102cfd76e2e324889eece5d6c41921b1e142a4ac5a2692be78803097f6a48d8", "Nepal premier won't resign."), + ("23b1018cd81db1d67983c5f7417c44da9deb582459e378d7a068552ea649dc9f", "For every action there is an equal and opposite government program."), + ("8001f190dfb527261c4cfcab70c98e8097a7a1922129bc4096950e57c7999a5a", "His money is twice tainted: 'taint yours and 'taint mine."), + ("8c87deb65505c3993eb24b7a150c4155e82eee6960cf0c3a8114ff736d69cad5", "There is no reason for any individual to have a computer in their home. -Ken Olsen, 1977"), + ("bfb0a67a19cdec3646498b2e0f751bddc41bba4b7f30081b0b932aad214d16d7", "It's a tiny change to the code and not completely disgusting. - Bob Manchek"), + ("7f9a0b9bf56332e19f5a0ec1ad9c1425a153da1c624868fda44561d6b74daf36", "size: a.out: bad magic"), + ("b13f81b8aad9e3666879af19886140904f7f429ef083286195982a7588858cfc", "The major problem is with sendmail. -Mark Horton"), + ("b26c38d61519e894480c70c8374ea35aa0ad05b2ae3d6674eec5f52a69305ed4", "Give me a rock, paper and scissors and I will move the world. CCFestoon"), + ("049d5e26d4f10222cd841a119e38bd8d2e0d1129728688449575d4ff42b842c1", "If the enemy is within range, then so are you."), + ("0e116838e3cc1c1a14cd045397e29b4d087aa11b0853fc69ec82e90330d60949", "It's well we cannot hear the screams/That we create in others' dreams."), + ("4f7d8eb5bcf11de2a56b971021a444aa4eafd6ecd0f307b5109e4e776cd0fe46", "You remind me of a TV show, but that's all right: I watch it anyway."), + ("61c0cc4c4bd8406d5120b3fb4ebc31ce87667c162f29468b3c779675a85aebce", "C is as portable as Stonehedge!!"), + ("1fb2eb3688093c4a3f80cd87a5547e2ce940a4f923243a79a2a1e242220693ac", "Even if I could be Shakespeare, I think I should still choose to be Faraday. - A. Huxley"), + ("395585ce30617b62c80b93e8208ce866d4edc811a177fdb4b82d3911d8696423", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule"), + ("4f9b189a13d030838269dce846b16a1ce9ce81fe63e65de2f636863336a98fe6", "How can you write a big system without C++? -Paul Glick"), + ] { + let sum = std::array::from_fn(|i| i).map(|i| + [expected.as_bytes()[2*i], expected.as_bytes()[2*i+1]].map(|b| match b { + b'0'..=b'9' => b - b'0', + b'a'..=b'f' => 10 + b - b'a', + _ => unreachable!() + }).into_iter().fold(0, |byte, b| byte * 2u8.pow(4) + b) + ); + + let mut s = SHA256::new(); + s.write(input.as_bytes()); + assert_eq!(s.sum(), sum); + } + } + + #[test] fn test_hmac_sha256() { + for (key, input, output_hexliteral) in [ + // Tests from RFC 4231 + ( + [ + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, + ].as_slice(), + "Hi There".as_bytes(), + "b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7".as_bytes(), + ), + ( + "Jefe".as_bytes(), + "what do ya want for nothing?".as_bytes(), + "5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843".as_bytes(), + ), + ( + [ + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, + ].as_slice(), + [ + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, + 0xdd, 0xdd, + ].as_slice(), + "773ea91e36800e46854db8ebd09181a72959098b3ef8c122d9635514ced565fe".as_bytes(), + ), + ( + [ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, + ].as_slice(), + [ + 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, + 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, + 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, + 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, + 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, + 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, 0xcd, + 0xcd, 0xcd, + ].as_slice(), + "82558a389a443c0ea4cc819899f2083a85f0faa3e578f8077a2e3ff46729665b".as_bytes(), + ), + ( + [ + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, + ].as_slice(), + "Test Using Larger Than Block-Size Key - Hash Key First".as_bytes(), + "60e431591ee0b67f0d8a26aacbf5b77f8e0bc6213728c5140546040f0ee37f54".as_bytes(), + ), + ( + [ + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + 0xaa, 0xaa, 0xaa, + ].as_slice(), + "This is a test using a larger than block-size key \ + and a larger than block-size data. The key needs to \ + be hashed before being used by the HMAC algorithm.".as_bytes(), + "9b09ffa71b942fcb27635fbcd5b0e944bfdc63644f0713938a7f51535c3a35e2".as_bytes(), + ), + + // Tests from https://csrc.nist.gov/groups/ST/toolkit/examples.html + // (truncated tag tests are left out) + ( + [ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, + ].as_slice(), + "Sample message for keylen=blocklen".as_bytes(), + "8bb9a1db9806f20df7f77b82138c7914d174d59e13dc4d0169c9057b133e1d62".as_bytes(), + ), + ( + [ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + ].as_slice(), + "Sample message for keylen b - b'0', + b'a'..=b'f' => 10 + b - b'a', + _ => unreachable!() + }).into_iter().fold(0, |byte, b| byte * 2u8.pow(4) + b) + ); + + let mut hs = HMAC_SHA256::new(key); + hs.write(input); + assert_eq!(hs.sum(), expected); + } + } +} diff --git a/ohkami/src/layer0_lib/mod.rs b/ohkami/src/layer0_lib/mod.rs index fcd9737d..4b892e51 100644 --- a/ohkami/src/layer0_lib/mod.rs +++ b/ohkami/src/layer0_lib/mod.rs @@ -4,9 +4,14 @@ pub(crate) use list::List; mod slice; pub(crate) use slice::{Slice, CowSlice}; +pub(crate) mod base64; + mod status; pub use status::Status; +mod hmac_sha256; +pub(crate) use hmac_sha256::{HMAC_SHA256}; + mod method; pub use method::Method; diff --git a/ohkami/src/layer2_context/mod.rs b/ohkami/src/layer2_context/mod.rs index 81d04553..79b1748b 100644 --- a/ohkami/src/layer2_context/mod.rs +++ b/ohkami/src/layer2_context/mod.rs @@ -1,5 +1,9 @@ #![allow(non_snake_case)] +mod store; + +use store::Store; + use crate::{ layer0_lib::{Status, server_header}, layer1_req_res::{Response}, @@ -33,12 +37,20 @@ pub struct Context { pub(crate) upgrade_id: Option, pub headers: server_header::Headers, + store: Store, } impl Context { #[inline(always)] pub fn set_headers(&mut self) -> server_header::SetHeaders<'_> { self.headers.set() } + + #[inline] pub fn store(&mut self, value: Value) { + self.store.insert(value) + } + #[inline] pub fn get(&self) -> Option<&Value> { + self.store.get() + } } impl Context { #[inline] pub(crate) fn new() -> Self { @@ -47,6 +59,7 @@ impl Context { upgrade_id: None, headers: server_header::Headers::new(), + store: Store::new(), } } } diff --git a/ohkami/src/layer2_context/store.rs b/ohkami/src/layer2_context/store.rs new file mode 100644 index 00000000..087779e2 --- /dev/null +++ b/ohkami/src/layer2_context/store.rs @@ -0,0 +1,51 @@ +use std::{ + any::{Any, TypeId}, + collections::HashMap, + hash::{Hasher, BuildHasherDefault}, +}; + + +pub struct Store( + Option, + BuildHasherDefault, + > + >> +); + +#[derive(Default)] +struct TypeIDHasger(u64); +impl Hasher for TypeIDHasger { + fn write(&mut self, _: &[u8]) { + unsafe {std::hint::unreachable_unchecked()} + } + + #[inline] fn write_u64(&mut self, type_id_value: u64) { + self.0 = type_id_value + } + #[inline] fn finish(&self) -> u64 { + self.0 + } +} + + +impl Store { + pub(super) fn new() -> Self { + Self(None) + } +} + +impl Store { + pub fn insert(&mut self, value: Value) { + self.0.get_or_insert_with(|| Box::new(HashMap::default())) + .insert(TypeId::of::(), Box::new(value)); + } + + pub fn get(&self) -> Option<&Value> { + self.0.as_ref() + .and_then(|map| map.get(&TypeId::of::())) + .and_then(|boxed| boxed.downcast_ref()) + } +} diff --git a/ohkami/src/layer3_fang_handler/fang/mod.rs b/ohkami/src/layer3_fang_handler/fang/mod.rs index 87855286..89501371 100644 --- a/ohkami/src/layer3_fang_handler/fang/mod.rs +++ b/ohkami/src/layer3_fang_handler/fang/mod.rs @@ -44,10 +44,14 @@ impl Fang { /// /// ## available `f` signatures /// +///
+/// /// #### To make *back fang*: /// - `Fn(&Response)` /// - `Fn(Response) -> Response` /// +///
+/// /// #### To make *front fang*: /// - `Fn( {&/&mut Context} )` /// - `Fn( {&/&mut Request} )` diff --git a/ohkami/src/layer5_ohkami/with_fangs.rs b/ohkami/src/layer5_ohkami/with_fangs.rs index 01766f57..7b039f1d 100644 --- a/ohkami/src/layer5_ohkami/with_fangs.rs +++ b/ohkami/src/layer5_ohkami/with_fangs.rs @@ -22,10 +22,15 @@ use crate::{ ///
/// /// ## fang schema +/// +///
+/// /// #### To make *back fang*: /// - `Fn(&Response)` /// - `Fn(Response) -> Response` /// +///
+/// /// #### To make *front fang*: /// - `Fn( {&/&mut Context} )` /// - `Fn( {&/&mut Request} )` diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 83b28c34..2bd11b2d 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -277,7 +277,7 @@ pub mod http { } pub mod utils { - pub use crate::x_utils ::{now, CORS}; + pub use crate::x_utils ::{now, CORS, JWT}; pub use crate::layer0_lib ::append; pub use crate::layer1_req_res::File; pub use ohkami_macros ::{Query, Payload}; diff --git a/ohkami/src/x_utils/cors.rs b/ohkami/src/x_utils/cors.rs index a2368c99..10dfbb21 100644 --- a/ohkami/src/x_utils/cors.rs +++ b/ohkami/src/x_utils/cors.rs @@ -1,12 +1,15 @@ #![allow(non_snake_case)] -use crate::{http::Method, IntoFang, Fang, Context, Response, Request}; +pub use internal::CORS; +pub const fn CORS(AllowOrigin: &'static str) -> internal::CORS { + #[cfg(test)] { + const fn assert_into_fang() {} + assert_into_fang::(); + } -#[allow(non_snake_case)] -pub const fn CORS(AllowOrigin: &'static str) -> CORS { - CORS { - AllowOrigin: AccessControlAllowOrigin::from_literal(AllowOrigin), + internal::CORS { + AllowOrigin: internal::AccessControlAllowOrigin::from_literal(AllowOrigin), AllowCredentials: false, AllowMethods: None, AllowHeaders: None, @@ -15,131 +18,136 @@ pub const fn CORS(AllowOrigin: &'static str) -> CORS { } } -pub struct CORS { - pub(crate) AllowOrigin: AccessControlAllowOrigin, - pub(crate) AllowCredentials: bool, - pub(crate) AllowMethods: Option>, - pub(crate) AllowHeaders: Option>, - pub(crate) ExposeHeaders: Option>, - pub(crate) MaxAge: Option, -} +mod internal { + use crate::{http::Method, IntoFang, Fang, Context, Response, Request}; -pub(crate) enum AccessControlAllowOrigin { - Any, - Only(&'static str), -} impl AccessControlAllowOrigin { - #[inline(always)] pub(crate) const fn is_any(&self) -> bool { - match self { - Self::Any => true, - _ => false, - } + + pub struct CORS { + pub(crate) AllowOrigin: AccessControlAllowOrigin, + pub(crate) AllowCredentials: bool, + pub(crate) AllowMethods: Option>, + pub(crate) AllowHeaders: Option>, + pub(crate) ExposeHeaders: Option>, + pub(crate) MaxAge: Option, } - #[inline(always)] pub(crate) const fn from_literal(lit: &'static str) -> Self { - match lit.as_bytes() { - b"*" => Self::Any, - origin => Self::Only(unsafe{std::str::from_utf8_unchecked(origin)}), + pub(crate) enum AccessControlAllowOrigin { + Any, + Only(&'static str), + } impl AccessControlAllowOrigin { + #[inline(always)] pub(crate) const fn is_any(&self) -> bool { + match self { + Self::Any => true, + _ => false, + } } - } - #[inline(always)] pub(crate) const fn as_str(&self) -> &'static str { - match self { - Self::Any => "*", - Self::Only(origin) => origin, + #[inline(always)] pub(crate) const fn from_literal(lit: &'static str) -> Self { + match lit.as_bytes() { + b"*" => Self::Any, + origin => Self::Only(unsafe{std::str::from_utf8_unchecked(origin)}), + } } - } - #[inline(always)] pub(crate) fn matches(&self, origin: &str) -> bool { - match self { - Self::Any => true, - Self::Only(o) => *o == origin, + #[inline(always)] pub(crate) const fn as_str(&self) -> &'static str { + match self { + Self::Any => "*", + Self::Only(origin) => origin, + } } - } -} -impl CORS { - pub const fn AllowCredentials(mut self) -> Self { - if self.AllowOrigin.is_any() { - panic!("\ - The value of the 'Access-Control-Allow-Origin' header in the response \ - must not be the wildcard '*' when the request's credentials mode is 'include'.\ - ") + #[inline(always)] pub(crate) fn matches(&self, origin: &str) -> bool { + match self { + Self::Any => true, + Self::Only(o) => *o == origin, + } } - self.AllowCredentials = true; - self - } - pub fn AllowMethods(mut self, methods: [Method; N]) -> Self { - self.AllowMethods = Some(methods.to_vec()); - self - } - pub fn AllowHeaders(mut self, headers: [&'static str; N]) -> Self { - self.AllowHeaders = Some(headers.to_vec()); - self } - pub fn ExposeHeaders(mut self, headers: [&'static str; N]) -> Self { - self.ExposeHeaders = Some(headers.to_vec()); - self - } - pub fn MaxAge(mut self, delta_seconds: u32) -> Self { - self.MaxAge = Some(delta_seconds); - self - } -} -impl IntoFang for CORS { - const METHODS: &'static [Method] = &[Method::OPTIONS]; - - fn into_fang(self) -> Fang { - #[cold] fn __forbid_cors(c: &Context) -> Result<(), Response> { - Err(c.Forbidden()) + impl CORS { + pub const fn AllowCredentials(mut self) -> Self { + if self.AllowOrigin.is_any() { + panic!("\ + The value of the 'Access-Control-Allow-Origin' header in the response \ + must not be the wildcard '*' when the request's credentials mode is 'include'.\ + ") + } + self.AllowCredentials = true; + self + } + pub fn AllowMethods(mut self, methods: [Method; N]) -> Self { + self.AllowMethods = Some(methods.to_vec()); + self + } + pub fn AllowHeaders(mut self, headers: [&'static str; N]) -> Self { + self.AllowHeaders = Some(headers.to_vec()); + self + } + pub fn ExposeHeaders(mut self, headers: [&'static str; N]) -> Self { + self.ExposeHeaders = Some(headers.to_vec()); + self } + pub fn MaxAge(mut self, delta_seconds: u32) -> Self { + self.MaxAge = Some(delta_seconds); + self + } + } - Fang(move |c: &mut Context, req: &mut Request| -> Result<(), Response> { - c.set_headers() - .AccessControlAllowOrigin(self.AllowOrigin.as_str()) - .AccessControlAllowCredentials(if self.AllowCredentials {"true"} else {"false"}); - if let Some(methods) = &self.AllowMethods { - c.set_headers() - .AccessControlAllowMethods(methods.iter().map(Method::as_str).collect::>().join(",")); - } - if let Some(headers) = &self.AllowHeaders { - c.set_headers() - .AccessControlAllowHeaders(headers.join(",")); - } - if let Some(headers) = &self.ExposeHeaders { - c.set_headers() - .AccessControlExposeHeaders(headers.join(",")); - } + impl IntoFang for CORS { + const METHODS: &'static [Method] = &[Method::OPTIONS]; - let origin = req.headers.Origin().ok_or_else(|| c.BadRequest())?; - if !self.AllowOrigin.matches(origin) { - return __forbid_cors(c) + fn into_fang(self) -> Fang { + #[cold] fn __forbid_cors(c: &Context) -> Result<(), Response> { + Err(c.Forbidden()) } - if req.headers.Authorization().is_some() { - if !self.AllowCredentials { - return __forbid_cors(c) + Fang(move |c: &mut Context, req: &mut Request| -> Result<(), Response> { + c.set_headers() + .AccessControlAllowOrigin(self.AllowOrigin.as_str()) + .AccessControlAllowCredentials(if self.AllowCredentials {"true"} else {"false"}); + if let Some(methods) = &self.AllowMethods { + c.set_headers() + .AccessControlAllowMethods(methods.iter().map(Method::as_str).collect::>().join(",")); + } + if let Some(headers) = &self.AllowHeaders { + c.set_headers() + .AccessControlAllowHeaders(headers.join(",")); + } + if let Some(headers) = &self.ExposeHeaders { + c.set_headers() + .AccessControlExposeHeaders(headers.join(",")); } - } - if let Some(request_method) = req.headers.AccessControlRequestMethod() { - let request_method = Method::from_bytes(request_method.as_bytes()); - let allow_methods = self.AllowMethods.as_ref().ok_or_else(|| c.Forbidden())?; - if !allow_methods.contains(&request_method) { + let origin = req.headers.Origin().ok_or_else(|| c.BadRequest())?; + if !self.AllowOrigin.matches(origin) { return __forbid_cors(c) } - } - if let Some(request_headers) = req.headers.AccessControlRequestHeaders() { - let request_headers = request_headers.split(',').map(|h| h.trim()); - let allow_headers = self.AllowHeaders.as_ref().ok_or_else(|| c.Forbidden())?; - if !request_headers.into_iter().all(|h| allow_headers.contains(&h)) { - return __forbid_cors(c) + if req.headers.Authorization().is_some() { + if !self.AllowCredentials { + return __forbid_cors(c) + } } - } - c.set_headers().Vary("Origin"); - Ok(()) - }) + if let Some(request_method) = req.headers.AccessControlRequestMethod() { + let request_method = Method::from_bytes(request_method.as_bytes()); + let allow_methods = self.AllowMethods.as_ref().ok_or_else(|| c.Forbidden())?; + if !allow_methods.contains(&request_method) { + return __forbid_cors(c) + } + } + + if let Some(request_headers) = req.headers.AccessControlRequestHeaders() { + let request_headers = request_headers.split(',').map(|h| h.trim()); + let allow_headers = self.AllowHeaders.as_ref().ok_or_else(|| c.Forbidden())?; + if !request_headers.into_iter().all(|h| allow_headers.contains(&h)) { + return __forbid_cors(c) + } + } + + c.set_headers().Vary("Origin"); + Ok(()) + }) + } } } diff --git a/ohkami/src/x_utils/jwt.rs b/ohkami/src/x_utils/jwt.rs new file mode 100644 index 00000000..f5c2d1e0 --- /dev/null +++ b/ohkami/src/x_utils/jwt.rs @@ -0,0 +1,416 @@ +#![allow(non_snake_case, non_camel_case_types)] + +pub use internal::JWT; + +pub fn JWT(secret: impl Into) -> internal::JWT { + internal::JWT::new(secret) +} + +mod internal { + use crate::layer0_lib::{base64, HMAC_SHA256}; + use crate::{Context, Request, Response}; + + + pub struct JWT { + secret: String, + } + impl JWT { + pub fn new(secret: impl Into) -> Self { + Self { + secret: secret.into(), + } + } + } + + impl JWT { + pub fn issue(self, payload: impl ::serde::Serialize) -> String { + let unsigned_token = { + let mut ut = base64::encode_url("{\"typ\":\"JWT\",\"alg\":\"HS256\"}"); + ut.push('.'); + ut.push_str(&base64::encode_url(::serde_json::to_vec(&payload).expect("Failed to serialze payload"))); + ut + }; + + let signature = { + let mut s = HMAC_SHA256::new(self.secret); + s.write(unsigned_token.as_bytes()); + s.sum() + }; + + let mut token = unsigned_token; + token.push('.'); + token.push_str(&base64::encode_url(signature)); + token + } + } + + impl JWT { + /// Verify JWT in requests' `Authorization` header and early return error response if + /// it's missing or malformed. + pub fn verify(&self, c: &Context, req: &Request) -> Result<(), Response> { + self.verified::<()>(c, req) + } + + /// Verify JWT in requests' `Authorization` header and early return error response if + /// it's missing or malformed. + /// + /// Then it's valid, this returns decoded paylaod of the JWT as `Payload`. + pub fn verified serde::Deserialize<'d>>(&self, c: &Context, req: &Request) -> Result { + const UNAUTHORIZED_MESSAGE: &str = "missing or malformed jwt"; + + type Header = ::serde_json::Value; + type Payload = ::serde_json::Value; + + let mut parts = req + .headers.Authorization().ok_or_else(|| c.Unauthorized().text(UNAUTHORIZED_MESSAGE))? + .strip_prefix("Bearer ").ok_or_else(|| c.BadRequest())? + .split('.'); + + let header_part = parts.next() + .ok_or_else(|| c.BadRequest())?; + let header: Header = ::serde_json::from_slice(&base64::decode_url(header_part)) + .map_err(|_| c.InternalServerError())?; + if header.get("typ").is_some_and(|typ| !typ.as_str().unwrap_or_default().eq_ignore_ascii_case("JWT")) { + return Err(c.BadRequest()) + } + if header.get("cty").is_some_and(|cty| !cty.as_str().unwrap_or_default().eq_ignore_ascii_case("JWT")) { + return Err(c.BadRequest()) + } + if header.get("alg").ok_or_else(|| c.BadRequest())? != "HS256" { + return Err(c.BadRequest()) + } + + let payload_part = parts.next() + .ok_or_else(|| c.BadRequest())?; + let payload: Payload = ::serde_json::from_slice(&base64::decode_url(payload_part)) + .map_err(|_| c.InternalServerError())?; + let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); + if payload.get("nbf").is_some_and(|nbf| nbf.as_u64().unwrap_or_default() > now) { + return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) + } + if payload.get("exp").is_some_and(|exp| exp.as_u64().unwrap_or_default() <= now) { + return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) + } + if payload.get("iat").is_some_and(|iat| iat.as_u64().unwrap_or_default() > now) { + return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) + } + + let signature_part = parts.next() + .ok_or_else(|| c.BadRequest())?; + let requested_signature = base64::decode_url(signature_part); + let actual_signature = { + let mut hs = HMAC_SHA256::new(&self.secret); + hs.write(header_part.as_bytes()); + hs.write(b"."); + hs.write(payload_part.as_bytes()); + hs.sum() + }; + if requested_signature != actual_signature { + return Err(c.Unauthorized().text(UNAUTHORIZED_MESSAGE)) + } + + let payload = ::serde_json::from_value(payload).map_err(|_| c.InternalServerError())?; + Ok(payload) + } + } +} + + + + +#[cfg(test)] mod test { + use serde::Deserialize; + + use super::JWT; + use crate::__rt__::test; + + #[test] async fn test_jwt_issue() { + /* NOTE: + `serde_json::to_vec` automatically sorts original object's keys + in alphabetical order. e.t., here + + ``` + json!({"name":"kanarus","id":42,"iat":1516239022}) + ``` + is serialzed to + + ```raw literal + {"iat":1516239022,"id":42,"name":"kanarus"} + ``` + */ + assert_eq! { + JWT("secret").issue(::serde_json::json!({"name":"kanarus","id":42,"iat":1516239022})), + "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpYXQiOjE1MTYyMzkwMjIsImlkIjo0MiwibmFtZSI6ImthbmFydXMifQ.dt43rLwmy4_GA_84LMC1m5CwVc59P9as_nRFldVCH7g" + } + } + + #[test] async fn test_jwt_verify() { + use crate::prelude::*; + use crate::{testing::*, http::*}; + + use std::{sync::OnceLock, collections::HashMap, borrow::Cow}; + use crate::__rt__::Mutex; + + + fn my_jwt() -> JWT { + JWT("myverysecretjwtsecretkey") + } + + #[derive(serde::Serialize, Deserialize)] + struct MyJWTPayload { + iat: u64, + user_id: usize, + } + + fn issue_jwt_for_user(user: &User) -> String { + use std::time::{UNIX_EPOCH, SystemTime}; + + my_jwt().issue(MyJWTPayload { + user_id: user.id, + iat: SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + }) + } + + + async fn repository() -> &'static Mutex> { + static REPOSITORY: OnceLock>> = OnceLock::new(); + + REPOSITORY.get_or_init(|| Mutex::new(HashMap::new())) + } + + #[derive(Clone)] + #[derive(Debug, PartialEq) /* for test */] + struct User { + id: usize, + first_name: String, + familly_name: String, + } impl User { + fn profile(&self) -> Profile { + Profile { + id: self.id, + first_name: &self.first_name, + familly_name: &self.familly_name, + } + } + } + + + #[derive(serde::Serialize, Deserialize, Debug, PartialEq)] + struct Profile<'p> { + id: usize, + first_name: &'p str, + familly_name: &'p str, + } + + #[cfg(feature="nightly")] + async fn get_profile(c: Context) -> Response { + let r = &mut *repository().await.lock().await; + + let jwt_payload = c.get::() + .ok_or_else(|| c.InternalServerError())?; + + let user = r.get(&jwt_payload.user_id) + .ok_or_else(|| c.BadRequest().text("User doesn't exist"))?; + + c.OK().json(user.profile()) + } + #[cfg(not(feature="nightly"))] + async fn get_profile(c: Context) -> Response { + let r = &mut *repository().await.lock().await; + + let Some(jwt_payload) = c.get::() else { + return (|| c.InternalServerError())() + }; + + let Some(user) = r.get(&jwt_payload.user_id) else { + return (|| c.BadRequest().text("User doesn't exist"))() + }; + + c.OK().json(user.profile()) + } + + #[derive(serde::Deserialize, serde::Serialize/* for test */)] + struct SigninRequest<'s> { + first_name: &'s str, + familly_name: &'s str, + } impl<'req> crate::FromRequest<'req> for SigninRequest<'req> { + type Error = std::borrow::Cow<'static, str>; + fn parse(req: &'req Request) -> Result { + serde_json::from_slice( + req.payload().ok_or_else(|| std::borrow::Cow::Borrowed("No payload found"))? + ).map_err(|e| std::borrow::Cow::Owned(e.to_string())) + } + } + + async fn signin(c: Context, body: SigninRequest<'_>) -> Response { + let r = &mut *repository().await.lock().await; + + let user: Cow<'_, User> = match r.iter().find(|(_, u)| + u.first_name == body.first_name && + u.familly_name == body.familly_name + ) { + Some((_, u)) => Cow::Borrowed(u), + None => { + let new_user_id = match r.keys().max() { + Some(max) => max + 1, + None => 1, + }; + + let new_user = User { + id: new_user_id, + first_name: body.first_name.to_string(), + familly_name: body.familly_name.to_string(), + }; + + r.insert(new_user_id, new_user.clone()); + + Cow::Owned(new_user) + } + }; + + c.OK().text(issue_jwt_for_user(&user)) + } + + + struct MyJWTFang(JWT); + impl IntoFang for MyJWTFang { + fn into_fang(self) -> Fang { + Fang(move |c: &mut Context, req: &Request| { + let jwt_payload = self.0.verified::(c, req)?; + c.store(jwt_payload); + Ok(()) + }) + } + } + + let t = Ohkami::new(( + "/signin".By(Ohkami::new( + "/".PUT(signin), + )), + "/profile".By(Ohkami::with(( + MyJWTFang(my_jwt()), + ), ( + "/".GET(get_profile), + ))), + )); + + + let req = TestRequest::PUT("/signin"); + let res = t.oneshot(req).await; + assert_eq!(res.status(), Status::BadRequest); + + let req = TestRequest::GET("/profile"); + let res = t.oneshot(req).await; + assert_eq!(res.status(), Status::Unauthorized); + assert_eq!(res.text(), Some("missing or malformed jwt")); + + + let req = TestRequest::PUT("/signin") + .json(SigninRequest { + first_name: "ohkami", + familly_name: "framework", + }); + let res = t.oneshot(req).await; + assert_eq!(res.status(), Status::OK); + let jwt_1 = dbg!(res.text().unwrap()); + + let req = TestRequest::GET("/profile") + .header("Authorization", format!("Bearer {jwt_1}")); + let res = t.oneshot(req).await; + assert_eq!(res.status(), Status::OK); + assert_eq!(res.json::().unwrap().unwrap(), Profile { + id: 1, + first_name: "ohkami", + familly_name: "framework", + }); + + let req = TestRequest::GET("/profile") + .header("Authorization", format!("Bearer {jwt_1}x")); + let res = t.oneshot(req).await; + assert_eq!(res.status(), Status::Unauthorized); + assert_eq!(res.text(), Some("missing or malformed jwt")); + + + assert_eq! { + &*repository().await.lock().await, + &HashMap::from([ + (1, User { + id: 1, + first_name: format!("ohkami"), + familly_name: format!("framework"), + }), + ]) + } + + + let req = TestRequest::PUT("/signin") + .json(SigninRequest { + first_name: "Leonhard", + familly_name: "Euler", + }); + let res = t.oneshot(req).await; + assert_eq!(res.status(), Status::OK); + let jwt_2 = dbg!(res.text().unwrap()); + + let req = TestRequest::GET("/profile") + .header("Authorization", format!("Bearer {jwt_2}")); + let res = t.oneshot(req).await; + assert_eq!(res.status(), Status::OK); + assert_eq!(res.json::().unwrap().unwrap(), Profile { + id: 2, + first_name: "Leonhard", + familly_name: "Euler", + }); + + + assert_eq! { + &*repository().await.lock().await, + &HashMap::from([ + (1, User { + id: 1, + first_name: format!("ohkami"), + familly_name: format!("framework"), + }), + (2, User { + id: 2, + first_name: format!("Leonhard"), + familly_name: format!("Euler"), + }), + ]) + } + + + let req = TestRequest::GET("/profile") + .header("Authorization", format!("Bearer {jwt_1}")); + let res = t.oneshot(req).await; + assert_eq!(res.status(), Status::OK); + assert_eq!(res.json::().unwrap().unwrap(), Profile { + id: 1, + first_name: "ohkami", + familly_name: "framework", + }); + + let req = TestRequest::GET("/profile") + .header("Authorization", format!("Bearer {jwt_2}0000")); + let res = t.oneshot(req).await; + assert_eq!(res.status(), Status::Unauthorized); + assert_eq!(res.text(), Some("missing or malformed jwt")); + + + assert_eq! { + &*repository().await.lock().await, + &HashMap::from([ + (1, User { + id: 1, + first_name: format!("ohkami"), + familly_name: format!("framework"), + }), + (2, User { + id: 2, + first_name: format!("Leonhard"), + familly_name: format!("Euler"), + }), + ]) + } + } +} diff --git a/ohkami/src/x_utils/mod.rs b/ohkami/src/x_utils/mod.rs index 452deb02..093c136b 100644 --- a/ohkami/src/x_utils/mod.rs +++ b/ohkami/src/x_utils/mod.rs @@ -1,5 +1,8 @@ mod now; pub use now::now; +mod jwt; +pub use jwt::JWT; + mod cors; pub use cors::CORS; diff --git a/ohkami/src/x_websocket/context.rs b/ohkami/src/x_websocket/context.rs index 6262a856..b33eb544 100644 --- a/ohkami/src/x_websocket/context.rs +++ b/ohkami/src/x_websocket/context.rs @@ -1,9 +1,10 @@ use std::{future::Future, borrow::Cow}; use super::websocket::Config; -use super::{WebSocket, sign}; +use super::{WebSocket, sign::Sha1}; use crate::{Response, Context, Request}; use crate::__rt__::{task}; use crate::http::{Method}; +use crate::layer0_lib::{base64}; use super::assume_upgradable; @@ -99,10 +100,10 @@ impl WebSocketContext { handler: impl Fn(WebSocket) -> Fut + Send + Sync + 'static ) -> Response { fn sign(sec_websocket_key: &str) -> String { - let mut sha1 = sign::Sha1::new(); + let mut sha1 = Sha1::new(); sha1.write(sec_websocket_key.as_bytes()); sha1.write(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - sign::Base64::<{sign::SHA1_SIZE}>::encode(sha1.sum()) + base64::encode(sha1.sum()) } let Self { diff --git a/ohkami/src/x_websocket/sign.rs b/ohkami/src/x_websocket/sign.rs index c0dd7ef0..f1fe49d5 100644 --- a/ohkami/src/x_websocket/sign.rs +++ b/ohkami/src/x_websocket/sign.rs @@ -1,5 +1,5 @@ -mod sha1; pub use sha1:: {Sha1, SIZE as SHA1_SIZE}; -mod base64; pub use base64::{Base64}; +mod sha1; +pub(crate) use sha1::{Sha1}; #[cfg(test)] mod sign_test { use super::*; @@ -52,30 +52,4 @@ mod base64; pub use base64::{Base64}; assert_eq!(s.sum(), expected); } } - - #[test] fn test_base64() {// https://github.com/golang/go/blob/master/src/encoding/base64/base64_test.go - // RFC 3548 examples - assert_eq!(Base64::<6>::encode(*b"\x14\xfb\x9c\x03\xd9\x7e"), "FPucA9l+"); - assert_eq!(Base64::<5>::encode(*b"\x14\xfb\x9c\x03\xd9"), "FPucA9k="); - assert_eq!(Base64::<4>::encode(*b"\x14\xfb\x9c\x03"), "FPucAw=="); - - // RFC 4648 examples - assert_eq!(Base64::<0>::encode(*b""), ""); - assert_eq!(Base64::<1>::encode(*b"f"), "Zg=="); - assert_eq!(Base64::<2>::encode(*b"fo"), "Zm8="); - assert_eq!(Base64::<3>::encode(*b"foo"), "Zm9v"); - assert_eq!(Base64::<4>::encode(*b"foob"), "Zm9vYg=="); - assert_eq!(Base64::<5>::encode(*b"fooba"), "Zm9vYmE="); - assert_eq!(Base64::<6>::encode(*b"foobar"), "Zm9vYmFy"); - - // Wikipedia examples - assert_eq!(Base64::<5>::encode(*b"sure."), "c3VyZS4="); - assert_eq!(Base64::<4>::encode(*b"sure"), "c3VyZQ=="); - assert_eq!(Base64::<3>::encode(*b"sur"), "c3Vy"); - assert_eq!(Base64::<2>::encode(*b"su"), "c3U="); - assert_eq!(Base64::<8>::encode(*b"leasure."), "bGVhc3VyZS4="); - assert_eq!(Base64::<7>::encode(*b"easure."), "ZWFzdXJlLg=="); - assert_eq!(Base64::<6>::encode(*b"asure."), "YXN1cmUu"); - assert_eq!(Base64::<5>::encode(*b"sure."), "c3VyZS4="); - } } diff --git a/ohkami/src/x_websocket/sign/base64.rs b/ohkami/src/x_websocket/sign/base64.rs deleted file mode 100644 index 379bad8b..00000000 --- a/ohkami/src/x_websocket/sign/base64.rs +++ /dev/null @@ -1,189 +0,0 @@ -/* https://github.com/golang/go/blob/master/src/encoding/base64/base64.go */ - -const ENCODER: [u8; 64] = *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -const PADDING: u8 = b'='; - -pub struct Base64< - const SRC_SIZE: usize, - #[cfg(feature="nightly")] const DST_SIZE: usize = {(SRC_SIZE + 2) / 3 * 4}, - #[cfg(feature="nightly")] const SRC_SIZE_REM3_0: bool = {SRC_SIZE % 3 == 0}, - #[cfg(feature="nightly")] const SRC_SIZE_REM3_1: bool = {SRC_SIZE % 3 == 1}, - #[cfg(feature="nightly")] const SRC_SIZE_REM3_2: bool = {SRC_SIZE % 3 == 2}, ->; - -#[cfg(feature="nightly")] impl< - const SRC_SIZE: usize, - const DST_SIZE: usize, - const SRC_SIZE_REM3_0: bool, - const SRC_SIZE_REM3_1: bool, - const SRC_SIZE_REM3_2: bool, -> Base64 { - pub fn encode(src: [u8; SRC_SIZE]) -> String { - if SRC_SIZE == 0 {// may deleted by compiler when `SRC_SIZE` is not 0 - return String::new() - } - - #[cfg(feature="nightly")] - let mut dst = vec![0; DST_SIZE]; - - let (mut di, mut si) = (0, 0); - let n = (SRC_SIZE / 3) * 3; // `n` is `SRC_SIZE - (SRC_SIZE % 3)` - while si < n { - let val = (src[si+0] as usize)<<16 | (src[si+1] as usize)<<8 | (src[si+2] as usize); - - dst[di+0] = ENCODER[val>>18&0x3F]; - dst[di+1] = ENCODER[val>>12&0x3F]; - dst[di+2] = ENCODER[val>>6&0x3F]; - dst[di+3] = ENCODER[val&0x3F]; - - si += 3; - di += 4; - } - - if SRC_SIZE_REM3_0 {// may deleted by compiler when `SRC_SIZE` is not a multiple of 3 - return (|| unsafe {String::from_utf8_unchecked(dst)})() - } - - let mut val = (src[si+0] as usize) << 16; - if SRC_SIZE_REM3_2 {// may be deleted by compiler when `SRC_SIZE` is congruent to 2 mod 3 - val |= (src[si+1] as usize) << 8; - } - - dst[di+0] = ENCODER[val>>18&0x3F]; - dst[di+1] = ENCODER[val>>12&0x3F]; - - if SRC_SIZE_REM3_2 {// may be deleted by compiler when `SRC_SIZE` is congruent to 2 mod 3 - dst[di+2] = ENCODER[val>>6&0x3F]; - dst[di+3] = PADDING; - } - if SRC_SIZE_REM3_1 {// may be deleted by compiler when `SRC_SIZE` is congruent to 1 mod 3 - dst[di+2] = PADDING; - dst[di+3] = PADDING; - } - - unsafe {String::from_utf8_unchecked(dst)} - } -} - - -#[cfg(not(feature="nightly"))] impl< - const SRC_SIZE: usize, -> Base64 { - pub fn encode(src: [u8; SRC_SIZE]) -> String { - if SRC_SIZE == 0 {// may deleted by compiler when `SRC_SIZE` is not 0 - return String::new() - } - - let mut dst = vec![0; (SRC_SIZE + 2) / 3 * 4]; - - let (mut di, mut si) = (0, 0); - let n = (SRC_SIZE / 3) * 3; // `n` is `SRC_SIZE - (SRC_SIZE % 3)` - while si < n { - let val = (src[si+0] as usize)<<16 | (src[si+1] as usize)<<8 | (src[si+2] as usize); - - dst[di+0] = ENCODER[val>>18&0x3F]; - dst[di+1] = ENCODER[val>>12&0x3F]; - dst[di+2] = ENCODER[val>>6&0x3F]; - dst[di+3] = ENCODER[val&0x3F]; - - si += 3; - di += 4; - } - - let remain = SRC_SIZE - si; // `remain` is `SRC_SIZE % 3` - if remain == 0 { - return (|| unsafe {String::from_utf8_unchecked(dst)})() - } - - let mut val = (src[si+0] as usize) << 16; - if remain == 2 { - val |= (src[si+1] as usize) << 8; - } - - dst[di+0] = ENCODER[val>>18&0x3F]; - dst[di+1] = ENCODER[val>>12&0x3F]; - - match remain { - 2 => { - dst[di+2] = ENCODER[val>>6&0x3F]; - dst[di+3] = PADDING; - } - 1 => { - dst[di+2] = PADDING; - dst[di+3] = PADDING; - } - _ => unsafe {std::hint::unreachable_unchecked()} - } - - unsafe {String::from_utf8_unchecked(dst)} - } -} - -//} - -//impl Base64 { -// pub fn encode(src: [u8; SRC_SIZE]) -> String { -// if src.len() == 0 { -// return String::new() -// } -// -// let mut dst = vec![0; (src.len() + 2) / 3 * 4]; -// -// let (mut di, mut si) = (0, 0); -// let n = (SRC_SIZE / 3) * 3; // `n` is `SRC_SIZE - (SRC_SIZE % 3)` -// while si < n { -// let val = (src[si+0] as usize)<<16 | (src[si+1] as usize)<<8 | (src[si+2] as usize); -// -// dst[di+0] = ENCODER[val>>18&0x3F]; -// dst[di+1] = ENCODER[val>>12&0x3F]; -// dst[di+2] = ENCODER[val>>6&0x3F]; -// dst[di+3] = ENCODER[val&0x3F]; -// -// si += 3; -// di += 4; -// } -// -// #[cfg(feature="nightly")] if SRC_SIZE_REM3_0 {// may deleted by compiler when `SRC_SIZE` is not a multiple of 3 -// return (|| unsafe {String::from_utf8_unchecked(dst)})() -// } -// -// #[cfg(not(feature="nightly"))] let remain = SRC_SIZE - si; // `remain` is `SRC_SIZE % 3` -// #[cfg(not(feature="nightly"))] { -// if remain == 0 {return (|| unsafe {String::from_utf8_unchecked(dst)})()} -// } -// -// let mut val = (src[si+0] as usize) << 16; -// #[cfg(feature="nightly")] if SRC_SIZE_REM3_2 {// may be deleted by compiler when `SRC_SIZE` is congruent to 2 mod 3 -// val |= (src[si+1] as usize) << 8; -// } -// #[cfg(not(feature="nightly"))] if remain == 2 { -// val |= (src[si+1] as usize) << 8; -// } -// -// dst[di+0] = ENCODER[val>>18&0x3F]; -// dst[di+1] = ENCODER[val>>12&0x3F]; -// -// #[cfg(feature="nightly")] if SRC_SIZE_REM3_2 {// may be deleted by compiler when `SRC_SIZE` is congruent to 2 mod 3 -// dst[di+2] = ENCODER[val>>6&0x3F]; -// dst[di+3] = PADDING; -// } -// #[cfg(feature="nightly")] if SRC_SIZE_REM3_1 {// may be deleted by compiler when `SRC_SIZE` is congruent to 1 mod 3 -// dst[di+2] = PADDING; -// dst[di+3] = PADDING; -// } -// #[cfg(not(feature="nightly"))] match remain { -// 2 => { -// dst[di+2] = ENCODER[val>>6&0x3F]; -// dst[di+3] = PADDING; -// } -// 1 => { -// dst[di+2] = PADDING; -// dst[di+3] = PADDING; -// } -// _ => unsafe {std::hint::unreachable_unchecked()} -// } -// -// unsafe {String::from_utf8_unchecked(dst)} -// } -//} -// \ No newline at end of file