Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9afd85c9 authored by Linus Lüssing's avatar Linus Lüssing Committed by David S. Miller
Browse files

net: Export IGMP/MLD message validation code



With this patch, the IGMP and MLD message validation functions are moved
from the bridge code to IPv4/IPv6 multicast files. Some small
refactoring was done to enhance readibility and to iron out some
differences in behaviour between the IGMP and MLD parsing code (e.g. the
skb-cloning of MLD messages is now only done if necessary, just like the
IGMP part always did).

Finally, these IGMP and MLD message validation functions are exported so
that not only the bridge can use it but batman-adv later, too.

Signed-off-by: default avatarLinus Lüssing <linus.luessing@c0d3.blue>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 3c9e4f87
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
@@ -130,5 +130,6 @@ extern void ip_mc_unmap(struct in_device *);
extern void ip_mc_remap(struct in_device *);
extern void ip_mc_remap(struct in_device *);
extern void ip_mc_dec_group(struct in_device *in_dev, __be32 addr);
extern void ip_mc_dec_group(struct in_device *in_dev, __be32 addr);
extern void ip_mc_inc_group(struct in_device *in_dev, __be32 addr);
extern void ip_mc_inc_group(struct in_device *in_dev, __be32 addr);
int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed);


#endif
#endif
+3 −0
Original line number Original line Diff line number Diff line
@@ -3419,6 +3419,9 @@ static inline void skb_checksum_none_assert(const struct sk_buff *skb)
bool skb_partial_csum_set(struct sk_buff *skb, u16 start, u16 off);
bool skb_partial_csum_set(struct sk_buff *skb, u16 start, u16 off);


int skb_checksum_setup(struct sk_buff *skb, bool recalculate);
int skb_checksum_setup(struct sk_buff *skb, bool recalculate);
struct sk_buff *skb_checksum_trimmed(struct sk_buff *skb,
				     unsigned int transport_len,
				     __sum16(*skb_chkf)(struct sk_buff *skb));


