]> git.baikalelectronics.ru Git - kernel.git/commitdiff
bpf: tcp: Support tcp_congestion_ops in bpf
authorMartin KaFai Lau <kafai@fb.com>
Thu, 9 Jan 2020 00:35:08 +0000 (16:35 -0800)
committerAlexei Starovoitov <ast@kernel.org>
Thu, 9 Jan 2020 16:46:18 +0000 (08:46 -0800)
This patch makes "struct tcp_congestion_ops" to be the first user
of BPF STRUCT_OPS.  It allows implementing a tcp_congestion_ops
in bpf.

The BPF implemented tcp_congestion_ops can be used like
regular kernel tcp-cc through sysctl and setsockopt.  e.g.
[root@arch-fb-vm1 bpf]# sysctl -a | egrep congestion
net.ipv4.tcp_allowed_congestion_control = reno cubic bpf_cubic
net.ipv4.tcp_available_congestion_control = reno bic cubic bpf_cubic
net.ipv4.tcp_congestion_control = bpf_cubic

There has been attempt to move the TCP CC to the user space
(e.g. CCP in TCP).   The common arguments are faster turn around,
get away from long-tail kernel versions in production...etc,
which are legit points.

BPF has been the continuous effort to join both kernel and
userspace upsides together (e.g. XDP to gain the performance
advantage without bypassing the kernel).  The recent BPF
advancements (in particular BTF-aware verifier, BPF trampoline,
BPF CO-RE...) made implementing kernel struct ops (e.g. tcp cc)
possible in BPF.  It allows a faster turnaround for testing algorithm
in the production while leveraging the existing (and continue growing)
BPF feature/framework instead of building one specifically for
userspace TCP CC.

This patch allows write access to a few fields in tcp-sock
(in bpf_tcp_ca_btf_struct_access()).

The optional "get_info" is unsupported now.  It can be added
later.  One possible way is to output the info with a btf-id
to describe the content.

Signed-off-by: Martin KaFai Lau <kafai@fb.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Andrii Nakryiko <andriin@fb.com>
Acked-by: Yonghong Song <yhs@fb.com>
Link: https://lore.kernel.org/bpf/20200109003508.3856115-1-kafai@fb.com
include/linux/filter.h
include/net/tcp.h
kernel/bpf/bpf_struct_ops_types.h
net/core/filter.c
net/ipv4/Makefile
net/ipv4/bpf_tcp_ca.c [new file with mode: 0644]
net/ipv4/tcp_cong.c
net/ipv4/tcp_ipv4.c
net/ipv4/tcp_minisocks.c
net/ipv4/tcp_output.c

index 70e6dd960bcafc54ce44b194f9041f0ae4e1a37d..a366a0b64a57df04655464e034e8096aa8c444a2 100644 (file)
@@ -843,6 +843,8 @@ int bpf_prog_create(struct bpf_prog **pfp, struct sock_fprog_kern *fprog);
 int bpf_prog_create_from_user(struct bpf_prog **pfp, struct sock_fprog *fprog,
                              bpf_aux_classic_check_t trans, bool save_orig);
 void bpf_prog_destroy(struct bpf_prog *fp);
+const struct bpf_func_proto *
+bpf_base_func_proto(enum bpf_func_id func_id);
 
 int sk_attach_filter(struct sock_fprog *fprog, struct sock *sk);
 int sk_attach_bpf(u32 ufd, struct sock *sk);
