diff --git a/.github/workflows/tls_codec.yml b/.github/workflows/tls_codec.yml index 12519a0ce..c2ba7049f 100644 --- a/.github/workflows/tls_codec.yml +++ b/.github/workflows/tls_codec.yml @@ -71,9 +71,23 @@ jobs: - uses: dtolnay/rust-toolchain@master with: toolchain: ${{ matrix.rust }} - targets: ${{ matrix.target }} + targets: ${{ matrix.targets }} - run: ${{ matrix.deps }} - uses: RustCrypto/actions/cargo-hack-install@master - - run: cargo hack test --feature-powerset - - run: cargo hack test -p tls_codec_derive --feature-powerset --test encode\* --test decode\* - - run: cargo hack test -p tls_codec_derive --feature-powerset --doc + - run: cargo hack test --target ${{ matrix.targets }} --feature-powerset + - run: cargo hack test --target ${{ matrix.targets }} -p tls_codec_derive --feature-powerset --test encode\* --test decode\* + - run: cargo hack test --target ${{ matrix.targets }} -p tls_codec_derive --feature-powerset --doc + - run: cargo test --target ${{ matrix.targets }} --benches + + fuzz: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v7 + - uses: dtolnay/rust-toolchain@nightly + - uses: taiki-e/install-action@v2 + with: + tool: cargo-fuzz + - run: | + for fuzz_target in inverse string deserialize bytes_inverse; do + cargo fuzz run --target x86_64-unknown-linux-gnu "$fuzz_target" -- -max_total_time=5 + done diff --git a/Cargo.lock b/Cargo.lock index 181e5d950..7f493753a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1729,6 +1729,16 @@ dependencies = [ [[package]] name = "tls_codec" version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b" +dependencies = [ + "tls_codec_derive 0.4.2", + "zeroize", +] + +[[package]] +name = "tls_codec" +version = "0.4.3-pre.1" dependencies = [ "arbitrary", "ciborium", @@ -1736,18 +1746,29 @@ dependencies = [ "serde", "serde_bytes", "serde_json", - "tls_codec_derive", + "tls_codec_derive 0.4.3-pre.1", "zeroize", ] [[package]] name = "tls_codec_derive" version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" dependencies = [ "proc-macro2", "quote", "syn", - "tls_codec", +] + +[[package]] +name = "tls_codec_derive" +version = "0.4.3-pre.1" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "tls_codec 0.4.3-pre.1", "trybuild", ] @@ -1972,7 +1993,7 @@ dependencies = [ "signature", "spki", "tempfile", - "tls_codec", + "tls_codec 0.4.2", "tokio", "x509-cert-test-support", ] @@ -2041,6 +2062,20 @@ name = "zeroize" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e13c156562582aa81c60cb29407084cdb54c4164760106ab78e6c5b0858cf64e" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c50655cbb0fe3fc43170059e702f1ce5e19b84cec58dc87b037a09935c2f328" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "zmij" diff --git a/tls_codec/CHANGELOG.md b/tls_codec/CHANGELOG.md index 406f5d407..3c57663aa 100644 --- a/tls_codec/CHANGELOG.md +++ b/tls_codec/CHANGELOG.md @@ -7,12 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## 0.4.3 + - [#2351](https://github.com/RustCrypto/formats/pull/2351): Implement `Size`, `SerializeBytes`, and `DeserializeBytes` for `String` (and `Size`/`SerializeBytes` for `str` / `&str`), encoding the UTF-8 bytes as a `VLByteVec`. Also implement `SerializeBytes` for `ContentLength` and `VLByteSlice`. - [#2322](https://github.com/RustCrypto/formats/pull/2322): Add `VLByteVec` and `SecretVLByteVec`, which are `#[serde(transparent)]` wrappers serializing via `serde_bytes`. They produce a much more compact representation in `serde` formats that distinguish byte arrays from sequences of `u8` (e.g. CBOR, MessagePack, bincode). Their `serde` output is not compatible with `VLBytes` / `SecretVLBytes`, but their `Deserialize` impls are backwards-compatible: in self-describing `serde` formats they also accept the legacy `VLBytes` / `SecretVLBytes` encoding (a struct with a `vec` field containing a sequence of `u8`). Deprecate `VLBytes` and `SecretVLBytes` in favour of `VLByteVec` and `SecretVLByteVec`. - [#1656](https://github.com/RustCrypto/formats/pull/1656) Add `TlsVarInt` type for variable-length integers. ### Fixed - [#2348](https://github.com/RustCrypto/formats/pull/2348) Use `write_all` everywhere instead of write to prevent partial writes from going undetected. The `Error::InvalidWriteLength` variant is deprecated as it is no longer returned. +- [#XXXX](https://github.com/RustCrypto/formats/pull/XXXX) Element-vector deserialization (`Vec`, `TlsVecU*`, `SecretTlsVecU*`), for both the `Deserialize` (`std::io::Read`) and `DeserializeBytes` implementations, now measures actual byte consumption instead of relying on `tls_serialized_len()` and enforces the declared length exactly. This makes the two implementations agree for non-canonical inner encodings (e.g. non-minimal varint lengths) and rejects input whose elements overshoot the declared vector boundary. + - Element vectors now reject zero-length elements that would otherwise never advance the read cursor, preventing an infinite loop and unbounded allocation on malicious input. + - Byte-vector deserialization (`TlsByteVecU*`, `VLBytes`, `VLByteVec`) now uses `checked_add` when computing the content range, returning `Error::InvalidVectorLength` instead of overflowing `usize` on targets where the length field is as wide as the pointer width. + - Read-based byte-vector deserialization (`TlsByteVecU*`, `VLBytes`, `VLByteVec`) no longer eagerly allocates a buffer sized by the untrusted length field, avoiding large allocations from bogus length prefixes. + - Fixed swapped doc comments on the `Deserializable*` / `Undeserializable*` type aliases generated by `#[conditionally_deserializable]`. + - Avoid potential overflows on summing up lengths. ## 0.4.2 diff --git a/tls_codec/Cargo.toml b/tls_codec/Cargo.toml index 32bfec919..b8876af4a 100644 --- a/tls_codec/Cargo.toml +++ b/tls_codec/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tls_codec" -version = "0.4.2" +version = "0.4.3-pre.1" authors = ["RustCrypto Developers"] license = "Apache-2.0 OR MIT" documentation = "https://docs.rs/tls_codec/" @@ -12,13 +12,13 @@ edition = "2024" rust-version = "1.85" [dependencies] -zeroize = { version = "1.8", default-features = false, features = ["alloc"] } +zeroize = { version = "1.9", default-features = false, features = ["alloc"] } # optional dependencies arbitrary = { version = "1.4", features = ["derive"], optional = true } -tls_codec_derive = { version = "=0.4.2", path = "./derive", optional = true } -serde = { version = "1.0.184", features = ["derive"], optional = true } -serde_bytes = { version = "0.11.17", optional = true } +tls_codec_derive = { version = "=0.4.3-pre.1", path = "./derive", optional = true } +serde = { version = "1.0.228", features = ["derive"], optional = true } +serde_bytes = { version = "0.11.19", optional = true } [dev-dependencies] criterion = { version = "0.6", default-features = false } diff --git a/tls_codec/benches/tls_vec.rs b/tls_codec/benches/tls_vec.rs index dea7112c1..4565f18c1 100644 --- a/tls_codec/benches/tls_vec.rs +++ b/tls_codec/benches/tls_vec.rs @@ -48,7 +48,41 @@ fn byte_vector(c: &mut Criterion) { TlsByteSliceU32(&long_vec).tls_serialize_detached().unwrap() }, |serialized_long_vec| { - TlsVecU32::::tls_deserialize(&mut serialized_long_vec.as_slice()).unwrap() + // Decode into the byte-vector type so this exercises the + // bulk `read_bytes_bounded` path, not the generic per-element + // loop (`TlsVecU32`). + TlsByteVecU32::tls_deserialize(&mut serialized_long_vec.as_slice()).unwrap() + }, + BatchSize::SmallInput, + ) + }); +} + +/// Benchmarks the generic per-element deserialize loop with a multi-byte +/// element type, so the loop's per-element bookkeeping isn't hidden behind +/// trivial `u8` decoding. +fn typed_vector(c: &mut Criterion) { + use tls_codec::*; + c.bench_function("TLS Serialize Typed Vector", |b| { + b.iter_batched_ref( + || { + ( + TlsVecU32::from(vec![0x7777u16; N]), + Vec::with_capacity(8 + 2 * N), + ) + }, + |(long_vec, buf)| long_vec.tls_serialize(buf).unwrap(), + BatchSize::SmallInput, + ) + }); + c.bench_function("TLS Deserialize Typed Vector", |b| { + b.iter_batched_ref( + || { + let long_vec = vec![0x7777u16; N]; + TlsSliceU32(&long_vec).tls_serialize_detached().unwrap() + }, + |serialized_long_vec| { + TlsVecU32::::tls_deserialize(&mut serialized_long_vec.as_slice()).unwrap() }, BatchSize::SmallInput, ) @@ -78,6 +112,7 @@ fn slice(c: &mut Criterion) { } fn benchmark(c: &mut Criterion) { vector(c); + typed_vector(c); slice(c); byte_vector(c); byte_slice(c); diff --git a/tls_codec/derive/Cargo.toml b/tls_codec/derive/Cargo.toml index c99730ae2..3f041aa67 100644 --- a/tls_codec/derive/Cargo.toml +++ b/tls_codec/derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tls_codec_derive" -version = "0.4.2" +version = "0.4.3-pre.1" authors = ["RustCrypto Developers"] license = "Apache-2.0 OR MIT" documentation = "https://docs.rs/tls_codec_derive/" diff --git a/tls_codec/derive/src/lib.rs b/tls_codec/derive/src/lib.rs index f132943a2..9a81e445c 100644 --- a/tls_codec/derive/src/lib.rs +++ b/tls_codec/derive/src/lib.rs @@ -27,7 +27,7 @@ //! function returns an `Error::UnknownValue` with a `u64` value of the unknown //! type. //! -//! ``` +//! ```no_run //! # #[cfg(feature = "std")] //! # { //! use tls_codec_derive::{TlsDeserialize, TlsSerialize, TlsSize}; @@ -761,6 +761,35 @@ fn define_discriminant_constants( Ok(quote! { #(#discriminant_constants)* }) } +/// Builds an expression that sums `base` and all `terms` into a `usize`, +/// guarding against overflow only on targets where it can actually occur. +/// +/// On 64-bit targets every length is bounded by addressable memory +/// (`isize::MAX`), so the sum can't overflow `usize`; a plain, branch-free +/// addition is emitted to keep the serialization hot path free of checks. On +/// narrower targets the serialized form of a large structure can carry enough +/// length-prefix / discriminant overhead to exceed `usize::MAX`, so we saturate +/// rather than silently wrapping to a small value (a wrapped length would be +/// written to the wire as a truncated, mismatched length prefix). +/// +/// This mirrors `tls_codec::len_add` but is emitted inline so the helper does +/// not need to be part of `tls_codec`'s public API. +fn sum_lengths(terms: &[TokenStream2], base: TokenStream2) -> TokenStream2 { + quote! { + { + #[cfg(target_pointer_width = "64")] + let __tls_len: usize = #base #(+ #terms)*; + #[cfg(not(target_pointer_width = "64"))] + let __tls_len: usize = { + let mut __tls_len: usize = #base; + #(__tls_len = __tls_len.saturating_add(#terms);)* + __tls_len + }; + __tls_len + } + } +} + #[allow(unused_variables)] fn impl_tls_size(parsed_ast: TlsStruct) -> TokenStream2 { match parsed_ast { @@ -779,13 +808,18 @@ fn impl_tls_size(parsed_ast: TlsStruct) -> TokenStream2 { .iter() .map(|p| p.for_trait("Size")) .collect::>(); + let field_lengths = prefixes + .iter() + .zip(members.iter()) + .map(|(prefix, member)| quote! { #prefix::tls_serialized_len(&self.#member) }) + .collect::>(); + let serialized_len = sum_lengths(&field_lengths, quote! { 0usize }); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); quote! { impl #impl_generics tls_codec::Size for #ident #ty_generics #where_clause { #[inline] fn tls_serialized_len(&self) -> usize { - #(#prefixes::tls_serialized_len(&self.#members) + )* - 0 + #serialized_len } } @@ -811,13 +845,27 @@ fn impl_tls_size(parsed_ast: TlsStruct) -> TokenStream2 { let variant_id = &variant.ident; let members = &variant.members; let bindings = make_n_ids(members.len()); - let prefixes = variant.member_prefixes.iter().map(|p| p.for_trait("Size")).collect::>(); + let prefixes = variant + .member_prefixes + .iter() + .map(|p| p.for_trait("Size")) + .collect::>(); + let field_lengths = prefixes + .iter() + .zip(bindings.iter()) + .map(|(prefix, binding)| quote! { #prefix::tls_serialized_len(#binding) }) + .collect::>(); + let variant_len = sum_lengths(&field_lengths, quote! { 0usize }); quote! { - #ident::#variant_id { #(#members: #bindings,)* } => 0 #(+ #prefixes::tls_serialized_len(#bindings))*, + #ident::#variant_id { #(#members: #bindings,)* } => #variant_len, } }) .collect::>(); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let total_len = sum_lengths( + &[quote! { field_len }], + quote! { core::mem::size_of::<#repr>() }, + ); quote! { impl #impl_generics tls_codec::Size for #ident #ty_generics #where_clause { #[inline] @@ -825,7 +873,7 @@ fn impl_tls_size(parsed_ast: TlsStruct) -> TokenStream2 { let field_len = match self { #(#field_arms)* }; - core::mem::size_of::<#repr>() + field_len + #total_len } } @@ -1414,9 +1462,9 @@ fn impl_conditionally_deserializable(mut annotated_item: ItemStruct) -> TokenStr quote! { #annotated_item - #[doc = #doc_string_deserializable] - #annotated_item_visibility type #undeserializable_ident #original_ty_generics = #annotated_item_ident #undeserializable_ty_generics; #[doc = #doc_string_undeserializable] + #annotated_item_visibility type #undeserializable_ident #original_ty_generics = #annotated_item_ident #undeserializable_ty_generics; + #[doc = #doc_string_deserializable] #annotated_item_visibility type #deserializable_ident #original_ty_generics = #annotated_item_ident #deserializable_ty_generics; #deserialize_implementation diff --git a/tls_codec/fuzz/Cargo.lock b/tls_codec/fuzz/Cargo.lock index b40265728..9c657155e 100644 --- a/tls_codec/fuzz/Cargo.lock +++ b/tls_codec/fuzz/Cargo.lock @@ -94,7 +94,7 @@ dependencies = [ [[package]] name = "tls_codec" -version = "0.4.2" +version = "0.4.3-pre.1" dependencies = [ "arbitrary", "tls_codec_derive", @@ -111,7 +111,7 @@ dependencies = [ [[package]] name = "tls_codec_derive" -version = "0.4.2" +version = "0.4.3-pre.1" dependencies = [ "proc-macro2", "quote", diff --git a/tls_codec/fuzz/Cargo.toml b/tls_codec/fuzz/Cargo.toml index 08c71f89c..40b7f1a11 100644 --- a/tls_codec/fuzz/Cargo.toml +++ b/tls_codec/fuzz/Cargo.toml @@ -32,3 +32,15 @@ name = "string" path = "fuzz_targets/string.rs" test = false doc = false + +[[bin]] +name = "deserialize" +path = "fuzz_targets/deserialize.rs" +test = false +doc = false + +[[bin]] +name = "bytes_inverse" +path = "fuzz_targets/bytes_inverse.rs" +test = false +doc = false diff --git a/tls_codec/fuzz/fuzz_targets/bytes_inverse.rs b/tls_codec/fuzz/fuzz_targets/bytes_inverse.rs new file mode 100644 index 000000000..63e6e4997 --- /dev/null +++ b/tls_codec/fuzz/fuzz_targets/bytes_inverse.rs @@ -0,0 +1,29 @@ +#![no_main] +#![allow(deprecated)] + +//! Round-trip fuzzing of the slice-based [`DeserializeBytes`] path. +//! +//! The existing `inverse` target only exercises the `Read`-based +//! [`Deserialize`] implementation. This one covers `tls_deserialize_bytes` +//! and additionally checks that both paths agree on a valid serialization. + +use libfuzzer_sys::fuzz_target; +use tls_codec::{Deserialize, DeserializeBytes, Serialize, Size, VLBytes}; + +fuzz_target!(|expected: VLBytes| { + let serialized = expected.tls_serialize_detached().unwrap(); + + // Assert that the serialized length matches the predicted length. + assert_eq!(expected.tls_serialized_len(), serialized.len()); + + // Slice-based deserialization round-trips and consumes all bytes. + let (got, remainder) = VLBytes::tls_deserialize_bytes(&serialized).unwrap(); + assert!(remainder.is_empty()); + assert_eq!(expected, got); + + // The `Read`-based path must agree with the slice-based path. + let mut read_slice = serialized.as_slice(); + let got_read = VLBytes::tls_deserialize(&mut read_slice).unwrap(); + assert!(read_slice.is_empty()); + assert_eq!(expected, got_read); +}); diff --git a/tls_codec/fuzz/fuzz_targets/deserialize.rs b/tls_codec/fuzz/fuzz_targets/deserialize.rs new file mode 100644 index 000000000..021f5835c --- /dev/null +++ b/tls_codec/fuzz/fuzz_targets/deserialize.rs @@ -0,0 +1,88 @@ +#![no_main] +#![allow(deprecated)] + +//! Robustness + differential fuzzing of the deserializers. +//! +//! For raw, potentially malformed input this target checks that: +//! * deserializing never panics (only returns `Ok`/`Err`), and +//! * the `Read`-based [`Deserialize`] path and the slice-based +//! [`DeserializeBytes`] path agree: for the same input both either fail, or +//! both succeed with the same value and consume the same number of bytes. +//! * on success, re-serializing honors the [`Size`] contract. + +use libfuzzer_sys::fuzz_target; +use tls_codec::{ + Deserialize, DeserializeBytes, Serialize, Size, TlsByteVecU8, TlsByteVecU16, TlsByteVecU24, + TlsByteVecU32, TlsVarInt, TlsVecU8, TlsVecU16, TlsVecU24, TlsVecU32, VLByteVec, VLBytes, +}; + +/// Run both deserialization paths on `data` for type `$t` and compare them. +macro_rules! differential { + ($t:ty, $data:expr) => {{ + let mut read_slice: &[u8] = $data; + let read_res = <$t as Deserialize>::tls_deserialize(&mut read_slice); + let bytes_res = <$t as DeserializeBytes>::tls_deserialize_bytes($data); + + match (read_res, bytes_res) { + (Ok(a), Ok((b, remainder))) => { + assert_eq!(a, b, "value mismatch for {}", stringify!($t)); + assert_eq!( + read_slice.len(), + remainder.len(), + "consumed-length mismatch for {}", + stringify!($t) + ); + + // The `Size` contract: the predicted length matches what is + // actually written. Note we deliberately do *not* require the + // re-serialized length to equal the number of bytes consumed: + // in non-MLS mode `TlsVarInt` accepts non-canonical (non-minimal) + // encodings that re-serialize to fewer bytes. + let serialized = a.tls_serialize_detached().unwrap(); + assert_eq!( + serialized.len(), + a.tls_serialized_len(), + "serialized length mismatch for {}", + stringify!($t) + ); + } + (Err(_), Err(_)) => {} + (read_res, bytes_res) => panic!( + "Ok/Err divergence for {}: read_ok={}, bytes_ok={}", + stringify!($t), + read_res.is_ok(), + bytes_res.is_ok() + ), + } + }}; +} + +fuzz_target!(|data: &[u8]| { + // Fixed-size primitives. + differential!(u8, data); + differential!(u16, data); + differential!(u32, data); + differential!(u64, data); + differential!([u8; 4], data); + + // Length-prefixed byte containers (TLS style). + differential!(TlsVecU8, data); + differential!(TlsVecU16, data); + differential!(TlsVecU24, data); + differential!(TlsVecU32, data); + differential!(TlsByteVecU8, data); + differential!(TlsByteVecU16, data); + differential!(TlsByteVecU24, data); + differential!(TlsByteVecU32, data); + + // Nested length-prefixed vectors. + differential!(TlsVecU16>, data); + + // QUIC-style variable-length encodings. + differential!(VLBytes, data); + differential!(VLByteVec, data); + differential!(TlsVarInt, data); + + // UTF-8 strings. + differential!(String, data); +}); diff --git a/tls_codec/src/lib.rs b/tls_codec/src/lib.rs index fb344afe3..838d7cc20 100644 --- a/tls_codec/src/lib.rs +++ b/tls_codec/src/lib.rs @@ -123,6 +123,88 @@ impl From for Error { } } +/// Read exactly `len` bytes from `reader` into a freshly allocated vector. +/// +/// Unlike `vec![0u8; len]` followed by `read_exact`, this does **not** eagerly +/// allocate `len` bytes up front: the initial allocation is capped so that a +/// bogus (large) length field in untrusted input can't trigger a huge +/// allocation before any bytes are read. The vector grows as data actually +/// arrives. +/// +/// Returns [`Error::EndOfStream`] if the reader is exhausted before `len` bytes +/// are read. +#[cfg(feature = "std")] +fn read_bytes_bounded(reader: &mut R, len: usize) -> Result, Error> { + /// Upper bound on the up-front allocation, so that a bogus (large) length + /// field in untrusted input can't trigger a huge allocation before any + /// bytes are read. + const MAX_PREALLOC: usize = 4096; + + // Cap the initial allocation. `read_to_end` grows the vector as data + // actually arrives, and `Take` bounds the reader to `len` so growth can + // never exceed the request. + let mut result = Vec::with_capacity(core::cmp::min(len, MAX_PREALLOC)); + + // `read_to_end` reads directly into the vector's spare capacity and retries + // `ErrorKind::Interrupted` internally, unlike a bare `read` loop. + reader.take(len as u64).read_to_end(&mut result)?; + + // `Take` caps output at `len`, so a short read means the stream was + // exhausted early. + if result.len() != len { + return Err(Error::EndOfStream); + } + Ok(result) +} + +/// Adds two serialized-length components, guarding against `usize` overflow +/// only on platforms where it can actually occur. +/// +/// On 64-bit targets every length is bounded by the amount of addressable +/// memory (`isize::MAX`), so a sum of serialized lengths can never overflow +/// `usize`. There this compiles down to a plain addition with no branch, +/// keeping the serialization hot path free of overflow checks. +/// +/// On narrower targets (32-bit, 16-bit is not officially supported) +/// the serialized form of a large in-memory structure can carry enough +/// length-prefix / discriminant overhead to exceed `usize::MAX`. +/// There we saturate at `usize::MAX` rather than silently wrapping to a small +/// value, so the oversized length is subsequently rejected by the +/// length-encoding bounds checks instead of producing a truncated, mismatched +/// length prefix on the wire. +/// +/// `tls_codec_derive` emits the equivalent logic inline, so this helper does not +/// need to be part of the public API. +#[inline(always)] +#[cfg(target_pointer_width = "64")] +pub(crate) const fn len_add(a: usize, b: usize) -> usize { + a + b +} + +#[inline(always)] +#[cfg(not(target_pointer_width = "64"))] +pub(crate) const fn len_add(a: usize, b: usize) -> usize { + a.saturating_add(b) +} + +/// Like [`len_add`], but for contexts that can surface an error. +/// +/// On 64-bit targets this is a plain, branch-free addition (see [`len_add`] for +/// why it can't overflow). On narrower targets an overflow becomes +/// [`Error::InvalidVectorLength`] so a wrapped, too-small length is never +/// written to the wire. +#[inline(always)] +#[cfg(target_pointer_width = "64")] +pub(crate) fn checked_len_add(a: usize, b: usize) -> Result { + Ok(a + b) +} + +#[inline(always)] +#[cfg(not(target_pointer_width = "64"))] +pub(crate) fn checked_len_add(a: usize, b: usize) -> Result { + a.checked_add(b).ok_or(Error::InvalidVectorLength) +} + /// The `Size` trait needs to be implemented by any struct that should be /// efficiently serialized. /// This allows to collect the length of a serialized structure before allocating diff --git a/tls_codec/src/quic_vec.rs b/tls_codec/src/quic_vec.rs index 542791d2c..06486036f 100644 --- a/tls_codec/src/quic_vec.rs +++ b/tls_codec/src/quic_vec.rs @@ -92,7 +92,6 @@ impl DeserializeBytes for Vec { #[inline(always)] fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { let (length, mut remainder) = ContentLength::tls_deserialize_bytes(bytes)?; - let len_len = length.0.bytes_len(); let length: usize = length.0.value().try_into()?; if length == 0 { @@ -101,12 +100,29 @@ impl DeserializeBytes for Vec { } let mut result = Vec::new(); - let mut read = len_len; - while (read - len_len) < length { + let mut read = 0usize; + while read < length { let (element, next_remainder) = T::tls_deserialize_bytes(remainder)?; + // Measure how many bytes the element actually consumed from the + // input rather than trusting `tls_serialized_len`. + let consumed = remainder.len() - next_remainder.len(); remainder = next_remainder; - read += element.tls_serialized_len(); result.push(element); + // A zero-length element would never advance `read`, causing an + // infinite loop that keeps allocating. Reject such input. + if consumed == 0 { + return Err(Error::DecodingError( + "Vector element consumed 0 bytes; refusing to loop".into(), + )); + } + read += consumed; + } + // The declared length is authoritative: the elements must consume + // exactly `length` bytes, not overshoot it. + if read != length { + return Err(Error::DecodingError(format!( + "Vector length mismatch: declared {length} bytes but elements consumed {read}" + ))); } Ok((result, remainder)) } @@ -148,11 +164,13 @@ impl SerializeBytes for &[T] { // We need to pre-compute the length of the content. // This requires more computations but the other option would be to buffer // the entire content, which can end up requiring a lot of memory. - let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len()); + let content_length = self.iter().try_fold(0usize, |acc, e| { + crate::checked_len_add(acc, e.tls_serialized_len()) + })?; let length = ContentLength::from_usize(content_length)?; let len_len = length.0.bytes_len(); - let mut out = Vec::with_capacity(content_length + len_len); + let mut out = Vec::with_capacity(crate::checked_len_add(content_length, len_len)?); out.resize(len_len, 0); length.0.write_bytes(&mut out)?; @@ -185,7 +203,9 @@ impl SerializeBytes for Vec { impl Size for &[T] { #[inline(always)] fn tls_serialized_len(&self) -> usize { - let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len()); + let content_length = self + .iter() + .fold(0, |acc, e| crate::len_add(acc, e.tls_serialized_len())); let len_len = ContentLength::from_usize(content_length) .map(|content_length| content_length.0.bytes_len()) .unwrap_or({ @@ -193,7 +213,7 @@ impl Size for &[T] { // trait. Let's say there's no content for now. 0 }); - content_length + len_len + crate::len_add(content_length, len_len) } } @@ -633,18 +653,31 @@ pub mod rw { impl Deserialize for Vec { #[inline(always)] fn tls_deserialize(bytes: &mut R) -> Result { - let (length, len_len) = read_length(bytes)?; + let (length, _len_len) = read_length(bytes)?; if length == 0 { // An empty vector. return Ok(Vec::new()); } + // The declared length is authoritative and delimits the vector's + // content. Bound the reader to exactly `length` bytes and decode + // elements until it is exhausted. This measures actual consumption + // instead of trusting `tls_serialized_len()`, keeping this in sync + // with the `DeserializeBytes` implementation for non-canonical + // encodings (e.g. non-minimal varint lengths). + let mut sub = std::io::Read::take(bytes, length as u64); let mut result = Vec::new(); - let mut read = len_len; - while (read - len_len) < length { - let element = T::tls_deserialize(bytes)?; - read += element.tls_serialized_len(); + while sub.limit() > 0 { + let before = sub.limit(); + let element = T::tls_deserialize(&mut sub)?; + // A zero-length element would never advance the reader, causing + // an infinite loop that keeps allocating. Reject such input. + if sub.limit() == before { + return Err(Error::DecodingError( + "Vector element consumed 0 bytes; refusing to loop".into(), + )); + } result.push(element); } Ok(result) @@ -672,7 +705,9 @@ pub mod rw { // We need to pre-compute the length of the content. // This requires more computations but the other option would be to buffer // the entire content, which can end up requiring a lot of memory. - let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len()); + let content_length = self.iter().try_fold(0usize, |acc, e| { + crate::checked_len_add(acc, e.tls_serialized_len()) + })?; let len_len = write_length(writer, content_length)?; // Serialize the elements @@ -692,7 +727,7 @@ pub mod rw { return Err(Error::LibraryError); } - Ok(content_length + len_len) + crate::checked_len_add(content_length, len_len) } } } @@ -701,7 +736,7 @@ pub mod rw { #[cfg(feature = "std")] mod rw_bytes { use super::*; - use crate::{Deserialize, Serialize}; + use crate::{Deserialize, Serialize, read_bytes_bounded}; #[inline(always)] fn tls_serialize_bytes( @@ -743,11 +778,9 @@ mod rw_bytes { return Ok(Self::new(vec![])); } - let mut result = Self { - vec: vec![0u8; length.0.value().try_into()?], - }; - bytes.read_exact(result.vec.as_mut_slice())?; - Ok(result) + let len: usize = length.0.value().try_into()?; + let vec = read_bytes_bounded(bytes, len)?; + Ok(Self { vec }) } } @@ -773,11 +806,9 @@ mod rw_bytes { return Ok(Self::new(vec![])); } - let mut result = Self { - vec: vec![0u8; length.0.value().try_into()?], - }; - bytes.read_exact(result.vec.as_mut_slice())?; - Ok(result) + let len: usize = length.0.value().try_into()?; + let vec = read_bytes_bounded(bytes, len)?; + Ok(Self { vec }) } } diff --git a/tls_codec/src/tls_vec.rs b/tls_codec/src/tls_vec.rs index 83bdee167..a34cce8d0 100644 --- a/tls_codec/src/tls_vec.rs +++ b/tls_codec/src/tls_vec.rs @@ -22,7 +22,7 @@ macro_rules! impl_size { fn tls_serialized_length(&$self) -> usize { $self.as_slice() .iter() - .fold($len_len, |acc, e| acc + e.tls_serialized_len()) + .fold($len_len, |acc, e| crate::len_add(acc, e.tls_serialized_len())) } } } @@ -53,17 +53,17 @@ macro_rules! impl_byte_deserialize { u16::MAX ))); } - let mut result = Self { - vec: vec![0u8; len], - }; - bytes.read_exact(result.vec.as_mut_slice())?; - Ok(result) + // Read into a bounded buffer rather than allocating `len` bytes up + // front, so an oversized length field can't trigger a huge + // allocation before any bytes are read. + let vec = crate::read_bytes_bounded(bytes, len)?; + Ok(Self { vec }) } #[inline(always)] fn deserialize_bytes_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { let (type_len, remainder) = <$size>::tls_deserialize_bytes(bytes)?; - let len = type_len.try_into().unwrap(); + let len: usize = type_len.try_into().unwrap(); // When fuzzing we limit the maximum size to allocate. // XXX: We should think about a configurable limit for the allocation // here. @@ -74,9 +74,12 @@ macro_rules! impl_byte_deserialize { u16::MAX ))); } - let vec = bytes - .get($len_len..len + $len_len) - .ok_or(Error::EndOfStream)?; + // Use `checked_add` to avoid overflowing `usize` on targets where + // the length field is as wide as (or wider than) the pointer width. + let end = len + .checked_add($len_len) + .ok_or(Error::InvalidVectorLength)?; + let vec = bytes.get($len_len..end).ok_or(Error::EndOfStream)?; let result = Self { vec: vec.to_vec() }; Ok((result, &remainder.get(len..).ok_or(Error::EndOfStream)?)) } @@ -90,11 +93,24 @@ macro_rules! impl_deserialize { fn deserialize(bytes: &mut R) -> Result { let mut result = Self { vec: Vec::new() }; let len = <$size>::tls_deserialize(bytes)?; - let mut read = len.tls_serialized_len(); - let len_len = read; - while (read - len_len) < len.try_into().unwrap() { - let element = T::tls_deserialize(bytes)?; - read += element.tls_serialized_len(); + let length: usize = len.try_into().unwrap(); + // The declared length is authoritative and delimits the vector's + // content. Bound the reader to exactly `length` bytes and decode + // elements until it is exhausted. This measures actual consumption + // instead of trusting `tls_serialized_len()`, keeping this in sync + // with the `DeserializeBytes` implementation for non-canonical + // encodings (e.g. non-minimal varint lengths). + let mut sub = Read::take(bytes, length as u64); + while sub.limit() > 0 { + let before = sub.limit(); + let element = T::tls_deserialize(&mut sub)?; + // A zero-length element would never advance the reader, causing + // an infinite loop that keeps allocating. Reject such input. + if sub.limit() == before { + return Err(Error::DecodingError(format!( + "Vector element consumed 0 bytes; refusing to loop" + ))); + } result.push(element); } Ok(result) @@ -108,13 +124,30 @@ macro_rules! impl_deserialize_bytes { fn deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { let mut result = Self { vec: Vec::new() }; let (len, mut remainder) = <$size>::tls_deserialize_bytes(bytes)?; - let mut read = len.tls_serialized_len(); - let len_len = read; - while (read - len_len) < len.try_into().unwrap() { + let length: usize = len.try_into().unwrap(); + let mut read = 0usize; + while read < length { let (element, next_remainder) = T::tls_deserialize_bytes(remainder)?; + // Measure how many bytes the element actually consumed from the + // input rather than trusting `tls_serialized_len`. + let consumed = remainder.len() - next_remainder.len(); remainder = next_remainder; - read += element.tls_serialized_len(); result.push(element); + // A zero-length element would never advance `read`, causing an + // infinite loop that keeps allocating. Reject such input. + if consumed == 0 { + return Err(Error::DecodingError(alloc::format!( + "Vector element consumed 0 bytes; refusing to loop" + ))); + } + read += consumed; + } + // The declared length is authoritative: the elements must consume + // exactly `length` bytes, not overshoot it. + if read != length { + return Err(Error::DecodingError(alloc::format!( + "Vector length mismatch: declared {length} bytes but elements consumed {read}" + ))); } Ok((result, remainder)) } @@ -169,8 +202,16 @@ macro_rules! impl_serialize_common { ($self:ident, $size:ty, $name:ident, $len_len:literal $(,#[$std_enabled:meta])?) => { $(#[$std_enabled])? fn get_content_lengths(&$self) -> Result<(usize, usize), Error> { - let tls_serialized_len = $self.tls_serialized_len(); - let byte_length = tls_serialized_len - $len_len; + // Sum the element lengths with an overflow check on platforms where + // `usize` is narrow enough for it to matter (see `crate::len_add`). + // Computing `byte_length` directly (rather than deriving it from + // `tls_serialized_len()`) lets us reject a true overflow instead of + // trusting a possibly-saturated value from the `Size` impl. + let byte_length = $self + .as_slice() + .iter() + .try_fold(0usize, |acc, e| crate::checked_len_add(acc, e.tls_serialized_len()))?; + let tls_serialized_len = crate::checked_len_add(byte_length, $len_len)?; let max_len = <$size>::MAX.try_into().unwrap(); debug_assert!( diff --git a/tls_codec/tests/decode.rs b/tls_codec/tests/decode.rs index e6b389e86..b2d560f02 100644 --- a/tls_codec/tests/decode.rs +++ b/tls_codec/tests/decode.rs @@ -4,10 +4,15 @@ #![allow(deprecated)] use tls_codec::{ - Error, Serialize, Size, TlsByteSliceU16, TlsByteVecU8, TlsByteVecU16, TlsSliceU16, TlsVecU8, - TlsVecU16, TlsVecU32, U24, VLByteSlice, VLBytes, + Error, Serialize, Size, TlsByteSliceU16, TlsByteVecU8, TlsByteVecU16, TlsByteVecU32, + TlsSliceU16, TlsVecU8, TlsVecU16, TlsVecU32, U24, VLByteSlice, VLBytes, }; +/// A `VLBytes` element encoded with a *non-minimal* 2-byte varint length prefix +/// (`0x40 0x01`) for a single payload byte (`0xAA`). Actual wire size is 3, but +/// `tls_serialized_len()` reports the canonical size of 2. +const NON_MINIMAL_ELEMENT: &[u8] = &[0x40, 0x01, 0xAA]; + #[test] fn deserialize_primitives() { use tls_codec::Deserialize; @@ -328,3 +333,114 @@ fn deserialize_bytes_empty_vl_bytes() { let b: &[u8] = &[]; VLBytes::tls_deserialize_bytes(b).expect_err("Empty bytes were parsed successfully"); } + +// The `Deserialize` (Read) and `DeserializeBytes` paths for element vectors must +// agree, even for non-canonical inner encodings. The Read path bounds the reader +// to the declared length and measures actual consumption instead of trusting +// `tls_serialized_len()`. +#[test] +fn read_and_bytes_paths_agree_on_non_minimal_inner_length() { + use tls_codec::{Deserialize, DeserializeBytes}; + // `Vec` uses a varint outer length. Declared content length = 3, + // followed by the single (non-minimally encoded) 3-byte element. + let mut input = vec![0x03]; + input.extend_from_slice(NON_MINIMAL_ELEMENT); + + let bytes_res = Vec::::tls_deserialize_exact_bytes(&input); + + let mut read_input = input.as_slice(); + let read_res = Vec::::tls_deserialize(&mut read_input); + + if cfg!(feature = "mls") { + // MLS requires minimum-size length encoding, so both paths reject it. + assert!(bytes_res.is_err()); + assert!(read_res.is_err()); + } else { + // Both paths must accept it and produce the identical result. + let expected = vec![VLBytes::new(vec![0xAA])]; + assert_eq!(bytes_res.as_ref().unwrap(), &expected); + assert_eq!(read_res.as_ref().unwrap(), &expected); + // And the Read path must have consumed the whole input. + assert!(read_input.is_empty()); + } +} + +#[test] +fn read_and_bytes_paths_agree_on_non_minimal_inner_length_fixed_len_vec() { + use tls_codec::{Deserialize, DeserializeBytes}; + // Same as above but with a fixed-size (u8) outer length field. + let mut input = vec![0x03]; + input.extend_from_slice(NON_MINIMAL_ELEMENT); + + let bytes_res = TlsVecU8::::tls_deserialize_exact_bytes(&input); + + let mut read_input = input.as_slice(); + let read_res = TlsVecU8::::tls_deserialize(&mut read_input); + + if cfg!(feature = "mls") { + assert!(bytes_res.is_err()); + assert!(read_res.is_err()); + } else { + let expected = TlsVecU8::from(vec![VLBytes::new(vec![0xAA])]); + assert_eq!(bytes_res.as_ref().unwrap(), &expected); + assert_eq!(read_res.as_ref().unwrap(), &expected); + assert!(read_input.is_empty()); + } +} + +#[test] +fn canonical_element_vec_still_round_trips() { + use tls_codec::{Deserialize, DeserializeBytes}; + // Guard against a regression in the bounded-reader loop for canonical input. + let original: Vec = vec![ + VLBytes::new(vec![1, 2, 3]), + VLBytes::new(vec![]), + VLBytes::new(vec![4]), + ]; + let serialized = original.tls_serialize_detached().unwrap(); + + let mut read_input = serialized.as_slice(); + let read = Vec::::tls_deserialize(&mut read_input).unwrap(); + assert_eq!(read, original); + assert!(read_input.is_empty()); + + let (bytes, remainder) = Vec::::tls_deserialize_bytes(&serialized).unwrap(); + assert_eq!(bytes, original); + assert!(remainder.is_empty()); +} + +// Read-based byte-vector deserialization must not eagerly allocate based on an +// untrusted length field. +#[test] +fn oversized_length_field_does_not_over_allocate() { + use tls_codec::Deserialize; + // TlsByteVecU32 declares ~4 GiB of content but only 3 bytes are present. + // The bounded reader must return an error promptly instead of allocating + // 4 GiB up front. + let input = &[0xFF, 0xFF, 0xFF, 0xFF, 1, 2, 3]; + let mut read_input = input.as_slice(); + let res = TlsByteVecU32::tls_deserialize(&mut read_input); + assert!(res.is_err()); + + // Same for the varint-length VLBytes: an 8-byte varint declaring a huge + // length with only one trailing byte. + let input = &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00]; + let mut read_input = input.as_slice(); + let res = VLBytes::tls_deserialize(&mut read_input); + assert!(res.is_err()); +} + +#[test] +fn byte_vec_round_trips_beyond_prealloc_cap() { + use tls_codec::Deserialize; + // Larger than the internal MAX_PREALLOC (4096) so the bounded reader has to + // loop across multiple chunks and grow the vector. + let payload = vec![0x5Au8; 10_000]; + let original = TlsByteVecU32::from(payload.clone()); + let serialized = original.tls_serialize_detached().unwrap(); + + let mut read_input = serialized.as_slice(); + let read = TlsByteVecU32::tls_deserialize(&mut read_input).unwrap(); + assert_eq!(read.as_slice(), payload.as_slice()); + assert!(read_input.is_empty()); +} diff --git a/tls_codec/tests/decode_bytes.rs b/tls_codec/tests/decode_bytes.rs index 413ed4ee9..01cc01f61 100644 --- a/tls_codec/tests/decode_bytes.rs +++ b/tls_codec/tests/decode_bytes.rs @@ -1,4 +1,7 @@ -use tls_codec::{DeserializeBytes, TlsByteVecU8, TlsByteVecU16, TlsByteVecU24, TlsByteVecU32}; +use tls_codec::{ + DeserializeBytes, Error, Size, TlsByteVecU8, TlsByteVecU16, TlsByteVecU24, TlsByteVecU32, + TlsVecU8, +}; #[test] fn deserialize_tls_byte_vec_u8() { @@ -35,3 +38,128 @@ fn deserialize_tls_byte_vec_u32() { assert_eq!(result.as_slice(), expected_result); assert_eq!(rest, []); } + +/// A zero-serialized-length element that always deserializes successfully +/// without consuming any input. Before the fix this caused an infinite loop. +#[derive(Debug, Clone, PartialEq)] +struct Zero; + +impl Size for Zero { + fn tls_serialized_len(&self) -> usize { + 0 + } +} + +impl DeserializeBytes for Zero { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { + Ok((Zero, bytes)) + } +} + +#[cfg(feature = "std")] +impl tls_codec::Deserialize for Zero { + fn tls_deserialize(_: &mut R) -> Result { + Ok(Zero) + } +} + +// zero-sized elements must not loop forever +#[test] +fn zero_sized_element_does_not_hang_bytes() { + // Length prefix says "1 byte of content" but `Zero` consumes nothing. + let input = [1u8]; + let res = TlsVecU8::::tls_deserialize_bytes(&input); + assert!( + matches!(res, Err(Error::DecodingError(_))), + "expected a decoding error, got {res:?}" + ); +} + +#[cfg(feature = "std")] +#[test] +fn zero_sized_element_does_not_hang_read() { + use tls_codec::Deserialize; + let input = [1u8]; + let res = TlsVecU8::::tls_deserialize(&mut input.as_slice()); + assert!( + matches!(res, Err(Error::DecodingError(_))), + "expected a decoding error, got {res:?}" + ); +} + +#[test] +fn zero_sized_element_does_not_hang_quic_vec() { + // QUIC-style `Vec` with a varint length prefix of 1. + let input = [1u8]; + let res = Vec::::tls_deserialize_bytes(&input); + assert!( + matches!(res, Err(Error::DecodingError(_))), + "expected a decoding error, got {res:?}" + ); +} + +// declared length must be enforced exactly (no overshoot) +#[test] +fn overshooting_declared_length_is_rejected_bytes() { + // Declared content length is 3 bytes, but two `u16` elements consume 4. + // The trailing 0xAA proves the decoder must not silently read past 3. + let input = [3u8, 0x00, 0x01, 0x00, 0x02, 0xAA]; + let res = TlsVecU8::::tls_deserialize_bytes(&input); + assert!( + matches!(res, Err(Error::DecodingError(_))), + "expected a decoding error, got {res:?}" + ); +} + +#[test] +fn overshooting_declared_length_is_rejected_quic_vec() { + // varint length prefix = 3, followed by two u16 (4 bytes) + trailing byte. + let input = [3u8, 0x00, 0x01, 0x00, 0x02, 0xAA]; + let res = Vec::::tls_deserialize_bytes(&input); + assert!( + matches!(res, Err(Error::DecodingError(_))), + "expected a decoding error, got {res:?}" + ); +} + +#[cfg(feature = "std")] +#[test] +fn overshooting_declared_length_is_rejected_read() { + use tls_codec::Deserialize; + let input = [3u8, 0x00, 0x01, 0x00, 0x02, 0xAA]; + let res = TlsVecU8::::tls_deserialize(&mut input.as_slice()); + // The Read path bounds the reader to the declared length (3 bytes), so it + // never reads past the boundary: the first `u16` consumes 2 bytes and the + // second cannot fit in the remaining byte, yielding `EndOfStream`. + assert!( + matches!(res, Err(Error::EndOfStream)), + "expected end of stream, got {res:?}" + ); +} + +// An exactly-fitting vector still round-trips +#[test] +fn exact_length_still_decodes() { + // Declared length 4 == two u16 elements. + let input = [4u8, 0x00, 0x01, 0x00, 0x02]; + let (v, rest) = TlsVecU8::::tls_deserialize_bytes(&input).expect("should decode"); + assert_eq!(v.as_slice(), &[1u16, 2]); + assert!(rest.is_empty()); +} + +// Oversized length must not overflow, just report EndOfStream +#[test] +fn truncated_byte_vec_reports_end_of_stream() { + // A u32 length prefix advertising a large length with no content bytes. + // The range computation must not overflow `usize`; it must fail cleanly. + let mut input = u32::MAX.to_be_bytes().to_vec(); + input.push(0xAA); // one stray content byte, far fewer than advertised + let res = TlsByteVecU32::tls_deserialize_bytes(&input); + assert!( + matches!( + res, + Err(Error::EndOfStream) | Err(Error::InvalidVectorLength) + ), + "expected EndOfStream/InvalidVectorLength, got {res:?}" + ); +}