diff --git a/ssl/quic/quic_impl.c b/ssl/quic/quic_impl.c index faf4cdb10c..ed0fa4ceaa 100644 --- a/ssl/quic/quic_impl.c +++ b/ssl/quic/quic_impl.c @@ -4527,6 +4527,18 @@ SSL *ossl_quic_new_domain(SSL_CTX *ctx, uint64_t flags) { QUIC_DOMAIN *qd = NULL; QUIC_ENGINE_ARGS engine_args = {0}; + uint64_t domain_flags; + + domain_flags = ctx->domain_flags; + if ((flags & (SSL_DOMAIN_FLAG_SINGLE_THREAD + | SSL_DOMAIN_FLAG_MULTI_THREAD + | SSL_DOMAIN_FLAG_THREAD_ASSISTED)) != 0) + domain_flags = flags; + else + domain_flags = ctx->domain_flags | flags; + + if (!ossl_adjust_domain_flags(domain_flags, &domain_flags)) + return NULL; if ((qd = OPENSSL_zalloc(sizeof(*qd))) == NULL) { QUIC_RAISE_NON_NORMAL_ERROR(NULL, ERR_R_CRYPTO_LIB, NULL); @@ -4545,7 +4557,7 @@ SSL *ossl_quic_new_domain(SSL_CTX *ctx, uint64_t flags) #if defined(OPENSSL_THREADS) engine_args.mutex = qd->mutex; #endif - if (need_notifier_for_domain_flags(ctx->domain_flags)) + if (need_notifier_for_domain_flags(domain_flags)) engine_args.reactor_flags |= QUIC_REACTOR_FLAG_USE_NOTIFIER; if ((qd->engine = ossl_quic_engine_new(&engine_args)) == NULL) { @@ -4558,6 +4570,7 @@ SSL *ossl_quic_new_domain(SSL_CTX *ctx, uint64_t flags) qd->engine, NULL)) goto err; + ossl_quic_obj_set_domain_flags(&qd->obj, domain_flags); return &qd->obj.ssl; err: diff --git a/ssl/quic/quic_obj_local.h b/ssl/quic/quic_obj_local.h index bf81b24a24..7cdda2f65a 100644 --- a/ssl/quic/quic_obj_local.h +++ b/ssl/quic/quic_obj_local.h @@ -327,5 +327,15 @@ ossl_quic_obj_get0_port_leader(const QUIC_OBJ *obj) : NULL; } +/* + * Change the domain flags. Should only be called immediately after + * ossl_quic_obj_init(). + */ +static ossl_inline ossl_unused void +ossl_quic_obj_set_domain_flags(QUIC_OBJ *obj, uint64_t domain_flags) +{ + obj->domain_flags = domain_flags; +} + # endif #endif diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c index f0d6165397..c1a51e2344 100644 --- a/ssl/ssl_lib.c +++ b/ssl/ssl_lib.c @@ -8020,41 +8020,50 @@ SSL *SSL_new_domain(SSL_CTX *ctx, uint64_t flags) #endif } +int ossl_adjust_domain_flags(uint64_t domain_flags, uint64_t *p_domain_flags) +{ + if ((domain_flags & ~OSSL_QUIC_SUPPORTED_DOMAIN_FLAGS) != 0) { + ERR_raise_data(ERR_LIB_SSL, ERR_R_UNSUPPORTED, + "unsupported domain flag requested"); + return 0; + } + + if ((domain_flags & SSL_DOMAIN_FLAG_THREAD_ASSISTED) != 0) + domain_flags |= SSL_DOMAIN_FLAG_MULTI_THREAD; + + if ((domain_flags & (SSL_DOMAIN_FLAG_MULTI_THREAD + | SSL_DOMAIN_FLAG_SINGLE_THREAD)) == 0) + domain_flags |= SSL_DOMAIN_FLAG_MULTI_THREAD; + + if ((domain_flags & SSL_DOMAIN_FLAG_SINGLE_THREAD) != 0 + && (domain_flags & SSL_DOMAIN_FLAG_MULTI_THREAD) != 0) { + ERR_raise_data(ERR_LIB_SSL, ERR_R_PASSED_INVALID_ARGUMENT, + "mutually exclusive domain flags specified"); + return 0; + } + + /* + * Note: We treat MULTI_THREAD as a no-op in non-threaded builds, but + * not THREAD_ASSISTED. + */ +# ifndef OPENSSL_THREADS + if ((domain_flags & SSL_DOMAIN_FLAG_THREAD_ASSISTED) != 0) { + ERR_raise_data(ERR_LIB_SSL, ERR_R_UNSUPPORTED, + "thread assisted mode not available in this build"); + return 0; + } +# endif + + *p_domain_flags = domain_flags; + return 1; +} + int SSL_CTX_set_domain_flags(SSL_CTX *ctx, uint64_t domain_flags) { #ifndef OPENSSL_NO_QUIC if (IS_QUIC_CTX(ctx)) { - if ((domain_flags & ~OSSL_QUIC_SUPPORTED_DOMAIN_FLAGS) != 0) { - ERR_raise_data(ERR_LIB_SSL, ERR_R_UNSUPPORTED, - "unsupported domain flag requested"); + if (!ossl_adjust_domain_flags(domain_flags, &domain_flags)) return 0; - } - - if ((domain_flags & SSL_DOMAIN_FLAG_THREAD_ASSISTED) != 0) - domain_flags |= SSL_DOMAIN_FLAG_MULTI_THREAD; - - if ((domain_flags & (SSL_DOMAIN_FLAG_MULTI_THREAD - | SSL_DOMAIN_FLAG_SINGLE_THREAD)) == 0) - domain_flags |= SSL_DOMAIN_FLAG_MULTI_THREAD; - - if ((domain_flags & SSL_DOMAIN_FLAG_SINGLE_THREAD) != 0 - && (domain_flags & SSL_DOMAIN_FLAG_MULTI_THREAD) != 0) { - ERR_raise_data(ERR_LIB_SSL, ERR_R_PASSED_INVALID_ARGUMENT, - "mutually exclusive domain flags specified"); - return 0; - } - - /* - * Note: We treat MULTI_THREAD as a no-op in non-threaded builds, but - * not THREAD_ASSISTED. - */ -# ifndef OPENSSL_THREADS - if ((domain_flags & SSL_DOMAIN_FLAG_THREAD_ASSISTED) != 0) { - ERR_raise_data(ERR_LIB_SSL, ERR_R_UNSUPPORTED, - "thread assisted mode not available in this build"); - return 0; - } -# endif ctx->domain_flags = domain_flags; return 1; diff --git a/ssl/ssl_local.h b/ssl/ssl_local.h index 7d9727aef7..8aa2cd5799 100644 --- a/ssl/ssl_local.h +++ b/ssl/ssl_local.h @@ -2908,6 +2908,9 @@ int ssl_get_md_idx(int md_nid); __owur const EVP_MD *ssl_handshake_md(SSL_CONNECTION *s); __owur const EVP_MD *ssl_prf_md(SSL_CONNECTION *s); +__owur int ossl_adjust_domain_flags(uint64_t domain_flags, + uint64_t *p_domain_flags); + /* * ssl_log_rsa_client_key_exchange logs |premaster| to the SSL_CTX associated * with |ssl|, if logging is enabled. It returns one on success and zero on