index 7df37e2fddca40621779e6dd8943d1a7451698dc..9dd975be7fdf1a3381886ae28278a0f13db3e892 100644 (file)
@@ -1007,6 +1007,7 @@ enum tcp_ca_ack_event_flags {
 #define TCP_CONG_NON_RESTRICTED 0x1
 /* Requires ECN/ECT set on all packets */
 #define TCP_CONG_NEEDS_ECN     0x2
+#define TCP_CONG_MASK  (TCP_CONG_NON_RESTRICTED | TCP_CONG_NEEDS_ECN)
 
 union tcp_cc_info;
 
@@ -1101,6 +1102,7 @@ u32 tcp_reno_undo_cwnd(struct sock *sk);
 void tcp_reno_cong_avoid(struct sock *sk, u32 ack, u32 acked);
 extern struct tcp_congestion_ops tcp_reno;
 
+struct tcp_congestion_ops *tcp_ca_find(const char *name);
 struct tcp_congestion_ops *tcp_ca_find_key(u32 key);
 u32 tcp_ca_get_key_by_name(struct net *net, const char *name, bool *ecn_ca);
 #ifdef CONFIG_INET
index 7bb13ff49ec2e068312e956f2b95ff605c260e60..066d83ea1c99c580b9b0d898c55dd7c7eb52d7fa 100644 (file)
@@ -1,4 +1,9 @@
 /* SPDX-License-Identifier: GPL-2.0 */
 /* internal file - do not include directly */
 
-/* To be filled in a later patch */
+#ifdef CONFIG_BPF_JIT
+#ifdef CONFIG_INET
+#include <net/tcp.h>
+BPF_STRUCT_OPS_TYPE(tcp_congestion_ops)
+#endif
+#endif
index 42fd17c48c5f7036ec70b39695059c050801faad..a702761ef3696ea2189a8cda4e454cf3c7918b90 100644 (file)
@@ -5935,7 +5935,7 @@ bool bpf_helper_changes_pkt_data(void *func)
        return false;
 }
 
