gen-wire: Cleanup of the wire generator tool

The wiregen tool was a bit hard to maintain since it was printing all
over the place, mixing template and processing logic. This commit
tears the two apart, externalizes everything that is not a single code
line, and repackages it into templates. Specifically functions are now
their own template and header/implementation files are a template.

Furthermore this simplifies some of the boilerplate of mapping types
to sizes and back again, by extracting them into dicts.

All changes have been verified to produce identical results on the
current wire definitions, except a bit of whitespace changes.
This commit is contained in:
Christian Decker 2017-02-28 23:07:38 +01:00 committed by Rusty Russell
parent 7419fde9a0
commit 4a995a42de
1 changed files with 214 additions and 212 deletions

View File

@ -8,50 +8,77 @@ import re
Enumtype = namedtuple('Enumtype', ['name', 'value'])
type2size = {
'pad': 1,
'struct channel_id': 32,
'struct short_channel_id': 8,
'struct ipv6': 16,
'secp256k1_ecdsa_signature': 64,
'struct pubkey': 33,
'struct sha256': 32,
'u64': 8,
'u32': 4,
'u16': 2,
'u8': 1,
'bool': 1
}
class FieldType(object):
def __init__(self,name):
self.name = name
self.tsize = FieldType._typesize(name)
def is_assignable(self):
return self.name == 'u8' or self.name == 'u16' or self.name == 'u32' or self.name == 'u64' or self.name == 'bool'
return self.name in ['u8', 'u16', 'u32', 'u64', 'bool']
# Returns base size
@staticmethod
def _typesize(typename):
if typename == 'pad':
return 1
elif typename == 'struct short_channel_id':
return 8
elif typename == 'struct channel_id':
return 32
elif typename == 'struct ipv6':
return 16
elif typename == 'secp256k1_ecdsa_signature':
return 64
elif typename == 'struct pubkey':
return 33
elif typename == 'struct sha256':
return 32
elif typename == 'u64':
return 8
elif typename == 'u32':
return 4
elif typename == 'u16':
return 2
elif typename == 'u8':
return 1
elif typename == 'bool':
return 1
else:
if typename in type2size:
return type2size[typename]
elif typename.startswith('struct '):
# We allow unknown structures, for extensiblity (can only happen
# if explicitly specified in csv)
if typename.startswith('struct '):
return 0
return 0
else:
raise ValueError('Unknown typename {}'.format(typename))
# Full (message, fieldname)-mappings
typemap = {
('update_fail_htlc', 'reason'): FieldType('u8'),
('node_announcement', 'alias'): FieldType('u8'),
('update_add_htlc', 'onion_routing_packet'): FieldType('u8'),
('error', 'data'): FieldType('u8'),
('shutdown', 'scriptpubkey'): FieldType('u8'),
('node_announcement', 'rgb_color'): FieldType('u8'),
('node_announcement', 'addresses'): FieldType('u8'),
('node_announcement', 'ipv6'): FieldType('struct ipv6'),
('node_announcement', 'alias'): FieldType('u8'),
('announcement_signatures', 'short_channel_id'): FieldType('struct short_channel_id'),
('channel_announcement', 'short_channel_id'): FieldType('struct short_channel_id'),
('channel_update', 'short_channel_id'): FieldType('struct short_channel_id')
}
# Partial names that map to a datatype
partialtypemap = {
'signature': FieldType('secp256k1_ecdsa_signature'),
'features': FieldType('u8'),
'channel_id': FieldType('struct channel_id'),
'pad': FieldType('pad'),
}
# Size to typename match
sizetypemap = {
33: FieldType('struct pubkey'),
32: FieldType('struct sha256'),
8: FieldType('u64'),
4: FieldType('u32'),
2: FieldType('u16'),
1: FieldType('u8')
}
class Field(object):
def __init__(self,message,name,size,comments,typename=None):
def __init__(self, message, name, size, comments, typename=None):
self.message = message
self.comments = comments
self.name = name.replace('-', '_')
@ -103,61 +130,53 @@ class Field(object):
# Returns FieldType
@staticmethod
def _guess_type(message, fieldname, base_size):
if fieldname.startswith('pad'):
return FieldType('pad')
# Check for full (message, fieldname)-matches
if (message, fieldname) in typemap:
return typemap[(message, fieldname)]
if fieldname.endswith('short_channel_id'):
return FieldType('struct short_channel_id')
# Check for partial field names
for k, v in partialtypemap.items():
if k in fieldname:
return v
if fieldname.endswith('channel_id'):
return FieldType('struct channel_id')
if message == 'node_announcement' and fieldname == 'ipv6':
return FieldType('struct ipv6')
if message == 'node_announcement' and fieldname == 'alias':
return FieldType('u8')
if fieldname.endswith('features'):
return FieldType('u8')
# We translate signatures and pubkeys.
if 'signature' in fieldname:
return FieldType('secp256k1_ecdsa_signature')
# We whitelist specific things here, otherwise we'd treat everything
# as a u8 array.
if message == 'update_fail_htlc' and fieldname == 'reason':
return FieldType('u8')
if message == 'update_add_htlc' and fieldname == 'onion_routing_packet':
return FieldType('u8')
if message == 'node_announcement' and fieldname == 'alias':
return FieldType('u8')
if message == 'error' and fieldname == 'data':
return FieldType('u8')
if message == 'shutdown' and fieldname == 'scriptpubkey':
return FieldType('u8')
if message == 'node_announcement' and fieldname == 'rgb_color':
return FieldType('u8')
if message == 'node_announcement' and fieldname == 'addresses':
return FieldType('u8')
# The remainder should be fixed sizes.
if base_size == 33:
return FieldType('struct pubkey')
if base_size == 32:
return FieldType('struct sha256')
if base_size == 8:
return FieldType('u64')
if base_size == 4:
return FieldType('u32')
if base_size == 2:
return FieldType('u16')
if base_size == 1:
return FieldType('u8')
# Check for size matches
if base_size in sizetypemap:
return sizetypemap[base_size]
raise ValueError('Unknown size {} for {}'.format(base_size,fieldname))
fromwire_impl_templ = """bool fromwire_{name}({ctx}const void *p, size_t *plen{args})
{{
{fields}
const u8 *cursor = p;
size_t tmp_len;
if (!plen) {{
tmp_len = tal_count(p);
plen = &tmp_len;
}}
if (fromwire_u16(&cursor, plen) != {enum.name})
return false;
{subcalls}
return cursor != NULL;
}}
"""
fromwire_header_templ = """bool fromwire_{name}({ctx}const void *p, size_t *plen{args});
"""
towire_header_templ = """u8 *towire_{name}(const tal_t *ctx{args});
"""
towire_impl_templ = """u8 *towire_{name}(const tal_t *ctx{args})
{{
{field_decls}
u8 *p = tal_arr(ctx, u8, 0);
towire_u16(&p, {enumname});
{subcalls}
return memcheck(p, tal_count(p));
}}
"""
class Message(object):
def __init__(self,name,enum,comments):
self.name = name
@ -166,7 +185,7 @@ class Message(object):
self.fields = []
self.has_variable_fields = False
def checkLenField(self,field):
def checkLenField(self, field):
for f in self.fields:
if f.name == field.lenvar:
if f.fieldtype.name != 'u16':
@ -191,141 +210,117 @@ class Message(object):
self.fields.append(field)
def print_fromwire(self,is_header):
if self.has_variable_fields:
ctx_arg = 'const tal_t *ctx, '
else:
ctx_arg = ''
print('bool fromwire_{}({}const void *p, size_t *plen'
.format(self.name, ctx_arg), end='')
ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else ''
args = []
for f in self.fields:
if f.is_len_var:
if f.is_len_var or f.is_padding():
continue
if f.is_padding():
continue
if f.is_array():
print(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='')
elif f.is_array():
args.append(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems))
elif f.is_variable_size():
print(', {} **{}'.format(f.fieldtype.name, f.name), end='')
args.append(', {} **{}'.format(f.fieldtype.name, f.name))
else:
print(', {} *{}'.format(f.fieldtype.name, f.name), end='')
args.append(', {} *{}'.format(f.fieldtype.name, f.name))
if is_header:
print(');')
return
print(')\n'
'{')
for f in self.fields:
if f.is_len_var:
print('\t{} {};'.format(f.fieldtype.name, f.name));
print('\tconst u8 *cursor = p;\n'
'\tsize_t tmp_len;\n'
'\n'
'\tif (!plen) {{\n'
'\t\ttmp_len = tal_count(p);\n'
'\t\tplen = &tmp_len;\n'
'\t}}\n'
'\tif (fromwire_u16(&cursor, plen) != {})\n'
'\t\treturn false;'
.format(self.enum.name))
template = fromwire_header_templ if is_header else fromwire_impl_templ
fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var]
subcalls = []
for f in self.fields:
basetype=f.fieldtype.name
if f.fieldtype.name.startswith('struct '):
basetype=f.fieldtype.name[7:]
for c in f.comments:
print('\t/*{} */'.format(c))
subcalls.append('\t/*{} */'.format(c))
if f.is_padding():
print('\tfromwire_pad(&cursor, plen, {});'
.format(f.num_elems))
subcalls.append('\tfromwire_pad(&cursor, plen, {});'
.format(f.num_elems))
elif f.is_array():
print("\t//1th case", f.name)
print('\tfromwire_{}_array(&cursor, plen, {}, {});'
.format(basetype, f.name, f.num_elems))
subcalls.append("\t//1th case {name}".format(name=f.name))
subcalls.append('\tfromwire_{}_array(&cursor, plen, {}, {});'
.format(basetype, f.name, f.num_elems))
elif f.is_variable_size():
print("\t//2th case", f.name)
print('\t*{} = tal_arr(ctx, {}, {});'
.format(f.name, f.fieldtype.name, f.lenvar))
print('\tfromwire_{}_array(&cursor, plen, *{}, {});'
.format(basetype, f.name, f.lenvar))
subcalls.append("\t//2th case {name}".format(name=f.name))
subcalls.append('\t*{} = tal_arr(ctx, {}, {});'
.format(f.name, f.fieldtype.name, f.lenvar))
subcalls.append('\tfromwire_{}_array(&cursor, plen, *{}, {});'
.format(basetype, f.name, f.lenvar))
elif f.is_assignable():
print("\t//3th case", f.name)
subcalls.append("\t//3th case {name}".format(name=f.name))
if f.is_len_var:
print('\t{} = fromwire_{}(&cursor, plen);'
.format(f.name, basetype))
subcalls.append('\t{} = fromwire_{}(&cursor, plen);'
.format(f.name, basetype))
else:
print('\t*{} = fromwire_{}(&cursor, plen);'
.format(f.name, basetype))
subcalls.append('\t*{} = fromwire_{}(&cursor, plen);'
.format(f.name, basetype))
else:
print("\t//4th case", f.name)
print('\tfromwire_{}(&cursor, plen, {});'
.format(basetype, f.name))
subcalls.append("\t//4th case {name}".format(name=f.name))
subcalls.append('\tfromwire_{}(&cursor, plen, {});'
.format(basetype, f.name))
print('\n'
'\treturn cursor != NULL;\n'
'}\n')
return template.format(
name=self.name,
ctx=ctx_arg,
args=''.join(args),
fields=''.join(fields),
enum=self.enum,
subcalls='\n'.join(subcalls)
)
def print_towire(self,is_header):
print('u8 *towire_{}(const tal_t *ctx'
.format(self.name), end='')
template = towire_header_templ if is_header else towire_impl_templ
args = []
for f in self.fields:
if f.is_padding() or f.is_len_var:
continue
if f.is_array():
print(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='')
args.append(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems))
elif f.is_assignable():
print(', {} {}'.format(f.fieldtype.name, f.name), end='')
args.append(', {} {}'.format(f.fieldtype.name, f.name))
else:
print(', const {} *{}'.format(f.fieldtype.name, f.name), end='')
args.append(', const {} *{}'.format(f.fieldtype.name, f.name))
if is_header:
print(');')
return
print(')\n'
'{\n')
field_decls = []
for f in self.fields:
if f.is_len_var:
print('\t{0} {1} = {2} ? tal_count({2}) : 0;'
.format(f.fieldtype.name, f.name, f.lenvar_for.name));
print('\tu8 *p = tal_arr(ctx, u8, 0);\n'
''
'\ttowire_u16(&p, {});'.format(self.enum.name))
field_decls.append('\t{0} {1} = {2} ? tal_count({2}) : 0;'.format(
f.fieldtype.name, f.name, f.lenvar_for.name
));
subcalls = []
for f in self.fields:
basetype=f.fieldtype.name
if f.fieldtype.name.startswith('struct '):
basetype=f.fieldtype.name[7:]
if basetype.startswith('struct '):
basetype=basetype[7:]
for c in f.comments:
print('\t/*{} */'.format(c))
subcalls.append('\t/*{} */'.format(c))
if f.is_padding():
print('\ttowire_pad(&p, {});'
subcalls.append('\ttowire_pad(&p, {});'
.format(f.num_elems))
elif f.is_array():
print('\ttowire_{}_array(&p, {}, {});'
subcalls.append('\ttowire_{}_array(&p, {}, {});'
.format(basetype, f.name, f.num_elems))
elif f.is_variable_size():
print('\ttowire_{}_array(&p, {}, {});'
subcalls.append('\ttowire_{}_array(&p, {}, {});'
.format(basetype, f.name, f.lenvar))
else:
print('\ttowire_{}(&p, {});'
subcalls.append('\ttowire_{}(&p, {});'
.format(basetype, f.name))
# Make sure we haven't encoded any uninitialzied fields!
print('\n'
'\treturn memcheck(p, tal_count(p));\n'
'}\n')
return template.format(
name=self.name,
args=''.join(args),
enumname=self.enum.name,
field_decls='\n'.join(field_decls),
subcalls='\n'.join(subcalls),
)
parser = argparse.ArgumentParser(description='Generate C from from CSV')
parser.add_argument('--header', action='store_true', help="Create wire header")
parser.add_argument('headerfilename', help='The filename of the header')
@ -333,19 +328,6 @@ parser.add_argument('enumname', help='The name of the enum to produce')
parser.add_argument('files', nargs='*', help='Files to read in (or stdin)')
options = parser.parse_args()
if options.header:
idem = re.sub(r'[^A-Z]+', '_', options.headerfilename.upper())
print('#ifndef LIGHTNING_{0}\n'
'#define LIGHTNING_{0}\n'
'#include <ccan/tal/tal.h>\n'
'#include <wire/wire.h>'.format(idem))
else:
print('#include <{}>\n'
'#include <ccan/mem/mem.h>\n'
'#include <ccan/tal/str/str.h>\n'
'#include <stdio.h>\n'
''.format(options.headerfilename))
# Maps message names to messages
messages = []
comments = []
@ -385,40 +367,60 @@ for line in fileinput.input(options.files):
break
comments=[]
if options.header:
for i in includes:
print(i, end='')
header_template = """#ifndef LIGHTNING_{idem}
#define LIGHTNING_{idem}
#include <ccan/tal/tal.h>
#include <wire/wire.h>
{includes}
enum {enumname} {{
{enums}}};
const char *{enumname}_name(int e);
print('')
{func_decls}
#endif /* LIGHTNING_{idem} */
"""
# Dump out enum, sorted by value order.
print('enum {} {{'.format(options.enumname))
for m in messages:
for c in m.comments:
print('\t/*{} */'.format(c))
print('\t{} = {},'.format(m.enum.name, m.enum.value))
print('};')
print('const char *{}_name(int e);'.format(options.enumname))
else:
print('const char *{}_name(int e)'.format(options.enumname))
print('{{\n'
'\tstatic char invalidbuf[sizeof("INVALID ") + STR_MAX_CHARS(e)];\n'
'\n'
'\tswitch ((enum {})e) {{'.format(options.enumname));
for m in messages:
print('\tcase {0}: return "{0}";'.format(m.enum.name))
print('\t}\n'
'\n'
'\tsprintf(invalidbuf, "INVALID %i", e);\n'
'\treturn invalidbuf;\n'
'}\n'
'')
impl_template = """#include <{headerfilename}>
#include <ccan/mem/mem.h>
#include <ccan/tal/str/str.h>
#include <stdio.h>
const char *{enumname}_name(int e)
{{
static char invalidbuf[sizeof("INVALID ") + STR_MAX_CHARS(e)];
switch ((enum {enumname})e) {{
{cases}
}}
sprintf(invalidbuf, "INVALID %i", e);
return invalidbuf;
}}
{func_decls}
"""
idem = re.sub(r'[^A-Z]+', '_', options.headerfilename.upper())
template = header_template if options.header else impl_template
# Dump out enum, sorted by value order.
enums = ""
for m in messages:
m.print_fromwire(options.header)
for c in m.comments:
enums += '\t/*{} */\n'.format(c)
enums += '\t{} = {},\n'.format(m.enum.name, m.enum.value)
includes = '\n'.join(includes)
cases = ['case {enum.name}: return "{enum.name}";'.format(enum=m.enum) for m in messages]
for m in messages:
m.print_towire(options.header)
if options.header:
print('#endif /* LIGHTNING_{} */\n'.format(idem))
fromwire_decls = [m.print_fromwire(options.header) for m in messages]
towire_decls = [m.print_towire(options.header) for m in messages]
print(template.format(
headerfilename=options.headerfilename,
cases='\n\t'.join(cases),
idem=idem,
includes=includes,
enumname=options.enumname,
enums=enums,
func_decls='\n'.join(fromwire_decls + towire_decls),
))