u32 skb_get_poff(const struct sk_buff *skb);
u32 skb_get_poff(const struct sk_buff *skb);
u32 __skb_get_poff(const struct sk_buff *skb, void *data,
u32 __skb_get_poff(const struct sk_buff *skb, void *data,
+1 −0
Original line number Original line Diff line number Diff line
@@ -142,6 +142,7 @@ void ipv6_mc_unmap(struct inet6_dev *idev);
void ipv6_mc_remap(struct inet6_dev *idev);
void ipv6_mc_remap(struct inet6_dev *idev);
void ipv6_mc_init_dev(struct inet6_dev *idev);
void ipv6_mc_init_dev(struct inet6_dev *idev);
void ipv6_mc_destroy_dev(struct inet6_dev *idev);
void ipv6_mc_destroy_dev(struct inet6_dev *idev);
int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed);
void addrconf_dad_failure(struct inet6_ifaddr *ifp);
void addrconf_dad_failure(struct inet6_ifaddr *ifp);


bool ipv6_chk_mcast_addr(struct net_device *dev, const struct in6_addr *group,
bool ipv6_chk_mcast_addr(struct net_device *dev, const struct in6_addr *group,
+30 −188
Original line number Original line Diff line number Diff line
@@ -975,9 +975,6 @@ static int br_ip4_multicast_igmp3_report(struct net_bridge *br,
	int err = 0;
	int err = 0;
	__be32 group;
	__be32 group;


	if (!pskb_may_pull(skb, sizeof(*ih)))
		return -EINVAL;

	ih = igmpv3_report_hdr(skb);
	ih = igmpv3_report_hdr(skb);
	num = ntohs(ih->ngrec);
	num = ntohs(ih->ngrec);
	len = sizeof(*ih);
	len = sizeof(*ih);
@@ -1248,25 +1245,14 @@ static int br_ip4_multicast_query(struct net_bridge *br,
			max_delay = 10 * HZ;
			max_delay = 10 * HZ;
			group = 0;
			group = 0;
		}
		}
	} else {
	} else if (skb->len >= sizeof(*ih3)) {
		if (!pskb_may_pull(skb, sizeof(struct igmpv3_query))) {
			err = -EINVAL;
			goto out;
		}

		ih3 = igmpv3_query_hdr(skb);
		ih3 = igmpv3_query_hdr(skb);
		if (ih3->nsrcs)
		if (ih3->nsrcs)
			goto out;
			goto out;


		max_delay = ih3->code ?
		max_delay = ih3->code ?
			    IGMPV3_MRC(ih3->code) * (HZ / IGMP_TIMER_SCALE) : 1;
			    IGMPV3_MRC(ih3->code) * (HZ / IGMP_TIMER_SCALE) : 1;
	}
	} else {

	/* RFC2236+RFC3376 (IGMPv2+IGMPv3) require the multicast link layer
	 * all-systems destination addresses (224.0.0.1) for general queries
	 */
	if (!group && iph->daddr != htonl(INADDR_ALLHOSTS_GROUP)) {
		err = -EINVAL;
		goto out;
		goto out;
	}
	}


@@ -1329,12 +1315,6 @@ static int br_ip6_multicast_query(struct net_bridge *br,
	    (port && port->state == BR_STATE_DISABLED))
	    (port && port->state == BR_STATE_DISABLED))
		goto out;
		goto out;


	/* RFC2710+RFC3810 (MLDv1+MLDv2) require link-local source addresses */
	if (!(ipv6_addr_type(&ip6h->saddr) & IPV6_ADDR_LINKLOCAL)) {
		err = -EINVAL;
		goto out;
	}

	if (skb->len == sizeof(*mld)) {
	if (skb->len == sizeof(*mld)) {
		if (!pskb_may_pull(skb, sizeof(*mld))) {
		if (!pskb_may_pull(skb, sizeof(*mld))) {
			err = -EINVAL;
			err = -EINVAL;
@@ -1358,14 +1338,6 @@ static int br_ip6_multicast_query(struct net_bridge *br,


	is_general_query = group && ipv6_addr_any(group);
	is_general_query = group && ipv6_addr_any(group);


	/* RFC2710+RFC3810 (MLDv1+MLDv2) require the multicast link layer
	 * all-nodes destination address (ff02::1) for general queries
	 */
	if (is_general_query && !ipv6_addr_is_ll_all_nodes(&ip6h->daddr)) {
		err = -EINVAL;
		goto out;
	}

	if (is_general_query) {
	if (is_general_query) {
		saddr.proto = htons(ETH_P_IPV6);
		saddr.proto = htons(ETH_P_IPV6);
		saddr.u.ip6 = ip6h->saddr;
		saddr.u.ip6 = ip6h->saddr;
@@ -1557,66 +1529,22 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
				 struct sk_buff *skb,
				 struct sk_buff *skb,
				 u16 vid)
				 u16 vid)
{
{
	struct sk_buff *skb2 = skb;
	struct sk_buff *skb_trimmed = NULL;
	const struct iphdr *iph;
	struct igmphdr *ih;
	struct igmphdr *ih;
	unsigned int len;
	unsigned int offset;
	int err;
	int err;


	/* We treat OOM as packet loss for now. */
	err = ip_mc_check_igmp(skb, &skb_trimmed);
	if (!pskb_may_pull(skb, sizeof(*iph)))
		return -EINVAL;

	iph = ip_hdr(skb);

	if (iph->ihl < 5 || iph->version != 4)
		return -EINVAL;

	if (!pskb_may_pull(skb, ip_hdrlen(skb)))
		return -EINVAL;

	iph = ip_hdr(skb);


	if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl)))
	if (err == -ENOMSG) {
		return -EINVAL;
		if (!ipv4_is_local_multicast(ip_hdr(skb)->daddr))

	if (iph->protocol != IPPROTO_IGMP) {
		if (!ipv4_is_local_multicast(iph->daddr))
			BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
			BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
		return 0;
		return 0;
	} else if (err < 0) {
		return err;
	}
	}


	len = ntohs(iph->tot_len);
	if (skb->len < len || len < ip_hdrlen(skb))
		return -EINVAL;

	if (skb->len > len) {
		skb2 = skb_clone(skb, GFP_ATOMIC);
		if (!skb2)
			return -ENOMEM;

		err = pskb_trim_rcsum(skb2, len);
		if (err)
			goto err_out;
	}

	len -= ip_hdrlen(skb2);
	offset = skb_network_offset(skb2) + ip_hdrlen(skb2);
	__skb_pull(skb2, offset);
	skb_reset_transport_header(skb2);

	err = -EINVAL;
	if (!pskb_may_pull(skb2, sizeof(*ih)))
		goto out;

	if (skb_checksum_simple_validate(skb2))
		goto out;

	err = 0;

	BR_INPUT_SKB_CB(skb)->igmp = 1;
	BR_INPUT_SKB_CB(skb)->igmp = 1;
	ih = igmp_hdr(skb2);
	ih = igmp_hdr(skb);


	switch (ih->type) {
	switch (ih->type) {
	case IGMP_HOST_MEMBERSHIP_REPORT:
	case IGMP_HOST_MEMBERSHIP_REPORT:
@@ -1625,21 +1553,19 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
		err = br_ip4_multicast_add_group(br, port, ih->group, vid);
		err = br_ip4_multicast_add_group(br, port, ih->group, vid);
		break;
		break;
	case IGMPV3_HOST_MEMBERSHIP_REPORT:
	case IGMPV3_HOST_MEMBERSHIP_REPORT:
		err = br_ip4_multicast_igmp3_report(br, port, skb2, vid);
		err = br_ip4_multicast_igmp3_report(br, port, skb_trimmed, vid);
		break;
		break;
	case IGMP_HOST_MEMBERSHIP_QUERY:
	case IGMP_HOST_MEMBERSHIP_QUERY:
		err = br_ip4_multicast_query(br, port, skb2, vid);
		err = br_ip4_multicast_query(br, port, skb_trimmed, vid);
		break;
		break;
	case IGMP_HOST_LEAVE_MESSAGE:
	case IGMP_HOST_LEAVE_MESSAGE:
		br_ip4_multicast_leave_group(br, port, ih->group, vid);
		br_ip4_multicast_leave_group(br, port, ih->group, vid);
		break;
		break;
	}
	}


out:
	if (skb_trimmed)
	__skb_push(skb2, offset);
		kfree_skb(skb_trimmed);
err_out:

	if (skb2 != skb)
		kfree_skb(skb2);
	return err;
	return err;
}
}


@@ -1649,126 +1575,42 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
				 struct sk_buff *skb,
				 struct sk_buff *skb,
				 u16 vid)
				 u16 vid)
{
{
	struct sk_buff *skb2;
	struct sk_buff *skb_trimmed = NULL;
	const struct ipv6hdr *ip6h;
	struct mld_msg *mld;
	u8 icmp6_type;
	u8 nexthdr;
	__be16 frag_off;
	unsigned int len;
	int offset;
	int err;
	int err;


	if (!pskb_may_pull(skb, sizeof(*ip6h)))
	err = ipv6_mc_check_mld(skb, &skb_trimmed);
		return -EINVAL;

	ip6h = ipv6_hdr(skb);


	/*
	if (err == -ENOMSG) {
	 * We're interested in MLD messages only.
		if (!ipv6_addr_is_ll_all_nodes(&ipv6_hdr(skb)->daddr))
	 *  - Version is 6
	 *  - MLD has always Router Alert hop-by-hop option
	 *  - But we do not support jumbrograms.
	 */
	if (ip6h->version != 6)
		return 0;

	/* Prevent flooding this packet if there is no listener present */
	if (!ipv6_addr_is_ll_all_nodes(&ip6h->daddr))
			BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
			BR_INPUT_SKB_CB(skb)->mrouters_only = 1;

	if (ip6h->nexthdr != IPPROTO_HOPOPTS ||
	    ip6h->payload_len == 0)
		return 0;

	len = ntohs(ip6h->payload_len) + sizeof(*ip6h);
	if (skb->len < len)
		return -EINVAL;

	nexthdr = ip6h->nexthdr;
	offset = ipv6_skip_exthdr(skb, sizeof(*ip6h), &nexthdr, &frag_off);

	if (offset < 0 || nexthdr != IPPROTO_ICMPV6)
		return 0;
		return 0;

	} else if (err < 0) {
	/* Okay, we found ICMPv6 header */
		return err;
	skb2 = skb_clone(skb, GFP_ATOMIC);
	if (!skb2)
		return -ENOMEM;

	err = -EINVAL;
	if (!pskb_may_pull(skb2, offset + sizeof(struct icmp6hdr)))
		goto out;

	len -= offset - skb_network_offset(skb2);

	__skb_pull(skb2, offset);
	skb_reset_transport_header(skb2);
	skb_postpull_rcsum(skb2, skb_network_header(skb2),
			   skb_network_header_len(skb2));

	icmp6_type = icmp6_hdr(skb2)->icmp6_type;

	switch (icmp6_type) {
	case ICMPV6_MGM_QUERY:
	case ICMPV6_MGM_REPORT:
	case ICMPV6_MGM_REDUCTION:
	case ICMPV6_MLD2_REPORT:
		break;
	default:
		err = 0;
		goto out;
	}

	/* Okay, we found MLD message. Check further. */
	if (skb2->len > len) {
		err = pskb_trim_rcsum(skb2, len);
		if (err)
			goto out;
		err = -EINVAL;
	}
	}


	ip6h = ipv6_hdr(skb2);

	if (skb_checksum_validate(skb2, IPPROTO_ICMPV6, ip6_compute_pseudo))
		goto out;

	err = 0;

	BR_INPUT_SKB_CB(skb)->igmp = 1;
	BR_INPUT_SKB_CB(skb)->igmp = 1;
	mld = (struct mld_msg *)skb_transport_header(skb);


	switch (icmp6_type) {
	switch (mld->mld_type) {
	case ICMPV6_MGM_REPORT:
	case ICMPV6_MGM_REPORT:
	    {
		struct mld_msg *mld;
		if (!pskb_may_pull(skb2, sizeof(*mld))) {
			err = -EINVAL;
			goto out;
		}
		mld = (struct mld_msg *)skb_transport_header(skb2);
		BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
		BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
		err = br_ip6_multicast_add_group(br, port, &mld->mld_mca, vid);
		err = br_ip6_multicast_add_group(br, port, &mld->mld_mca, vid);
		break;
		break;
	    }
	case ICMPV6_MLD2_REPORT:
	case ICMPV6_MLD2_REPORT:
		err = br_ip6_multicast_mld2_report(br, port, skb2, vid);
		err = br_ip6_multicast_mld2_report(br, port, skb_trimmed, vid);
		break;
		break;
	case ICMPV6_MGM_QUERY:
	case ICMPV6_MGM_QUERY:
		err = br_ip6_multicast_query(br, port, skb2, vid);
		err = br_ip6_multicast_query(br, port, skb_trimmed, vid);
		break;
		break;
	case ICMPV6_MGM_REDUCTION:
	case ICMPV6_MGM_REDUCTION:
	    {
		struct mld_msg *mld;
		if (!pskb_may_pull(skb2, sizeof(*mld))) {
			err = -EINVAL;
			goto out;
		}
		mld = (struct mld_msg *)skb_transport_header(skb2);
		br_ip6_multicast_leave_group(br, port, &mld->mld_mca, vid);
		br_ip6_multicast_leave_group(br, port, &mld->mld_mca, vid);
	    }
		break;
	}
	}


out:
	if (skb_trimmed)
	kfree_skb(skb2);
		kfree_skb(skb_trimmed);

	return err;
	return err;
}
}
#endif
#endif
+87 −0
Original line number Original line Diff line number Diff line
@@ -4030,6 +4030,93 @@ int skb_checksum_setup(struct sk_buff *skb, bool recalculate)
}
}
EXPORT_SYMBOL(skb_checksum_setup);
EXPORT_SYMBOL(skb_checksum_setup);


