mlx: use TRIE & struct based param decoding

Also fix two bugs with the properties parameter to the set_params call:
- the parameter wasn't listed in the settables table
- the parameter was ignored unless there was a public key present

Reviewed-by: Richard Levitte <levitte@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/27859)
This commit is contained in:
Pauli 2025-06-20 11:29:00 +10:00
parent b622ae3917
commit cf13e66522
1 changed files with 76 additions and 81 deletions

View File

@ -6,6 +6,9 @@
* in the file LICENSE in the source distribution or at
* https://www.openssl.org/source/license.html
*/
{-
use OpenSSL::paramnames qw(produce_param_decoder);
-}
#include <openssl/core_dispatch.h>
#include <openssl/core_names.h>
@ -144,6 +147,11 @@ static int mlx_kem_match(const void *vkey1, const void *vkey2, int selection)
&& EVP_PKEY_eq(key1->xkey, key2->xkey);
}
{- produce_param_decoder('ml_kem_import_export',
(['PKEY_PARAM_PRIV_KEY', 'privkey', 'octet_string'],
['PKEY_PARAM_PUB_KEY', 'pubkey', 'octet_string'],
)); -}
typedef struct export_cb_arg_st {
const char *algorithm_name;
uint8_t *pubenc;
@ -160,7 +168,7 @@ typedef struct export_cb_arg_st {
static int export_sub_cb(const OSSL_PARAM *params, void *varg)
{
EXPORT_CB_ARG *sub_arg = varg;
const OSSL_PARAM *p = NULL;
struct ml_kem_import_export_st p;
size_t len;
/*
@ -170,11 +178,11 @@ static int export_sub_cb(const OSSL_PARAM *params, void *varg)
*/
if (ossl_param_is_empty(params))
return 1;
if (sub_arg->pubenc != NULL
&& (p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PUB_KEY)) != NULL) {
p = ml_kem_import_export_decoder(params);
if (sub_arg->pubenc != NULL && p.pubkey != NULL) {
void *pub = sub_arg->pubenc + sub_arg->puboff;
if (OSSL_PARAM_get_octet_string(p, &pub, sub_arg->publen, &len) != 1)
if (OSSL_PARAM_get_octet_string(p.pubkey, &pub, sub_arg->publen, &len) != 1)
return 0;
if (len != sub_arg->publen) {
ERR_raise_data(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR,
@ -185,11 +193,10 @@ static int export_sub_cb(const OSSL_PARAM *params, void *varg)
}
++sub_arg->pubcount;
}
if (sub_arg->prvenc != NULL
&& (p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PRIV_KEY)) != NULL) {
if (sub_arg->prvenc != NULL && p.privkey != NULL) {
void *prv = sub_arg->prvenc + sub_arg->prvoff;
if (OSSL_PARAM_get_octet_string(p, &prv, sub_arg->prvlen, &len) != 1)
if (OSSL_PARAM_get_octet_string(p.privkey, &prv, sub_arg->prvlen, &len) != 1)
return 0;
if (len != sub_arg->prvlen) {
ERR_raise_data(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR,
@ -319,14 +326,8 @@ static int mlx_kem_export(void *vkey, int selection, OSSL_CALLBACK *param_cb,
static const OSSL_PARAM *mlx_kem_imexport_types(int selection)
{
static const OSSL_PARAM key_types[] = {
OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PUB_KEY, NULL, 0),
OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PRIV_KEY, NULL, 0),
OSSL_PARAM_END
};
if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0)
return key_types;
return ml_kem_import_export_list;
return NULL;
}
@ -411,7 +412,7 @@ static int mlx_kem_key_fromdata(MLX_KEY *key,
const OSSL_PARAM params[],
int include_private)
{
const OSSL_PARAM *param_prv_key = NULL, *param_pub_key;
struct ml_kem_import_export_st p;
const void *pubenc = NULL, *prvenc = NULL;
size_t pubkey_bytes, prvkey_bytes;
size_t publen = 0, prvlen = 0;
@ -422,16 +423,15 @@ static int mlx_kem_key_fromdata(MLX_KEY *key,
pubkey_bytes = key->minfo->pubkey_bytes + key->xinfo->pubkey_bytes;
prvkey_bytes = key->minfo->prvkey_bytes + key->xinfo->prvkey_bytes;
p = ml_kem_import_export_decoder(params);
/* What does the caller want to set? */
param_pub_key = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PUB_KEY);
if (param_pub_key != NULL &&
OSSL_PARAM_get_octet_string_ptr(param_pub_key, &pubenc, &publen) != 1)
if (p.pubkey != NULL &&
OSSL_PARAM_get_octet_string_ptr(p.pubkey, &pubenc, &publen) != 1)
return 0;
if (include_private)
param_prv_key = OSSL_PARAM_locate_const(params,
OSSL_PKEY_PARAM_PRIV_KEY);
if (param_prv_key != NULL &&
OSSL_PARAM_get_octet_string_ptr(param_prv_key, &prvenc, &prvlen) != 1)
if (include_private
&& p.privkey != NULL
&& OSSL_PARAM_get_octet_string_ptr(p.privkey, &prvenc, &prvlen) != 1)
return 0;
/* The caller MUST specify at least one of the public or private keys. */
@ -472,19 +472,18 @@ static int mlx_kem_import(void *vkey, int selection, const OSSL_PARAM params[])
return mlx_kem_key_fromdata(key, params, include_private);
}
{- produce_param_decoder('mlx_get_params',
(['PKEY_PARAM_BITS', 'bits', 'int'],
['PKEY_PARAM_SECURITY_BITS', 'secbits', 'int'],
['PKEY_PARAM_MAX_SIZE', 'maxsize', 'int'],
['PKEY_PARAM_SECURITY_CATEGORY', 'seccat', 'int'],
['PKEY_PARAM_ENCODED_PUBLIC_KEY', 'pub', 'octet_string'],
['PKEY_PARAM_PRIV_KEY', 'priv', 'octet_string'],
)); -}
static const OSSL_PARAM *mlx_kem_gettable_params(void *provctx)
{
static const OSSL_PARAM arr[] = {
OSSL_PARAM_int(OSSL_PKEY_PARAM_BITS, NULL),
OSSL_PARAM_int(OSSL_PKEY_PARAM_SECURITY_BITS, NULL),
OSSL_PARAM_int(OSSL_PKEY_PARAM_MAX_SIZE, NULL),
OSSL_PARAM_int(OSSL_PKEY_PARAM_SECURITY_CATEGORY, NULL),
OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY, NULL, 0),
OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PRIV_KEY, NULL, 0),
OSSL_PARAM_END
};
return arr;
return mlx_get_params_list;
}
/*
@ -493,42 +492,40 @@ static const OSSL_PARAM *mlx_kem_gettable_params(void *provctx)
static int mlx_kem_get_params(void *vkey, OSSL_PARAM params[])
{
MLX_KEY *key = vkey;
OSSL_PARAM *p, *pub, *prv = NULL;
OSSL_PARAM *pub, *prv = NULL;
EXPORT_CB_ARG sub_arg;
int selection;
struct mlx_get_params_st p;
size_t publen = key->minfo->pubkey_bytes + key->xinfo->pubkey_bytes;
size_t prvlen = key->minfo->prvkey_bytes + key->xinfo->prvkey_bytes;
p = mlx_get_params_decoder(params);
/* The reported "bit" count is those of the ML-KEM key */
p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_BITS);
if (p != NULL)
if (!OSSL_PARAM_set_int(p, key->minfo->bits))
if (p.bits != NULL)
if (!OSSL_PARAM_set_int(p.bits, key->minfo->bits))
return 0;
/* The reported security bits are those of the ML-KEM key */
p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_SECURITY_BITS);
if (p != NULL)
if (!OSSL_PARAM_set_int(p, key->minfo->secbits))
if (p.secbits != NULL)
if (!OSSL_PARAM_set_int(p.secbits, key->minfo->secbits))
return 0;
/* The reported security category are those of the ML-KEM key */
p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_SECURITY_CATEGORY);
if (p != NULL)
if (!OSSL_PARAM_set_int(p, key->minfo->security_category))
if (p.seccat != NULL)
if (!OSSL_PARAM_set_int(p.seccat, key->minfo->security_category))
return 0;
/* The ciphertext sizes are additive */
p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_MAX_SIZE);
if (p != NULL)
if (!OSSL_PARAM_set_int(p, key->minfo->ctext_bytes + key->xinfo->pubkey_bytes))
if (p.maxsize != NULL)
if (!OSSL_PARAM_set_int(p.maxsize, key->minfo->ctext_bytes + key->xinfo->pubkey_bytes))
return 0;
if (!mlx_kem_have_pubkey(key))
return 1;
memset(&sub_arg, 0, sizeof(sub_arg));
pub = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY);
if (pub != NULL) {
if ((pub = p.pub) != NULL) {
if (pub->data_type != OSSL_PARAM_OCTET_STRING)
return 0;
pub->return_size = publen;
@ -545,8 +542,7 @@ static int mlx_kem_get_params(void *vkey, OSSL_PARAM params[])
}
}
if (mlx_kem_have_prvkey(key)) {
prv = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_PRIV_KEY);
if (prv != NULL) {
if ((prv = p.priv) != NULL) {
if (prv->data_type != OSSL_PARAM_OCTET_STRING)
return 0;
prv->return_size = prvlen;
@ -582,29 +578,36 @@ static int mlx_kem_get_params(void *vkey, OSSL_PARAM params[])
return 1;
}
{- produce_param_decoder('mlx_set_params',
(['PKEY_PARAM_ENCODED_PUBLIC_KEY', 'pub', 'octet_string'],
['PKEY_PARAM_PROPERTIES', 'propq', 'utf8_string'],
)); -}
static const OSSL_PARAM *mlx_kem_settable_params(void *provctx)
{
static const OSSL_PARAM arr[] = {
OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY, NULL, 0),
OSSL_PARAM_END
};
return arr;
return mlx_set_params_list;
}
static int mlx_kem_set_params(void *vkey, const OSSL_PARAM params[])
{
MLX_KEY *key = vkey;
const OSSL_PARAM *p;
struct mlx_set_params_st p;
const void *pubenc = NULL;
size_t publen = 0;
if (ossl_param_is_empty(params))
return 1;
/* Only one settable parameter is supported */
p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY);
if (p == NULL)
p = mlx_set_params_decoder(params);
if (p.propq != NULL) {
OPENSSL_free(key->propq);
key->propq = NULL;
if (!OSSL_PARAM_get_utf8_string(p.propq, &key->propq, 0))
return 0;
}
if (p.pub == NULL)
return 1;
/* Key mutation is reportedly generally not allowed */
@ -615,17 +618,9 @@ static int mlx_kem_set_params(void *vkey, const OSSL_PARAM params[])
return 0;
}
/* An unlikely failure mode is the parameter having some unexpected type */
if (!OSSL_PARAM_get_octet_string_ptr(p, &pubenc, &publen))
if (!OSSL_PARAM_get_octet_string_ptr(p.pub, &pubenc, &publen))
return 0;
p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PROPERTIES);
if (p != NULL) {
OPENSSL_free(key->propq);
key->propq = NULL;
if (!OSSL_PARAM_get_utf8_string(p, &key->propq, 0))
return 0;
}
if (publen != key->minfo->pubkey_bytes + key->xinfo->pubkey_bytes) {
ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY);
return 0;
@ -634,22 +629,27 @@ static int mlx_kem_set_params(void *vkey, const OSSL_PARAM params[])
return load_keys(key, pubenc, publen, NULL, 0);
}
{- produce_param_decoder('mlx_gen_set_params',
(['PKEY_PARAM_PROPERTIES', 'propq', 'utf8_string'],
)); -}
static int mlx_kem_gen_set_params(void *vgctx, const OSSL_PARAM params[])
{
PROV_ML_KEM_GEN_CTX *gctx = vgctx;
const OSSL_PARAM *p;
struct mlx_gen_set_params_st p;
if (gctx == NULL)
return 0;
if (ossl_param_is_empty(params))
return 1;
p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PROPERTIES);
if (p != NULL) {
if (p->data_type != OSSL_PARAM_UTF8_STRING)
p = mlx_gen_set_params_decoder(params);
if (p.propq != NULL) {
if (p.propq->data_type != OSSL_PARAM_UTF8_STRING)
return 0;
OPENSSL_free(gctx->propq);
if ((gctx->propq = OPENSSL_strdup(p->data)) == NULL)
if ((gctx->propq = OPENSSL_strdup(p.propq->data)) == NULL)
return 0;
}
return 1;
@ -682,12 +682,7 @@ static void *mlx_kem_gen_init(int evp_type, OSSL_LIB_CTX *libctx,
static const OSSL_PARAM *mlx_kem_gen_settable_params(ossl_unused void *vgctx,
ossl_unused void *provctx)
{
static OSSL_PARAM settable[] = {
OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PROPERTIES, NULL, 0),
OSSL_PARAM_END
};
return settable;
return mlx_gen_set_params_list;
}
static void *mlx_kem_gen(void *vgctx, OSSL_CALLBACK *osslcb, void *cbarg)