subtypes: add some parsing for subtypes, so that it passes

this probably could be consolidated, as it splits
out all the print_to/fromwire method stuff for the Subtype class
This commit is contained in:
lisa neigut 2019-03-27 18:58:29 -07:00 committed by Rusty Russell
parent 94395c6a9a
commit 37d6545191
1 changed files with 145 additions and 17 deletions

View File

@ -55,6 +55,9 @@ class FieldType(object):
def is_assignable(self):
return self.name in ['u8', 'u16', 'u32', 'u64', 'bool', 'struct amount_msat', 'struct amount_sat', 'var_int'] or self.name.startswith('enum ')
def needs_ptr(self):
return not self.is_assignable()
# We only accelerate the u8 case: it's common and trivial.
def has_array_helper(self):
return self.name in ['u8']
@ -69,6 +72,12 @@ class FieldType(object):
return 'u64'
return basetype
def is_subtype(self):
for subtype in subtypes:
if subtype.name == self.base():
return True
return False
# Returns base size
@staticmethod
def _typesize(typename):
@ -144,6 +153,7 @@ class Field(object):
self.num_elems = 1
self.optional = False
self.is_tlv = False
self.is_subtype = False
if name.endswith('_tlv'):
self.is_tlv = True
@ -169,23 +179,29 @@ class Field(object):
# Bolts use just a number: Guess type based on size.
if options.bolt:
if size == '$': # this is a subtype
if size.startswith('$'): # this is a subtype
self.fieldtype = FieldType('struct {}'.format(name))
elif size == 'var_int':
base_size = 8
self.fieldtype = FieldType(size)
self.is_subtype = True
if size[1:] == prevname:
self.lenvar = size[1:]
else:
raise ValueError('Expected size field for subtype field {}'.format(name))
else:
base_size = int(size)
self.fieldtype = Field._guess_type(message, self.name, base_size)
# There are some arrays which we have to guess, based on sizes.
tsize = FieldType._typesize(self.fieldtype.name)
if base_size % tsize != 0:
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'
.format(base_size,
self.message,
self.name,
tsize))
self.num_elems = int(base_size / tsize)
if size == 'var_int':
base_size = 8
self.fieldtype = FieldType(size)
else:
base_size = int(size)
self.fieldtype = Field._guess_type(message, self.name, base_size)
# There are some arrays which we have to guess, based on sizes.
tsize = FieldType._typesize(self.fieldtype.name)
if base_size % tsize != 0:
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'
.format(base_size,
self.message,
self.name,
tsize))
self.num_elems = int(base_size / tsize)
else:
# Real typename.
self.fieldtype = FieldType(size)
@ -255,6 +271,15 @@ fromwire_tlv_impl_templ = """static bool fromwire_{tlv_name}_{name}({ctx}{args})
}}
"""
fromwire_subtype_impl_templ = """static bool fromwire_{name}({ctx}{args})
{{
{fields}
{subcalls}
\treturn cursor != NULL;
}}
"""
fromwire_header_templ = """bool fromwire_{name}({ctx}const void *p{args});
"""
@ -408,8 +433,9 @@ class Message(object):
subcalls.append('({})[i] = fromwire_{}({}, {}cursor, {}plen);'
.format(name, basetype, ctx, p_ref, p_ref))
else:
subcalls.append('fromwire_{}({}cursor, {}plen, {} + i);'
.format(basetype, p_ref, p_ref, name))
ctx_arg = ctx + ', ' if f.fieldtype.is_subtype() else ''
subcalls.append('fromwire_{}({}{}cursor, {}plen, {} + i);'
.format(basetype, ctx_arg, p_ref, p_ref, name))
def print_fromwire(self, is_header):
ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else ''
@ -859,6 +885,99 @@ class Subtype(Message):
def print_struct(self):
return TlvMessage._inner_print_struct(self.name, self.fields)
def print_towire(self):
""" prints towire function definition for a subtype"""
field_decls = []
for f in self.fields:
if f.optional:
raise TypeError("Optional fields on subtypes not currently supported. {}".format(f.name))
if f.is_len_var:
field_decls.append('\t{0} {1} = tal_count({2}->{3});'.format(
f.fieldtype.name, f.name, self.name, f.lenvar_for.name
))
subcalls = CCode()
for f in self.fields:
basetype = f.fieldtype.base()
for c in f.comments:
subcalls.append('/*{} */'.format(c))
if f.is_padding():
subcalls.append('towire_pad(p, {});'.format(f.num_elems))
elif f.is_array():
self.print_towire_array(subcalls, basetype, f, f.num_elems,
is_tlv=True)
elif f.is_variable_size():
self.print_towire_array(subcalls, basetype, f, f.lenvar,
is_tlv=True)
elif f.is_len_var:
subcalls.append('towire_{}(p, {});'.format(basetype, f.name))
else:
ref = '&' if f.fieldtype.needs_ptr() else ''
subcalls.append('towire_{}(p, {}{}->{});'.format(basetype, ref, self.name, f.name))
return subtype_towire_stub.format(
name=self.name,
field_decls='\n'.join(field_decls),
subcalls=str(subcalls))
def print_fromwire(self):
""" prints fromwire function definition for a subtype.
these are significantly different in that they take in a struct
to populate, instead of fields.
"""
ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else ''
args = 'const u8 **cursor, size_t *plen, struct {name} *{name}'.format(name=self.name)
fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var]
subcalls = CCode()
for f in self.fields:
basetype = f.fieldtype.base()
if f.optional:
raise TypeError('Optional fields on subtypes not currently supported')
for c in f.comments:
subcalls.append('/*{} */'.format(c))
if f.is_padding():
subcalls.append('fromwire_pad(cursor, plen, {});'
.format(f.num_elems))
elif f.is_array():
name = '*{}->{}'.format(self.name, f.name)
self.print_fromwire_array('ctx', subcalls, basetype, f, name,
f.num_elems, is_tlv=True)
elif f.is_variable_size():
subcalls.append("// 2nd case {name}".format(name=f.name))
typename = f.fieldtype.name
# If structs are varlen, need array of ptrs to them.
if basetype in varlen_structs:
typename += ' *'
subcalls.append('{}->{} = {} ? tal_arr(ctx, {}, {}) : NULL;'
.format(self.name, f.name, f.lenvar, typename, f.lenvar))
name = '{}->{}'.format(self.name, f.name)
# Allocate these off the array itself, if they need alloc.
self.print_fromwire_array('*' + f.name, subcalls, basetype, f,
name, f.lenvar, is_tlv=True)
else:
if f.is_assignable():
if f.is_len_var:
s = '{} = fromwire_{}(cursor, plen);'.format(f.name, basetype)
else:
s = '{}->{} = fromwire_{}(cursor, plen);'.format(
self.name, f.name, basetype)
else:
ref = '&' if f.fieldtype.needs_ptr() else ''
s = 'fromwire_{}(cursor, plen, {}{}->{});'.format(
basetype, ref, self.name, f.name)
subcalls.append(s)
return fromwire_subtype_impl_templ.format(
name=self.name,
ctx=ctx_arg,
args=''.join(args),
fields=''.join(fields),
subcalls=str(subcalls)
)
tlv_message_towire_stub = """static void towire_{tlv_name}_{name}(u8 **p, struct tlv_msg_{name} *{name}) {{
{field_decls}
@ -866,6 +985,12 @@ tlv_message_towire_stub = """static void towire_{tlv_name}_{name}(u8 **p, struct
}}
"""
subtype_towire_stub = """static void towire_{name}(u8 **p, const struct {name} *{name}) {{
{field_decls}
{subcalls}
}}
"""
tlv_struct_template = """
struct {tlv_name} {{
{msg_type_structs}
@ -1345,6 +1470,9 @@ else:
if not options.header:
towire_decls += build_tlv_towires(tlv_fields)
fromwire_decls += build_tlv_fromwires(tlv_fields)
for subtype in subtypes:
towire_decls.append(subtype.print_towire())
fromwire_decls.append(subtype.print_fromwire())
towire_decls += [m.print_towire(options.header) for m in toplevel_messages + messages_with_option]
fromwire_decls += [m.print_fromwire(options.header) for m in toplevel_messages + messages_with_option]