tlv: allocate tlv structs from within

let's let the fromwire__tlv methods allocate the tlv-objects and
return them. we also want to initialize all of their underlying
messages to NULL, and fail if we discover a duplicate mesage type.

if parsing fails, instead of returning a struct we return NULL.

Suggested-By: @rustyrussell
This commit is contained in:
lisa neigut 2019-03-28 14:27:28 -07:00 committed by Rusty Russell
parent 5aea65b463
commit b89ea071e8
1 changed files with 22 additions and 8 deletions

View File

@ -504,8 +504,9 @@ class Message(object):
elif f.is_tlv:
if not f.is_variable_size():
raise TypeError('TLV {} not variable size'.format(f.name))
subcalls.append('if (!fromwire__{tlv_name}(ctx, &cursor, &plen, &{tlv_len}, {tlv_name}))'
subcalls.append('{tlv_name} = fromwire__{tlv_name}(ctx, &cursor, &plen, &{tlv_len});'
.format(tlv_name=f.name, tlv_len=f.lenvar))
subcalls.append('if (!{tlv_name})'.format(tlv_name=f.name))
subcalls.append('return false;')
elif f.is_variable_size():
subcalls.append("//2nd case {name}".format(name=f.name))
@ -846,10 +847,12 @@ tlv__type_impl_towire_template = """static void towire__{tlv_name}(const tal_t *
{fields}}}
"""
tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, size_t *plen, const u16 *len, struct {tlv_name} *{tlv_name}) {{
tlv__type_impl_fromwire_template = """static struct {tlv_name} *fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, size_t *plen, const u16 *len) {{
\tu8 msg_type, msg_len;
\tif (*plen < *len)
\t\treturn false;
\t\treturn NULL;
\tstruct {tlv_name} *{tlv_name} = talz(ctx, struct {tlv_name});
\twhile (*plen) {{
\t\tmsg_type = fromwire_u8(p, plen);
@ -865,14 +868,25 @@ tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal
\t\t\tplen -= msg_len;
\t\t}}
\t}}
\treturn *p != NULL;
\tif (!*p) {{
\t\ttal_free({tlv_name});
\t\treturn NULL;
\t}}
\treturn {tlv_name};
}}
"""
case_tmpl = """\t\tcase {tlv_msg_enum}:
\t\t\t{tlv_name}->{tlv_msg_name} = tal(ctx, struct tlv_msg_{tlv_msg_name});
\t\t\tif (!fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}*p, plen, msg_len, {tlv_name}->{tlv_msg_name}))
\t\t\t\treturn false;
\t\t\tif ({tlv_name}->{tlv_msg_name} != NULL) {{
\t\t\t\tfromwire_fail(p, plen);
\t\t\t\ttal_free({tlv_name});
\t\t\t\treturn NULL;
\t\t\t}}
\t\t\t{tlv_name}->{tlv_msg_name} = tal({tlv_name}, struct tlv_msg_{tlv_msg_name});
\t\t\tif (!fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}*p, plen, msg_len, {tlv_name}->{tlv_msg_name})) {{
\t\t\t\ttal_free({tlv_name});
\t\t\t\treturn NULL;
\t\t\t}}
\t\t\tbreak;
"""
@ -906,7 +920,7 @@ def print_tlv_towire(tlv_field_name, messages):
def print_tlv_fromwire(tlv_field_name, messages):
cases = ""
for m in messages:
ctx_arg = 'ctx, ' if m.has_variable_fields else ''
ctx_arg = tlv_field_name + ', ' if m.has_variable_fields else ''
cases += case_tmpl.format(ctx_arg=ctx_arg,
tlv_msg_enum=m.enum.name,
tlv_name=tlv_field_name,