-static const struct bpf_func_proto *
+const struct bpf_func_proto *
 bpf_base_func_proto(enum bpf_func_id func_id)
 {
        switch (func_id) {
index d57ecfaf89d48c73f00bf7f2a151365648c1001c..9d97bace13c83bbbf13efb9c1c790d480f77808b 100644 (file)
@@ -65,3 +65,7 @@ obj-$(CONFIG_NETLABEL) += cipso_ipv4.o
 
 obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \
                      xfrm4_output.o xfrm4_protocol.o
+
+ifeq ($(CONFIG_BPF_JIT),y)
+obj-$(CONFIG_BPF_SYSCALL) += bpf_tcp_ca.o
+endif
diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c
new file mode 100644 (file)
index 0000000..9c7745b
--- /dev/null
@@ -0,0 +1,230 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright (c) 2019 Facebook  */
+
+#include <linux/types.h>
+#include <linux/bpf_verifier.h>
+#include <linux/bpf.h>
+#include <linux/btf.h>
+#include <linux/filter.h>
+#include <net/tcp.h>
+
+static u32 optional_ops[] = {
+       offsetof(struct tcp_congestion_ops, init),
+       offsetof(struct tcp_congestion_ops, release),
+       offsetof(struct tcp_congestion_ops, set_state),
+       offsetof(struct tcp_congestion_ops, cwnd_event),
+       offsetof(struct tcp_congestion_ops, in_ack_event),
+       offsetof(struct tcp_congestion_ops, pkts_acked),
+       offsetof(struct tcp_congestion_ops, min_tso_segs),
+       offsetof(struct tcp_congestion_ops, sndbuf_expand),
+       offsetof(struct tcp_congestion_ops, cong_control),
+};
+
+static u32 unsupported_ops[] = {
+       offsetof(struct tcp_congestion_ops, get_info),
+};
+
+static const struct btf_type *tcp_sock_type;
+static u32 tcp_sock_id, sock_id;
+
+static int bpf_tcp_ca_init(struct btf *btf)
+{
+       s32 type_id;
+
+       type_id = btf_find_by_name_kind(btf, "sock", BTF_KIND_STRUCT);
+       if (type_id < 0)
+               return -EINVAL;
+       sock_id = type_id;
+
+       type_id = btf_find_by_name_kind(btf, "tcp_sock", BTF_KIND_STRUCT);
+       if (type_id < 0)
+               return -EINVAL;
+       tcp_sock_id = type_id;
+       tcp_sock_type = btf_type_by_id(btf, tcp_sock_id);
+
+       return 0;
+}
+
+static bool is_optional(u32 member_offset)
+{
+       unsigned int i;
+
+       for (i = 0; i < ARRAY_SIZE(optional_ops); i++) {
+               if (member_offset == optional_ops[i])
+                       return true;
+       }
+
+       return false;
+}
+
+static bool is_unsupported(u32 member_offset)
+{
+       unsigned int i;
+
+       for (i = 0; i < ARRAY_SIZE(unsupported_ops); i++) {
+               if (member_offset == unsupported_ops[i])
+                       return true;
+       }
+
+       return false;
+}
+
+extern struct btf *btf_vmlinux;
+
+static bool bpf_tcp_ca_is_valid_access(int off, int size,
+                                      enum bpf_access_type type,
+                                      const struct bpf_prog *prog,
+                                      struct bpf_insn_access_aux *info)
+{
+       if (off < 0 || off >= sizeof(__u64) * MAX_BPF_FUNC_ARGS)
+               return false;
+       if (type != BPF_READ)
+               return false;
+       if (off % size != 0)
+               return false;
+
+       if (!btf_ctx_access(off, size, type, prog, info))
+               return false;
+
+       if (info->reg_type == PTR_TO_BTF_ID && info->btf_id == sock_id)
+               /* promote it to tcp_sock */
+               info->btf_id = tcp_sock_id;
+
+       return true;
+}
+
+static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log,
+                                       const struct btf_type *t, int off,
+                                       int size, enum bpf_access_type atype,
+                                       u32 *next_btf_id)
+{
+       size_t end;
+
+       if (atype == BPF_READ)
+               return btf_struct_access(log, t, off, size, atype, next_btf_id);
+
+       if (t != tcp_sock_type) {
+               bpf_log(log, "only read is supported\n");
+               return -EACCES;
+       }
+
+       switch (off) {
+       case bpf_ctx_range(struct inet_connection_sock, icsk_ca_priv):
+               end = offsetofend(struct inet_connection_sock, icsk_ca_priv);
+               break;
+       case offsetof(struct inet_connection_sock, icsk_ack.pending):
+               end = offsetofend(struct inet_connection_sock,
+                                 icsk_ack.pending);
+               break;
+       case offsetof(struct tcp_sock, snd_cwnd):
+               end = offsetofend(struct tcp_sock, snd_cwnd);
+               break;
+       case offsetof(struct tcp_sock, snd_cwnd_cnt):
+               end = offsetofend(struct tcp_sock, snd_cwnd_cnt);
+               break;
+       case offsetof(struct tcp_sock, snd_ssthresh):
+               end = offsetofend(struct tcp_sock, snd_ssthresh);
+               break;
+       case offsetof(struct tcp_sock, ecn_flags):
+               end = offsetofend(struct tcp_sock, ecn_flags);
+               break;
+       default:
+               bpf_log(log, "no write support to tcp_sock at off %d\n", off);
+               return -EACCES;
+       }
+
+       if (off + size > end) {
+               bpf_log(log,
+                       "write access at off %d with size %d beyond the member of tcp_sock ended at %zu\n",
+                       off, size, end);
+               return -EACCES;
+       }
+
+       return NOT_INIT;
+}
+
+static const struct bpf_func_proto *
+bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id,
+                         const struct bpf_prog *prog)
+{
+       return bpf_base_func_proto(func_id);
+}
+
+static const struct bpf_verifier_ops bpf_tcp_ca_verifier_ops = {
+       .get_func_proto         = bpf_tcp_ca_get_func_proto,
+       .is_valid_access        = bpf_tcp_ca_is_valid_access,
+       .btf_struct_access      = bpf_tcp_ca_btf_struct_access,
+};
+
+static int bpf_tcp_ca_init_member(const struct btf_type *t,
+                                 const struct btf_member *member,
+                                 void *kdata, const void *udata)
+{
+       const struct tcp_congestion_ops *utcp_ca;
+       struct tcp_congestion_ops *tcp_ca;
+       size_t tcp_ca_name_len;
+       int prog_fd;
+       u32 moff;
+
+       utcp_ca = (const struct tcp_congestion_ops *)udata;
+       tcp_ca = (struct tcp_congestion_ops *)kdata;
+
+       moff = btf_member_bit_offset(t, member) / 8;
+       switch (moff) {
+       case offsetof(struct tcp_congestion_ops, flags):
+               if (utcp_ca->flags & ~TCP_CONG_MASK)
+                       return -EINVAL;
+               tcp_ca->flags = utcp_ca->flags;
+               return 1;
+       case offsetof(struct tcp_congestion_ops, name):
+               tcp_ca_name_len = strnlen(utcp_ca->name, sizeof(utcp_ca->name));
+               if (!tcp_ca_name_len ||
+                   tcp_ca_name_len == sizeof(utcp_ca->name))
+                       return -EINVAL;
+               if (tcp_ca_find(utcp_ca->name))
+                       return -EEXIST;
+               memcpy(tcp_ca->name, utcp_ca->name, sizeof(tcp_ca->name));
+               return 1;
+       }
+
+       if (!btf_type_resolve_func_ptr(btf_vmlinux, member->type, NULL))
+               return 0;
+
+       /* Ensure bpf_prog is provided for compulsory func ptr */
+       prog_fd = (int)(*(unsigned long *)(udata + moff));
+       if (!prog_fd && !is_optional(moff) && !is_unsupported(moff))
+               return -EINVAL;
+
+       return 0;
+}
+
+static int bpf_tcp_ca_check_member(const struct btf_type *t,
+                                  const struct btf_member *member)
+{
+       if (is_unsupported(btf_member_bit_offset(t, member) / 8))
+               return -ENOTSUPP;
+       return 0;
+}
+
+static int bpf_tcp_ca_reg(void *kdata)
+{
+       return tcp_register_congestion_control(kdata);
+}
+
+static void bpf_tcp_ca_unreg(void *kdata)
+{
+       tcp_unregister_congestion_control(kdata);
+}
+
+/* Avoid sparse warning.  It is only used in bpf_struct_ops.c. */
+extern struct bpf_struct_ops bpf_tcp_congestion_ops;
+
+struct bpf_struct_ops bpf_tcp_congestion_ops = {
+       .verifier_ops = &bpf_tcp_ca_verifier_ops,
+       .reg = bpf_tcp_ca_reg,
+       .unreg = bpf_tcp_ca_unreg,
+       .check_member = bpf_tcp_ca_check_member,
+       .init_member = bpf_tcp_ca_init_member,
+       .init = bpf_tcp_ca_init,
+       .name = "tcp_congestion_ops",
+};
index 3737ec096650271b49456de9906ccf26465c7b02..3172e31987be4232af90e7b204742c5bb09ef6ca 100644 (file)
@@ -21,7 +21,7 @@ static DEFINE_SPINLOCK(tcp_cong_list_lock);
 static LIST_HEAD(tcp_cong_list);
 
 /* Simple linear search, don't expect many entries! */
