e->valid = 0;
- msk = mptcp_token_get_sock(e->token);
+ msk = mptcp_token_get_sock(net, e->token);
if (!msk) {
spin_unlock_bh(&join_entry_locks[i]);
return false;
}
- /* If this fails, the token got re-used in the mean time by another
- * mptcp socket in a different netns, i.e. entry is outdated.
- */
- if (!net_eq(sock_net((struct sock *)msk), net))
- goto err_put;
-
subflow_req->remote_nonce = e->remote_nonce;
subflow_req->local_nonce = e->local_nonce;
subflow_req->backup = e->backup;
subflow_req->msk = msk;
spin_unlock_bh(&join_entry_locks[i]);
return true;
-
-err_put:
- spin_unlock_bh(&join_entry_locks[i]);
- sock_put((struct sock *)msk);
- return false;
}
void __init mptcp_join_cookie_init(void)
/**
* mptcp_token_get_sock - retrieve mptcp connection sock using its token
+ * @net: restrict to this namespace
* @token: token of the mptcp connection to retrieve
*
* This function returns the mptcp connection structure with the given token.
*
* returns NULL if no connection with the given token value exists.
*/
-struct mptcp_sock *mptcp_token_get_sock(u32 token)
+struct mptcp_sock *mptcp_token_get_sock(struct net *net, u32 token)
{
struct hlist_nulls_node *pos;
struct token_bucket *bucket;
again:
sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
msk = mptcp_sk(sk);
- if (READ_ONCE(msk->token) != token)
+ if (READ_ONCE(msk->token) != token ||
+ !net_eq(sock_net(sk), net))
continue;
+
if (!refcount_inc_not_zero(&sk->sk_refcnt))
goto not_found;
- if (READ_ONCE(msk->token) != token) {
+
+ if (READ_ONCE(msk->token) != token ||
+ !net_eq(sock_net(sk), net)) {
sock_put(sk);
goto again;
}
GFP_USER);
KUNIT_EXPECT_NOT_ERR_OR_NULL(test, req);
mptcp_token_init_request((struct request_sock *)req);
+ sock_net_set((struct sock *)req, &init_net);
return req;
}
KUNIT_ASSERT_EQ(test, 0,
mptcp_token_new_request((struct request_sock *)req));
KUNIT_EXPECT_NE(test, 0, (int)req->token);
- KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(req->token));
+ KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, req->token));
/* cleanup */
mptcp_token_destroy_request((struct request_sock *)req);
msk = kunit_kzalloc(test, sizeof(struct mptcp_sock), GFP_USER);
KUNIT_EXPECT_NOT_ERR_OR_NULL(test, msk);
refcount_set(&((struct sock *)msk)->sk_refcnt, 1);
+ sock_net_set((struct sock *)msk, &init_net);
return msk;
}
mptcp_token_new_connect((struct sock *)icsk));
KUNIT_EXPECT_NE(test, 0, (int)ctx->token);
KUNIT_EXPECT_EQ(test, ctx->token, msk->token);
- KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(ctx->token));
+ KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, ctx->token));
KUNIT_EXPECT_EQ(test, 2, (int)refcount_read(&sk->sk_refcnt));
mptcp_token_destroy(msk);
- KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(ctx->token));
+ KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, ctx->token));
}
static void mptcp_token_test_accept(struct kunit *test)
mptcp_token_new_request((struct request_sock *)req));
msk->token = req->token;
mptcp_token_accept(req, msk);
- KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(msk->token));
+ KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, msk->token));
/* this is now a no-op */
mptcp_token_destroy_request((struct request_sock *)req);
- KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(msk->token));
+ KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, msk->token));
/* cleanup */
mptcp_token_destroy(msk);
/* simulate race on removal */
refcount_set(&sk->sk_refcnt, 0);
- KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(msk->token));
+ KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, msk->token));
/* cleanup */
mptcp_token_destroy(msk);