diff --git a/src/iface/interface/ipv6.rs b/src/iface/interface/ipv6.rs index 96e999f7e..ca8c2ef35 100644 --- a/src/iface/interface/ipv6.rs +++ b/src/iface/interface/ipv6.rs @@ -422,6 +422,16 @@ impl InterfaceInner { #[cfg(feature = "medium-ip")] Medium::Ip => None, }, + #[cfg(feature = "multicast")] + Icmpv6Repr::Mld(repr) => match repr { + // [RFC 3810 ยง 6.2], reception checks + MldRepr::Query { .. } + if ip_repr.hop_limit == 1 && ip_repr.src_addr.is_link_local() => + { + self.process_mldv2(ip_repr, repr) + } + _ => None, + }, // Don't report an error if a packet with unknown type // has been handled by an ICMP socket diff --git a/src/iface/interface/multicast.rs b/src/iface/interface/multicast.rs index 68b1c7750..6f67c3d9c 100644 --- a/src/iface/interface/multicast.rs +++ b/src/iface/interface/multicast.rs @@ -1,7 +1,7 @@ use core::result::Result; use heapless::LinearMap; -#[cfg(feature = "proto-ipv4")] +#[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))] use super::{check, IpPayload, Packet}; use super::{Interface, InterfaceInner}; use crate::config::IFACE_MAX_MULTICAST_GROUP_COUNT; @@ -34,6 +34,18 @@ pub(crate) enum IgmpReportState { }, } +#[cfg(feature = "proto-ipv6")] +pub(crate) enum MldReportState { + Inactive, + ToGeneralQuery { + timeout: crate::time::Instant, + }, + ToSpecificQuery { + group: Ipv6Address, + timeout: crate::time::Instant, + }, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum GroupState { /// Joining group, we have to send the join packet. @@ -49,6 +61,8 @@ pub(crate) struct State { /// When to report for (all or) the next multicast group membership via IGMP #[cfg(feature = "proto-ipv4")] igmp_report_state: IgmpReportState, + #[cfg(feature = "proto-ipv6")] + mld_report_state: MldReportState, } impl State { @@ -57,6 +71,8 @@ impl State { groups: LinearMap::new(), #[cfg(feature = "proto-ipv4")] igmp_report_state: IgmpReportState::Inactive, + #[cfg(feature = "proto-ipv6")] + mld_report_state: MldReportState::Inactive, } } @@ -306,6 +322,46 @@ impl Interface { } _ => {} } + #[cfg(feature = "proto-ipv6")] + match self.inner.multicast.mld_report_state { + MldReportState::ToGeneralQuery { timeout } if self.inner.now >= timeout => { + let records = self + .inner + .multicast + .groups + .iter() + .filter_map(|(addr, _)| match addr { + IpAddress::Ipv6(addr) => Some(MldAddressRecordRepr::new( + MldRecordType::ModeIsExclude, + *addr, + )), + #[allow(unreachable_patterns)] + _ => None, + }) + .collect::>(); + if let Some(pkt) = self.inner.mldv2_report_packet(&records) { + if let Some(tx_token) = device.transmit(self.inner.now) { + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + }; + }; + self.inner.multicast.mld_report_state = MldReportState::Inactive; + } + MldReportState::ToSpecificQuery { group, timeout } if self.inner.now >= timeout => { + let record = MldAddressRecordRepr::new(MldRecordType::ModeIsExclude, group); + if let Some(pkt) = self.inner.mldv2_report_packet(&[record]) { + if let Some(tx_token) = device.transmit(self.inner.now) { + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + } + } + self.inner.multicast.mld_report_state = MldReportState::Inactive; + } + _ => {} + } } } @@ -425,4 +481,55 @@ impl InterfaceInner { ) }) } + + /// Host duties of the **MLDv2** protocol. + /// + /// Sets up `mld_report_state` for responding to MLD general/specific membership queries. + /// Membership must not be reported immediately in order to avoid flooding the network + /// after a query is broadcasted by a router; Currently the delay is fixed and not randomized. + #[cfg(feature = "proto-ipv6")] + pub(super) fn process_mldv2<'frame>( + &mut self, + ip_repr: Ipv6Repr, + repr: MldRepr<'frame>, + ) -> Option> { + match repr { + MldRepr::Query { + mcast_addr, + max_resp_code, + .. + } => { + // Do not respont immediately to the query, but wait a random time + let delay = crate::time::Duration::from_millis( + (self.rand.rand_u16() % max_resp_code).into(), + ); + // General query + if mcast_addr.is_unspecified() + && (ip_repr.dst_addr == IPV6_LINK_LOCAL_ALL_NODES + || self.has_ip_addr(ip_repr.dst_addr)) + { + let ipv6_multicast_group_count = self + .multicast + .groups + .keys() + .filter(|a| matches!(a, IpAddress::Ipv6(_))) + .count(); + if ipv6_multicast_group_count != 0 { + self.multicast.mld_report_state = MldReportState::ToGeneralQuery { + timeout: self.now + delay, + }; + } + } + if self.has_multicast_group(mcast_addr) && ip_repr.dst_addr == mcast_addr { + self.multicast.mld_report_state = MldReportState::ToSpecificQuery { + group: mcast_addr, + timeout: self.now + delay, + }; + } + None + } + MldRepr::Report { .. } => None, + MldRepr::ReportRecordReprs { .. } => None, + } + } } diff --git a/src/iface/interface/tests/ipv6.rs b/src/iface/interface/tests/ipv6.rs index f67737b31..75687e5c3 100644 --- a/src/iface/interface/tests/ipv6.rs +++ b/src/iface/interface/tests/ipv6.rs @@ -1378,3 +1378,187 @@ fn test_join_ipv6_multicast_group(#[case] medium: Medium) { assert!(!iface.has_multicast_group(group_addr)); } } + +#[rstest] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "multicast", feature = "medium-ethernet"))] +fn test_handle_valid_multicast_query(#[case] medium: Medium) { + fn recv_icmpv6( + device: &mut crate::tests::TestingDevice, + timestamp: Instant, + ) -> std::vec::Vec>> { + let caps = device.capabilities(); + recv_all(device, timestamp) + .iter() + .filter_map(|frame| { + let ipv6_packet = match caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => { + let eth_frame = EthernetFrame::new_checked(frame).ok()?; + Ipv6Packet::new_checked(eth_frame.payload()).ok()? + } + #[cfg(feature = "medium-ip")] + Medium::Ip => Ipv6Packet::new_checked(&frame[..]).ok()?, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => todo!(), + }; + let buf = ipv6_packet.into_inner().to_vec(); + Some(Ipv6Packet::new_unchecked(buf)) + }) + .collect::>() + } + + let (mut iface, mut sockets, mut device) = setup(medium); + + let mut timestamp = Instant::ZERO; + + let mut eth_bytes = vec![0u8; 86]; + + let local_ip_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 101); + let remote_ip_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 100); + let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); + let query_ip_addr = Ipv6Address::new(0xff02, 0, 0, 0, 0, 0, 0, 0x1234); + + iface.join_multicast_group(query_ip_addr).unwrap(); + iface + .join_multicast_group(local_ip_addr.solicited_node()) + .unwrap(); + + iface.poll(timestamp, &mut device, &mut sockets); + // flush multicast reports from the join_multicast_group calls + recv_icmpv6(&mut device, timestamp); + + let queries = [ + // General query, expect both multicast addresses back + ( + Ipv6Address::UNSPECIFIED, + IPV6_LINK_LOCAL_ALL_NODES, + vec![query_ip_addr, local_ip_addr.solicited_node()], + ), + // Address specific query, expect only the queried address back + (query_ip_addr, query_ip_addr, vec![query_ip_addr]), + ]; + + for (mcast_query, address, _results) in queries.iter() { + let query = Icmpv6Repr::Mld(MldRepr::Query { + max_resp_code: 1000, + mcast_addr: *mcast_query, + s_flag: false, + qrv: 1, + qqic: 60, + num_srcs: 0, + data: &[0, 0, 0, 0], + }); + + let ip_repr = IpRepr::Ipv6(Ipv6Repr { + src_addr: remote_ip_addr, + dst_addr: *address, + next_header: IpProtocol::Icmpv6, + hop_limit: 1, + payload_len: query.buffer_len(), + }); + + let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); + frame.set_dst_addr(EthernetAddress([0x33, 0x33, 0x00, 0x00, 0x00, 0x00])); + frame.set_src_addr(remote_hw_addr); + frame.set_ethertype(EthernetProtocol::Ipv6); + ip_repr.emit(frame.payload_mut(), &ChecksumCapabilities::default()); + query.emit( + &remote_ip_addr, + address, + &mut Icmpv6Packet::new_unchecked(&mut frame.payload_mut()[ip_repr.header_len()..]), + &ChecksumCapabilities::default(), + ); + + iface.inner.process_ethernet( + &mut sockets, + PacketMeta::default(), + frame.into_inner(), + &mut iface.fragments, + ); + + timestamp += crate::time::Duration::from_millis(1000); + iface.poll(timestamp, &mut device, &mut sockets); + } + + let reports = recv_icmpv6(&mut device, timestamp); + assert_eq!(reports.len(), queries.len()); + + let caps = device.capabilities(); + let checksum_caps = &caps.checksum; + for ((_mcast_query, _address, results), ipv6_packet) in queries.iter().zip(reports) { + let buf = ipv6_packet.into_inner(); + let ipv6_packet = Ipv6Packet::new_unchecked(buf.as_slice()); + + let ipv6_repr = Ipv6Repr::parse(&ipv6_packet).unwrap(); + let ip_payload = ipv6_packet.payload(); + assert_eq!(ipv6_repr.dst_addr, IPV6_LINK_LOCAL_ALL_MLDV2_ROUTERS); + + // The first 2 octets of this payload hold the next-header indicator and the + // Hop-by-Hop header length (in 8-octet words, minus 1). The remaining 6 octets + // hold the Hop-by-Hop PadN and Router Alert options. + let hbh_header = Ipv6HopByHopHeader::new_checked(&ip_payload[..8]).unwrap(); + let hbh_repr = Ipv6HopByHopRepr::parse(&hbh_header).unwrap(); + + assert_eq!(hbh_repr.options.len(), 3); + assert_eq!( + hbh_repr.options[0], + Ipv6OptionRepr::Unknown { + type_: Ipv6OptionType::Unknown(IpProtocol::Icmpv6.into()), + length: 0, + data: &[], + } + ); + assert_eq!( + hbh_repr.options[1], + Ipv6OptionRepr::RouterAlert(Ipv6OptionRouterAlert::MulticastListenerDiscovery) + ); + assert_eq!(hbh_repr.options[2], Ipv6OptionRepr::PadN(0)); + + let icmpv6_packet = + Icmpv6Packet::new_checked(&ip_payload[hbh_repr.buffer_len()..]).unwrap(); + let icmpv6_repr = Icmpv6Repr::parse( + &ipv6_packet.src_addr(), + &ipv6_packet.dst_addr(), + &icmpv6_packet, + checksum_caps, + ) + .unwrap(); + + let record_data = match icmpv6_repr { + Icmpv6Repr::Mld(MldRepr::Report { + nr_mcast_addr_rcrds, + data, + }) => { + assert_eq!(nr_mcast_addr_rcrds, results.len() as u16); + data + } + other => panic!("unexpected icmpv6_repr: {:?}", other), + }; + + let mut record_reprs = Vec::new(); + let mut payload = record_data; + + // FIXME: parsing multiple address records should be done by the MLD code + while !payload.is_empty() { + let record = MldAddressRecord::new_checked(payload).unwrap(); + let mut record_repr = MldAddressRecordRepr::parse(&record).unwrap(); + payload = record_repr.payload; + record_repr.payload = &[]; + record_reprs.push(record_repr); + } + + let expected_records = results + .iter() + .map(|addr| MldAddressRecordRepr { + num_srcs: 0, + mcast_addr: *addr, + record_type: MldRecordType::ModeIsExclude, + aux_data_len: 0, + payload: &[], + }) + .collect::>(); + + assert_eq!(record_reprs, expected_records); + } +}