From 4a995a42de6f7857c44c45b00480e5d54285e31d Mon Sep 17 00:00:00 2001 From: Christian Decker Date: Tue, 28 Feb 2017 23:07:38 +0100 Subject: [PATCH] gen-wire: Cleanup of the wire generator tool The wiregen tool was a bit hard to maintain since it was printing all over the place, mixing template and processing logic. This commit tears the two apart, externalizes everything that is not a single code line, and repackages it into templates. Specifically functions are now their own template and header/implementation files are a template. Furthermore this simplifies some of the boilerplate of mapping types to sizes and back again, by extracting them into dicts. All changes have been verified to produce identical results on the current wire definitions, except a bit of whitespace changes. --- tools/generate-wire.py | 426 +++++++++++++++++++++-------------------- 1 file changed, 214 insertions(+), 212 deletions(-) diff --git a/tools/generate-wire.py b/tools/generate-wire.py index 09d183a36..deef23433 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -8,50 +8,77 @@ import re Enumtype = namedtuple('Enumtype', ['name', 'value']) +type2size = { + 'pad': 1, + 'struct channel_id': 32, + 'struct short_channel_id': 8, + 'struct ipv6': 16, + 'secp256k1_ecdsa_signature': 64, + 'struct pubkey': 33, + 'struct sha256': 32, + 'u64': 8, + 'u32': 4, + 'u16': 2, + 'u8': 1, + 'bool': 1 +} + class FieldType(object): def __init__(self,name): self.name = name self.tsize = FieldType._typesize(name) def is_assignable(self): - return self.name == 'u8' or self.name == 'u16' or self.name == 'u32' or self.name == 'u64' or self.name == 'bool' + return self.name in ['u8', 'u16', 'u32', 'u64', 'bool'] # Returns base size @staticmethod def _typesize(typename): - if typename == 'pad': - return 1 - elif typename == 'struct short_channel_id': - return 8 - elif typename == 'struct channel_id': - return 32 - elif typename == 'struct ipv6': - return 16 - elif typename == 'secp256k1_ecdsa_signature': - return 64 - elif typename == 'struct pubkey': - return 33 - elif typename == 'struct sha256': - return 32 - elif typename == 'u64': - return 8 - elif typename == 'u32': - return 4 - elif typename == 'u16': - return 2 - elif typename == 'u8': - return 1 - elif typename == 'bool': - return 1 - else: + if typename in type2size: + return type2size[typename] + elif typename.startswith('struct '): # We allow unknown structures, for extensiblity (can only happen # if explicitly specified in csv) - if typename.startswith('struct '): - return 0 + return 0 + else: raise ValueError('Unknown typename {}'.format(typename)) +# Full (message, fieldname)-mappings +typemap = { + ('update_fail_htlc', 'reason'): FieldType('u8'), + ('node_announcement', 'alias'): FieldType('u8'), + ('update_add_htlc', 'onion_routing_packet'): FieldType('u8'), + ('error', 'data'): FieldType('u8'), + ('shutdown', 'scriptpubkey'): FieldType('u8'), + ('node_announcement', 'rgb_color'): FieldType('u8'), + ('node_announcement', 'addresses'): FieldType('u8'), + ('node_announcement', 'ipv6'): FieldType('struct ipv6'), + ('node_announcement', 'alias'): FieldType('u8'), + ('announcement_signatures', 'short_channel_id'): FieldType('struct short_channel_id'), + ('channel_announcement', 'short_channel_id'): FieldType('struct short_channel_id'), + ('channel_update', 'short_channel_id'): FieldType('struct short_channel_id') +} + +# Partial names that map to a datatype +partialtypemap = { + 'signature': FieldType('secp256k1_ecdsa_signature'), + 'features': FieldType('u8'), + 'channel_id': FieldType('struct channel_id'), + 'pad': FieldType('pad'), +} + +# Size to typename match +sizetypemap = { + 33: FieldType('struct pubkey'), + 32: FieldType('struct sha256'), + 8: FieldType('u64'), + 4: FieldType('u32'), + 2: FieldType('u16'), + 1: FieldType('u8') +} + class Field(object): - def __init__(self,message,name,size,comments,typename=None): + def __init__(self, message, name, size, comments, typename=None): self.message = message self.comments = comments self.name = name.replace('-', '_') @@ -103,61 +130,53 @@ class Field(object): # Returns FieldType @staticmethod def _guess_type(message, fieldname, base_size): - if fieldname.startswith('pad'): - return FieldType('pad') + # Check for full (message, fieldname)-matches + if (message, fieldname) in typemap: + return typemap[(message, fieldname)] - if fieldname.endswith('short_channel_id'): - return FieldType('struct short_channel_id') + # Check for partial field names + for k, v in partialtypemap.items(): + if k in fieldname: + return v - if fieldname.endswith('channel_id'): - return FieldType('struct channel_id') - - if message == 'node_announcement' and fieldname == 'ipv6': - return FieldType('struct ipv6') - - if message == 'node_announcement' and fieldname == 'alias': - return FieldType('u8') - - if fieldname.endswith('features'): - return FieldType('u8') - - # We translate signatures and pubkeys. - if 'signature' in fieldname: - return FieldType('secp256k1_ecdsa_signature') - - # We whitelist specific things here, otherwise we'd treat everything - # as a u8 array. - if message == 'update_fail_htlc' and fieldname == 'reason': - return FieldType('u8') - if message == 'update_add_htlc' and fieldname == 'onion_routing_packet': - return FieldType('u8') - if message == 'node_announcement' and fieldname == 'alias': - return FieldType('u8') - if message == 'error' and fieldname == 'data': - return FieldType('u8') - if message == 'shutdown' and fieldname == 'scriptpubkey': - return FieldType('u8') - if message == 'node_announcement' and fieldname == 'rgb_color': - return FieldType('u8') - if message == 'node_announcement' and fieldname == 'addresses': - return FieldType('u8') - - # The remainder should be fixed sizes. - if base_size == 33: - return FieldType('struct pubkey') - if base_size == 32: - return FieldType('struct sha256') - if base_size == 8: - return FieldType('u64') - if base_size == 4: - return FieldType('u32') - if base_size == 2: - return FieldType('u16') - if base_size == 1: - return FieldType('u8') + # Check for size matches + if base_size in sizetypemap: + return sizetypemap[base_size] raise ValueError('Unknown size {} for {}'.format(base_size,fieldname)) +fromwire_impl_templ = """bool fromwire_{name}({ctx}const void *p, size_t *plen{args}) +{{ +{fields} + const u8 *cursor = p; + size_t tmp_len; + + if (!plen) {{ + tmp_len = tal_count(p); + plen = &tmp_len; + }} + if (fromwire_u16(&cursor, plen) != {enum.name}) + return false; +{subcalls} + return cursor != NULL; +}} +""" + +fromwire_header_templ = """bool fromwire_{name}({ctx}const void *p, size_t *plen{args}); +""" + +towire_header_templ = """u8 *towire_{name}(const tal_t *ctx{args}); +""" +towire_impl_templ = """u8 *towire_{name}(const tal_t *ctx{args}) +{{ +{field_decls} + u8 *p = tal_arr(ctx, u8, 0); + towire_u16(&p, {enumname}); +{subcalls} + + return memcheck(p, tal_count(p)); +}} +""" class Message(object): def __init__(self,name,enum,comments): self.name = name @@ -166,7 +185,7 @@ class Message(object): self.fields = [] self.has_variable_fields = False - def checkLenField(self,field): + def checkLenField(self, field): for f in self.fields: if f.name == field.lenvar: if f.fieldtype.name != 'u16': @@ -191,141 +210,117 @@ class Message(object): self.fields.append(field) def print_fromwire(self,is_header): - if self.has_variable_fields: - ctx_arg = 'const tal_t *ctx, ' - else: - ctx_arg = '' - - print('bool fromwire_{}({}const void *p, size_t *plen' - .format(self.name, ctx_arg), end='') + ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else '' + args = [] + for f in self.fields: - if f.is_len_var: + if f.is_len_var or f.is_padding(): continue - if f.is_padding(): - continue - if f.is_array(): - print(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='') + elif f.is_array(): + args.append(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems)) elif f.is_variable_size(): - print(', {} **{}'.format(f.fieldtype.name, f.name), end='') + args.append(', {} **{}'.format(f.fieldtype.name, f.name)) else: - print(', {} *{}'.format(f.fieldtype.name, f.name), end='') + args.append(', {} *{}'.format(f.fieldtype.name, f.name)) - if is_header: - print(');') - return - - print(')\n' - '{') - - for f in self.fields: - if f.is_len_var: - print('\t{} {};'.format(f.fieldtype.name, f.name)); - - print('\tconst u8 *cursor = p;\n' - '\tsize_t tmp_len;\n' - '\n' - '\tif (!plen) {{\n' - '\t\ttmp_len = tal_count(p);\n' - '\t\tplen = &tmp_len;\n' - '\t}}\n' - '\tif (fromwire_u16(&cursor, plen) != {})\n' - '\t\treturn false;' - .format(self.enum.name)) + template = fromwire_header_templ if is_header else fromwire_impl_templ + fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var] + subcalls = [] for f in self.fields: basetype=f.fieldtype.name if f.fieldtype.name.startswith('struct '): basetype=f.fieldtype.name[7:] for c in f.comments: - print('\t/*{} */'.format(c)) + subcalls.append('\t/*{} */'.format(c)) if f.is_padding(): - print('\tfromwire_pad(&cursor, plen, {});' - .format(f.num_elems)) + subcalls.append('\tfromwire_pad(&cursor, plen, {});' + .format(f.num_elems)) elif f.is_array(): - print("\t//1th case", f.name) - print('\tfromwire_{}_array(&cursor, plen, {}, {});' - .format(basetype, f.name, f.num_elems)) + subcalls.append("\t//1th case {name}".format(name=f.name)) + subcalls.append('\tfromwire_{}_array(&cursor, plen, {}, {});' + .format(basetype, f.name, f.num_elems)) elif f.is_variable_size(): - print("\t//2th case", f.name) - print('\t*{} = tal_arr(ctx, {}, {});' - .format(f.name, f.fieldtype.name, f.lenvar)) - print('\tfromwire_{}_array(&cursor, plen, *{}, {});' - .format(basetype, f.name, f.lenvar)) + subcalls.append("\t//2th case {name}".format(name=f.name)) + subcalls.append('\t*{} = tal_arr(ctx, {}, {});' + .format(f.name, f.fieldtype.name, f.lenvar)) + subcalls.append('\tfromwire_{}_array(&cursor, plen, *{}, {});' + .format(basetype, f.name, f.lenvar)) elif f.is_assignable(): - print("\t//3th case", f.name) + subcalls.append("\t//3th case {name}".format(name=f.name)) if f.is_len_var: - print('\t{} = fromwire_{}(&cursor, plen);' - .format(f.name, basetype)) + subcalls.append('\t{} = fromwire_{}(&cursor, plen);' + .format(f.name, basetype)) else: - print('\t*{} = fromwire_{}(&cursor, plen);' - .format(f.name, basetype)) + subcalls.append('\t*{} = fromwire_{}(&cursor, plen);' + .format(f.name, basetype)) else: - print("\t//4th case", f.name) - print('\tfromwire_{}(&cursor, plen, {});' - .format(basetype, f.name)) + subcalls.append("\t//4th case {name}".format(name=f.name)) + subcalls.append('\tfromwire_{}(&cursor, plen, {});' + .format(basetype, f.name)) - print('\n' - '\treturn cursor != NULL;\n' - '}\n') + return template.format( + name=self.name, + ctx=ctx_arg, + args=''.join(args), + fields=''.join(fields), + enum=self.enum, + subcalls='\n'.join(subcalls) + ) def print_towire(self,is_header): - print('u8 *towire_{}(const tal_t *ctx' - .format(self.name), end='') - + template = towire_header_templ if is_header else towire_impl_templ + args = [] for f in self.fields: if f.is_padding() or f.is_len_var: continue if f.is_array(): - print(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='') + args.append(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems)) elif f.is_assignable(): - print(', {} {}'.format(f.fieldtype.name, f.name), end='') + args.append(', {} {}'.format(f.fieldtype.name, f.name)) else: - print(', const {} *{}'.format(f.fieldtype.name, f.name), end='') + args.append(', const {} *{}'.format(f.fieldtype.name, f.name)) - if is_header: - print(');') - return - - print(')\n' - '{\n') + field_decls = [] for f in self.fields: if f.is_len_var: - print('\t{0} {1} = {2} ? tal_count({2}) : 0;' - .format(f.fieldtype.name, f.name, f.lenvar_for.name)); - - print('\tu8 *p = tal_arr(ctx, u8, 0);\n' - '' - '\ttowire_u16(&p, {});'.format(self.enum.name)) + field_decls.append('\t{0} {1} = {2} ? tal_count({2}) : 0;'.format( + f.fieldtype.name, f.name, f.lenvar_for.name + )); + subcalls = [] for f in self.fields: basetype=f.fieldtype.name - if f.fieldtype.name.startswith('struct '): - basetype=f.fieldtype.name[7:] + if basetype.startswith('struct '): + basetype=basetype[7:] for c in f.comments: - print('\t/*{} */'.format(c)) + subcalls.append('\t/*{} */'.format(c)) if f.is_padding(): - print('\ttowire_pad(&p, {});' + subcalls.append('\ttowire_pad(&p, {});' .format(f.num_elems)) elif f.is_array(): - print('\ttowire_{}_array(&p, {}, {});' + subcalls.append('\ttowire_{}_array(&p, {}, {});' .format(basetype, f.name, f.num_elems)) elif f.is_variable_size(): - print('\ttowire_{}_array(&p, {}, {});' + subcalls.append('\ttowire_{}_array(&p, {}, {});' .format(basetype, f.name, f.lenvar)) else: - print('\ttowire_{}(&p, {});' + subcalls.append('\ttowire_{}(&p, {});' .format(basetype, f.name)) - # Make sure we haven't encoded any uninitialzied fields! - print('\n' - '\treturn memcheck(p, tal_count(p));\n' - '}\n') - + return template.format( + name=self.name, + args=''.join(args), + enumname=self.enum.name, + field_decls='\n'.join(field_decls), + subcalls='\n'.join(subcalls), + ) + parser = argparse.ArgumentParser(description='Generate C from from CSV') parser.add_argument('--header', action='store_true', help="Create wire header") parser.add_argument('headerfilename', help='The filename of the header') @@ -333,19 +328,6 @@ parser.add_argument('enumname', help='The name of the enum to produce') parser.add_argument('files', nargs='*', help='Files to read in (or stdin)') options = parser.parse_args() -if options.header: - idem = re.sub(r'[^A-Z]+', '_', options.headerfilename.upper()) - print('#ifndef LIGHTNING_{0}\n' - '#define LIGHTNING_{0}\n' - '#include \n' - '#include '.format(idem)) -else: - print('#include <{}>\n' - '#include \n' - '#include \n' - '#include \n' - ''.format(options.headerfilename)) - # Maps message names to messages messages = [] comments = [] @@ -385,40 +367,60 @@ for line in fileinput.input(options.files): break comments=[] -if options.header: - for i in includes: - print(i, end='') +header_template = """#ifndef LIGHTNING_{idem} +#define LIGHTNING_{idem} +#include +#include +{includes} +enum {enumname} {{ +{enums}}}; +const char *{enumname}_name(int e); - print('') +{func_decls} +#endif /* LIGHTNING_{idem} */ +""" - # Dump out enum, sorted by value order. - print('enum {} {{'.format(options.enumname)) - for m in messages: - for c in m.comments: - print('\t/*{} */'.format(c)) - print('\t{} = {},'.format(m.enum.name, m.enum.value)) - print('};') - print('const char *{}_name(int e);'.format(options.enumname)) -else: - print('const char *{}_name(int e)'.format(options.enumname)) - print('{{\n' - '\tstatic char invalidbuf[sizeof("INVALID ") + STR_MAX_CHARS(e)];\n' - '\n' - '\tswitch ((enum {})e) {{'.format(options.enumname)); - for m in messages: - print('\tcase {0}: return "{0}";'.format(m.enum.name)) - print('\t}\n' - '\n' - '\tsprintf(invalidbuf, "INVALID %i", e);\n' - '\treturn invalidbuf;\n' - '}\n' - '') +impl_template = """#include <{headerfilename}> +#include +#include +#include +const char *{enumname}_name(int e) +{{ + static char invalidbuf[sizeof("INVALID ") + STR_MAX_CHARS(e)]; + + switch ((enum {enumname})e) {{ + {cases} + }} + + sprintf(invalidbuf, "INVALID %i", e); + return invalidbuf; +}} + +{func_decls} +""" + +idem = re.sub(r'[^A-Z]+', '_', options.headerfilename.upper()) +template = header_template if options.header else impl_template + +# Dump out enum, sorted by value order. +enums = "" for m in messages: - m.print_fromwire(options.header) + for c in m.comments: + enums += '\t/*{} */\n'.format(c) + enums += '\t{} = {},\n'.format(m.enum.name, m.enum.value) +includes = '\n'.join(includes) +cases = ['case {enum.name}: return "{enum.name}";'.format(enum=m.enum) for m in messages] -for m in messages: - m.print_towire(options.header) - -if options.header: - print('#endif /* LIGHTNING_{} */\n'.format(idem)) +fromwire_decls = [m.print_fromwire(options.header) for m in messages] +towire_decls = [m.print_towire(options.header) for m in messages] + +print(template.format( + headerfilename=options.headerfilename, + cases='\n\t'.join(cases), + idem=idem, + includes=includes, + enumname=options.enumname, + enums=enums, + func_decls='\n'.join(fromwire_decls + towire_decls), +))