diff --git a/tools/generate-wire.py b/tools/generate-wire.py index 7fabefc3e..d5c75402f 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -55,6 +55,9 @@ class FieldType(object): def is_assignable(self): return self.name in ['u8', 'u16', 'u32', 'u64', 'bool', 'struct amount_msat', 'struct amount_sat', 'var_int'] or self.name.startswith('enum ') + def needs_ptr(self): + return not self.is_assignable() + # We only accelerate the u8 case: it's common and trivial. def has_array_helper(self): return self.name in ['u8'] @@ -69,6 +72,12 @@ class FieldType(object): return 'u64' return basetype + def is_subtype(self): + for subtype in subtypes: + if subtype.name == self.base(): + return True + return False + # Returns base size @staticmethod def _typesize(typename): @@ -144,6 +153,7 @@ class Field(object): self.num_elems = 1 self.optional = False self.is_tlv = False + self.is_subtype = False if name.endswith('_tlv'): self.is_tlv = True @@ -169,23 +179,29 @@ class Field(object): # Bolts use just a number: Guess type based on size. if options.bolt: - if size == '$': # this is a subtype + if size.startswith('$'): # this is a subtype self.fieldtype = FieldType('struct {}'.format(name)) - elif size == 'var_int': - base_size = 8 - self.fieldtype = FieldType(size) + self.is_subtype = True + if size[1:] == prevname: + self.lenvar = size[1:] + else: + raise ValueError('Expected size field for subtype field {}'.format(name)) else: - base_size = int(size) - self.fieldtype = Field._guess_type(message, self.name, base_size) - # There are some arrays which we have to guess, based on sizes. - tsize = FieldType._typesize(self.fieldtype.name) - if base_size % tsize != 0: - raise ValueError('Invalid size {} for {}.{} not a multiple of {}' - .format(base_size, - self.message, - self.name, - tsize)) - self.num_elems = int(base_size / tsize) + if size == 'var_int': + base_size = 8 + self.fieldtype = FieldType(size) + else: + base_size = int(size) + self.fieldtype = Field._guess_type(message, self.name, base_size) + # There are some arrays which we have to guess, based on sizes. + tsize = FieldType._typesize(self.fieldtype.name) + if base_size % tsize != 0: + raise ValueError('Invalid size {} for {}.{} not a multiple of {}' + .format(base_size, + self.message, + self.name, + tsize)) + self.num_elems = int(base_size / tsize) else: # Real typename. self.fieldtype = FieldType(size) @@ -255,6 +271,15 @@ fromwire_tlv_impl_templ = """static bool fromwire_{tlv_name}_{name}({ctx}{args}) }} """ +fromwire_subtype_impl_templ = """static bool fromwire_{name}({ctx}{args}) +{{ + +{fields} +{subcalls} +\treturn cursor != NULL; +}} +""" + fromwire_header_templ = """bool fromwire_{name}({ctx}const void *p{args}); """ @@ -408,8 +433,9 @@ class Message(object): subcalls.append('({})[i] = fromwire_{}({}, {}cursor, {}plen);' .format(name, basetype, ctx, p_ref, p_ref)) else: - subcalls.append('fromwire_{}({}cursor, {}plen, {} + i);' - .format(basetype, p_ref, p_ref, name)) + ctx_arg = ctx + ', ' if f.fieldtype.is_subtype() else '' + subcalls.append('fromwire_{}({}{}cursor, {}plen, {} + i);' + .format(basetype, ctx_arg, p_ref, p_ref, name)) def print_fromwire(self, is_header): ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else '' @@ -859,6 +885,99 @@ class Subtype(Message): def print_struct(self): return TlvMessage._inner_print_struct(self.name, self.fields) + def print_towire(self): + """ prints towire function definition for a subtype""" + field_decls = [] + for f in self.fields: + if f.optional: + raise TypeError("Optional fields on subtypes not currently supported. {}".format(f.name)) + if f.is_len_var: + field_decls.append('\t{0} {1} = tal_count({2}->{3});'.format( + f.fieldtype.name, f.name, self.name, f.lenvar_for.name + )) + + subcalls = CCode() + for f in self.fields: + basetype = f.fieldtype.base() + for c in f.comments: + subcalls.append('/*{} */'.format(c)) + + if f.is_padding(): + subcalls.append('towire_pad(p, {});'.format(f.num_elems)) + elif f.is_array(): + self.print_towire_array(subcalls, basetype, f, f.num_elems, + is_tlv=True) + elif f.is_variable_size(): + self.print_towire_array(subcalls, basetype, f, f.lenvar, + is_tlv=True) + elif f.is_len_var: + subcalls.append('towire_{}(p, {});'.format(basetype, f.name)) + else: + ref = '&' if f.fieldtype.needs_ptr() else '' + subcalls.append('towire_{}(p, {}{}->{});'.format(basetype, ref, self.name, f.name)) + return subtype_towire_stub.format( + name=self.name, + field_decls='\n'.join(field_decls), + subcalls=str(subcalls)) + + def print_fromwire(self): + """ prints fromwire function definition for a subtype. + these are significantly different in that they take in a struct + to populate, instead of fields. + """ + ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else '' + args = 'const u8 **cursor, size_t *plen, struct {name} *{name}'.format(name=self.name) + fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var] + subcalls = CCode() + for f in self.fields: + basetype = f.fieldtype.base() + if f.optional: + raise TypeError('Optional fields on subtypes not currently supported') + + for c in f.comments: + subcalls.append('/*{} */'.format(c)) + + if f.is_padding(): + subcalls.append('fromwire_pad(cursor, plen, {});' + .format(f.num_elems)) + elif f.is_array(): + name = '*{}->{}'.format(self.name, f.name) + self.print_fromwire_array('ctx', subcalls, basetype, f, name, + f.num_elems, is_tlv=True) + elif f.is_variable_size(): + subcalls.append("// 2nd case {name}".format(name=f.name)) + typename = f.fieldtype.name + # If structs are varlen, need array of ptrs to them. + if basetype in varlen_structs: + typename += ' *' + subcalls.append('{}->{} = {} ? tal_arr(ctx, {}, {}) : NULL;' + .format(self.name, f.name, f.lenvar, typename, f.lenvar)) + + name = '{}->{}'.format(self.name, f.name) + # Allocate these off the array itself, if they need alloc. + self.print_fromwire_array('*' + f.name, subcalls, basetype, f, + name, f.lenvar, is_tlv=True) + else: + if f.is_assignable(): + if f.is_len_var: + s = '{} = fromwire_{}(cursor, plen);'.format(f.name, basetype) + else: + s = '{}->{} = fromwire_{}(cursor, plen);'.format( + self.name, f.name, basetype) + else: + ref = '&' if f.fieldtype.needs_ptr() else '' + s = 'fromwire_{}(cursor, plen, {}{}->{});'.format( + basetype, ref, self.name, f.name) + subcalls.append(s) + + return fromwire_subtype_impl_templ.format( + name=self.name, + ctx=ctx_arg, + args=''.join(args), + fields=''.join(fields), + subcalls=str(subcalls) + ) + tlv_message_towire_stub = """static void towire_{tlv_name}_{name}(u8 **p, struct tlv_msg_{name} *{name}) {{ {field_decls} @@ -866,6 +985,12 @@ tlv_message_towire_stub = """static void towire_{tlv_name}_{name}(u8 **p, struct }} """ +subtype_towire_stub = """static void towire_{name}(u8 **p, const struct {name} *{name}) {{ +{field_decls} +{subcalls} +}} +""" + tlv_struct_template = """ struct {tlv_name} {{ {msg_type_structs} @@ -1345,6 +1470,9 @@ else: if not options.header: towire_decls += build_tlv_towires(tlv_fields) fromwire_decls += build_tlv_fromwires(tlv_fields) + for subtype in subtypes: + towire_decls.append(subtype.print_towire()) + fromwire_decls.append(subtype.print_fromwire()) towire_decls += [m.print_towire(options.header) for m in toplevel_messages + messages_with_option] fromwire_decls += [m.print_fromwire(options.header) for m in toplevel_messages + messages_with_option]