/**
 * skb_checksum_maybe_trim - maybe trims the given skb
 * @skb: the skb to check
 * @transport_len: the data length beyond the network header
 *
 * Checks whether the given skb has data beyond the given transport length.
 * If so, returns a cloned skb trimmed to this transport length.
 * Otherwise returns the provided skb. Returns NULL in error cases
 * (e.g. transport_len exceeds skb length or out-of-memory).
 *
 * Caller needs to set the skb transport header and release the returned skb.
 * Provided skb is consumed.
 */
static struct sk_buff *skb_checksum_maybe_trim(struct sk_buff *skb,
					       unsigned int transport_len)
{
	struct sk_buff *skb_chk;
	unsigned int len = skb_transport_offset(skb) + transport_len;
	int ret;

	if (skb->len < len) {
		kfree_skb(skb);
		return NULL;
	} else if (skb->len == len) {
		return skb;
	}

	skb_chk = skb_clone(skb, GFP_ATOMIC);
	kfree_skb(skb);

	if (!skb_chk)
		return NULL;

	ret = pskb_trim_rcsum(skb_chk, len);
	if (ret) {
		kfree_skb(skb_chk);
		return NULL;
	}

	return skb_chk;
}

/**
 * skb_checksum_trimmed - validate checksum of an skb
 * @skb: the skb to check
 * @transport_len: the data length beyond the network header
 * @skb_chkf: checksum function to use
 *
 * Applies the given checksum function skb_chkf to the provided skb.
 * Returns a checked and maybe trimmed skb. Returns NULL on error.
 *
 * If the skb has data beyond the given transport length, then a
 * trimmed & cloned skb is checked and returned.
 *
 * Caller needs to set the skb transport header and release the returned skb.
 * Provided skb is consumed.
 */
struct sk_buff *skb_checksum_trimmed(struct sk_buff *skb,
				     unsigned int transport_len,
				     __sum16(*skb_chkf)(struct sk_buff *skb))
{
	struct sk_buff *skb_chk;
	unsigned int offset = skb_transport_offset(skb);
	int ret;

	skb_chk = skb_checksum_maybe_trim(skb, transport_len);
	if (!skb_chk)
		return NULL;

	if (!pskb_may_pull(skb_chk, offset)) {
		kfree_skb(skb_chk);
		return NULL;
	}

	__skb_pull(skb_chk, offset);
	ret = skb_chkf(skb_chk);
	__skb_push(skb_chk, offset);

	if (ret) {
		kfree_skb(skb_chk);
		return NULL;
	}

	return skb_chk;
}
EXPORT_SYMBOL(skb_checksum_trimmed);

void __skb_warn_lro_forwarding(const struct sk_buff *skb)
void __skb_warn_lro_forwarding(const struct sk_buff *skb)
{
{
	net_warn_ratelimited("%s: received packets cannot be forwarded while LRO is enabled\n",
	net_warn_ratelimited("%s: received packets cannot be forwarded while LRO is enabled\n",
Loading