-static struct tcp_congestion_ops *tcp_ca_find(const char *name)
+struct tcp_congestion_ops *tcp_ca_find(const char *name)
 {
        struct tcp_congestion_ops *e;
 
@@ -162,7 +162,7 @@ void tcp_assign_congestion_control(struct sock *sk)
 
        rcu_read_lock();
        ca = rcu_dereference(net->ipv4.tcp_congestion_control);
-       if (unlikely(!try_module_get(ca->owner)))
+       if (unlikely(!bpf_try_module_get(ca, ca->owner)))
                ca = &tcp_reno;
        icsk->icsk_ca_ops = ca;
        rcu_read_unlock();
@@ -208,7 +208,7 @@ void tcp_cleanup_congestion_control(struct sock *sk)
 
        if (icsk->icsk_ca_ops->release)
                icsk->icsk_ca_ops->release(sk);
-       module_put(icsk->icsk_ca_ops->owner);
+       bpf_module_put(icsk->icsk_ca_ops, icsk->icsk_ca_ops->owner);
 }
 
 /* Used by sysctl to change default congestion control */
@@ -222,12 +222,12 @@ int tcp_set_default_congestion_control(struct net *net, const char *name)
        ca = tcp_ca_find_autoload(net, name);
        if (!ca) {
                ret = -ENOENT;
-       } else if (!try_module_get(ca->owner)) {
+       } else if (!bpf_try_module_get(ca, ca->owner)) {
                ret = -EBUSY;
        } else {
                prev = xchg(&net->ipv4.tcp_congestion_control, ca);
                if (prev)
-                       module_put(prev->owner);
+                       bpf_module_put(prev, prev->owner);
 
                ca->flags |= TCP_CONG_NON_RESTRICTED;
                ret = 0;
@@ -366,19 +366,19 @@ int tcp_set_congestion_control(struct sock *sk, const char *name, bool load,
        } else if (!load) {
                const struct tcp_congestion_ops *old_ca = icsk->icsk_ca_ops;
 
-               if (try_module_get(ca->owner)) {
+               if (bpf_try_module_get(ca, ca->owner)) {
                        if (reinit) {
                                tcp_reinit_congestion_control(sk, ca);
                        } else {
                                icsk->icsk_ca_ops = ca;
-                               module_put(old_ca->owner);
+                               bpf_module_put(old_ca, old_ca->owner);
                        }
                } else {
                        err = -EBUSY;
                }
        } else if (!((ca->flags & TCP_CONG_NON_RESTRICTED) || cap_net_admin)) {
                err = -EPERM;
-       } else if (!try_module_get(ca->owner)) {
+       } else if (!bpf_try_module_get(ca, ca->owner)) {
                err = -EBUSY;
        } else {
                tcp_reinit_congestion_control(sk, ca);
index 4adac9c75343d5e62dbcfadc68457704ec4b2979..317ccca548a21ffcfefe33cb89249505d381d5e2 100644 (file)
@@ -2678,7 +2678,8 @@ static void __net_exit tcp_sk_exit(struct net *net)
        int cpu;
 
        if (net->ipv4.tcp_congestion_control)
-               module_put(net->ipv4.tcp_congestion_control->owner);
+               bpf_module_put(net->ipv4.tcp_congestion_control,
+                              net->ipv4.tcp_congestion_control->owner);
 
        for_each_possible_cpu(cpu)
                inet_ctl_sock_destroy(*per_cpu_ptr(net->ipv4.tcp_sk, cpu));
@@ -2785,7 +2786,8 @@ static int __net_init tcp_sk_init(struct net *net)
 
        /* Reno is always built in */
        if (!net_eq(net, &init_net) &&
-           try_module_get(init_net.ipv4.tcp_congestion_control->owner))
+           bpf_try_module_get(init_net.ipv4.tcp_congestion_control,
+                              init_net.ipv4.tcp_congestion_control->owner))
                net->ipv4.tcp_congestion_control = init_net.ipv4.tcp_congestion_control;
        else
                net->ipv4.tcp_congestion_control = &tcp_reno;
index c802bc80c4006f82c2e9189ef1fc11b8f321e70d..ad3b56d9fa7156f724f7558abccb1367fb5ea8d3 100644 (file)
@@ -414,7 +414,7 @@ void tcp_ca_openreq_child(struct sock *sk, const struct dst_entry *dst)
 
                rcu_read_lock();
                ca = tcp_ca_find_key(ca_key);
-               if (likely(ca && try_module_get(ca->owner))) {
+               if (likely(ca && bpf_try_module_get(ca, ca->owner))) {
                        icsk->icsk_ca_dst_locked = tcp_ca_dst_locked(dst);
                        icsk->icsk_ca_ops = ca;
                        ca_got_dst = true;
@@ -425,7 +425,7 @@ void tcp_ca_openreq_child(struct sock *sk, const struct dst_entry *dst)
        /* If no valid choice made yet, assign current system default ca. */
        if (!ca_got_dst &&
            (!icsk->icsk_ca_setsockopt ||
-            !try_module_get(icsk->icsk_ca_ops->owner)))
+            !bpf_try_module_get(icsk->icsk_ca_ops, icsk->icsk_ca_ops->owner)))
                tcp_assign_congestion_control(sk);
 
        tcp_set_ca_state(sk, TCP_CA_Open);
index 58c92a7d671c54564479db061dcec186a8a7bb34..377cfab422dfb4b8691ab5dcb14e66181e3922df 100644 (file)
@@ -3368,8 +3368,8 @@ static void tcp_ca_dst_init(struct sock *sk, const struct dst_entry *dst)
 
        rcu_read_lock();
        ca = tcp_ca_find_key(ca_key);
-       if (likely(ca && try_module_get(ca->owner))) {
-               module_put(icsk->icsk_ca_ops->owner);
+       if (likely(ca && bpf_try_module_get(ca, ca->owner))) {
+               bpf_module_put(icsk->icsk_ca_ops, icsk->icsk_ca_ops->owner);
                icsk->icsk_ca_dst_locked = tcp_ca_dst_locked(dst);
                icsk->icsk_ca_ops = ca;
        }