]> git.baikalelectronics.ru Git - kernel.git/commitdiff
net/packet: remove data races in fanout operations
authorEric Dumazet <edumazet@google.com>
Wed, 14 Apr 2021 19:36:44 +0000 (12:36 -0700)
committerDavid S. Miller <davem@davemloft.net>
Wed, 14 Apr 2021 21:34:38 +0000 (14:34 -0700)
af_packet fanout uses RCU rules to ensure f->arr elements
are not dismantled before RCU grace period.

However, it lacks rcu accessors to make sure KCSAN and other tools
wont detect data races. Stupid compilers could also play games.

Fixes: dc99f600698d ("packet: Add fanout support.")
Signed-off-by: Eric Dumazet <edumazet@google.com>
Reported-by: "Gong, Sishuai" <sishuai@purdue.edu>
Cc: Willem de Bruijn <willemb@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/packet/af_packet.c
net/packet/internal.h

index 118d585337d72f10cd31ec5ca7c55b508fc18baf..ba96db1880eae89febf77ba6ff943b054cd268d7 100644 (file)
@@ -1359,7 +1359,7 @@ static unsigned int fanout_demux_rollover(struct packet_fanout *f,
        struct packet_sock *po, *po_next, *po_skip = NULL;
        unsigned int i, j, room = ROOM_NONE;
 
-       po = pkt_sk(f->arr[idx]);
+       po = pkt_sk(rcu_dereference(f->arr[idx]));
 
        if (try_self) {
                room = packet_rcv_has_room(po, skb);
@@ -1371,7 +1371,7 @@ static unsigned int fanout_demux_rollover(struct packet_fanout *f,
 
        i = j = min_t(int, po->rollover->sock, num - 1);
        do {
-               po_next = pkt_sk(f->arr[i]);
+               po_next = pkt_sk(rcu_dereference(f->arr[i]));
                if (po_next != po_skip && !READ_ONCE(po_next->pressure) &&
                    packet_rcv_has_room(po_next, skb) == ROOM_NORMAL) {
                        if (i != j)
@@ -1466,7 +1466,7 @@ static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
        if (fanout_has_flag(f, PACKET_FANOUT_FLAG_ROLLOVER))
                idx = fanout_demux_rollover(f, skb, idx, true, num);
 
-       po = pkt_sk(f->arr[idx]);
+       po = pkt_sk(rcu_dereference(f->arr[idx]));
        return po->prot_hook.func(skb, dev, &po->prot_hook, orig_dev);
 }
 
@@ -1480,7 +1480,7 @@ static void __fanout_link(struct sock *sk, struct packet_sock *po)
        struct packet_fanout *f = po->fanout;
 
        spin_lock(&f->lock);
-       f->arr[f->num_members] = sk;
+       rcu_assign_pointer(f->arr[f->num_members], sk);
        smp_wmb();
        f->num_members++;
        if (f->num_members == 1)
@@ -1495,11 +1495,14 @@ static void __fanout_unlink(struct sock *sk, struct packet_sock *po)
 
        spin_lock(&f->lock);
        for (i = 0; i < f->num_members; i++) {
-               if (f->arr[i] == sk)
+               if (rcu_dereference_protected(f->arr[i],
+                                             lockdep_is_held(&f->lock)) == sk)
                        break;
        }
        BUG_ON(i >= f->num_members);
-       f->arr[i] = f->arr[f->num_members - 1];
+       rcu_assign_pointer(f->arr[i],
+                          rcu_dereference_protected(f->arr[f->num_members - 1],
+                                                    lockdep_is_held(&f->lock)));
        f->num_members--;
        if (f->num_members == 0)
                __dev_remove_pack(&f->prot_hook);
index 5f61e59ebbffaa25a8fdfe31f79211fe6a755c51..48af35b1aed2565267c0288e013e23ff51f2fcac 100644 (file)
@@ -94,7 +94,7 @@ struct packet_fanout {
        spinlock_t              lock;
        refcount_t              sk_ref;
        struct packet_type      prot_hook ____cacheline_aligned_in_smp;
-       struct sock             *arr[];
+       struct sock     __rcu   *arr[];
 };
 
 struct packet_rollover {