pyln.proto.message: use BufferedIOBase instead of bytes for binary ops.
Instead of val_to_bin/val_from_bin which deal with bytes, we implement read and write which use streams. This simplifies the API. Suggested-by: Christian Decker Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
parent
ed4eadc8f3
commit
47631cc23c
|
@ -42,28 +42,26 @@ wants an array of some type.
|
|||
|
||||
return '[' + s + ']'
|
||||
|
||||
def val_to_bin(self, v, otherfields):
|
||||
b = bytes()
|
||||
def write(self, io_out, v, otherfields):
|
||||
for i in v:
|
||||
b += self.elemtype.val_to_bin(i, otherfields)
|
||||
return b
|
||||
self.elemtype.write(io_out, i, otherfields)
|
||||
|
||||
def arr_from_bin(self, bytestream, otherfields, arraysize):
|
||||
"""arraysize None means take rest of bytestream exactly"""
|
||||
totsize = 0
|
||||
def read_arr(self, io_in, otherfields, arraysize):
|
||||
"""arraysize None means take rest of io entirely and exactly"""
|
||||
vals = []
|
||||
i = 0
|
||||
while True:
|
||||
if arraysize is None and totsize == len(bytestream):
|
||||
return vals, totsize
|
||||
elif i == arraysize:
|
||||
return vals, totsize
|
||||
val, size = self.elemtype.val_from_bin(bytestream[totsize:],
|
||||
otherfields)
|
||||
totsize += size
|
||||
i += 1
|
||||
while arraysize is None or len(vals) < arraysize:
|
||||
# Throws an exception on partial read, so None means completely empty.
|
||||
val = self.elemtype.read(io_in, otherfields)
|
||||
if val is None:
|
||||
if arraysize is not None:
|
||||
raise ValueError('{}: not enough remaining to read'
|
||||
.format(self))
|
||||
break
|
||||
|
||||
vals.append(val)
|
||||
|
||||
return vals
|
||||
|
||||
|
||||
class SizedArrayType(ArrayType):
|
||||
"""A fixed-size array"""
|
||||
|
@ -82,13 +80,13 @@ class SizedArrayType(ArrayType):
|
|||
raise ValueError("Length of {} != {}", s, self.arraysize)
|
||||
return a, b
|
||||
|
||||
def val_to_bin(self, v, otherfields):
|
||||
def write(self, io_out, v, otherfields):
|
||||
if len(v) != self.arraysize:
|
||||
raise ValueError("Length of {} != {}", v, self.arraysize)
|
||||
return super().val_to_bin(v, otherfields)
|
||||
return super().write(io_out, v, otherfields)
|
||||
|
||||
def val_from_bin(self, bytestream, otherfields):
|
||||
return super().arr_from_bin(bytestream, otherfields, self.arraysize)
|
||||
def read(self, io_in, otherfields):
|
||||
return super().read_arr(io_in, otherfields, self.arraysize)
|
||||
|
||||
|
||||
class EllipsisArrayType(ArrayType):
|
||||
|
@ -97,9 +95,9 @@ when the tlv ends"""
|
|||
def __init__(self, tlv, name, elemtype):
|
||||
super().__init__(tlv, name, elemtype)
|
||||
|
||||
def val_from_bin(self, bytestream, otherfields):
|
||||
def read(self, io_in, otherfields):
|
||||
"""Takes rest of bytestream"""
|
||||
return super().arr_from_bin(bytestream, otherfields, None)
|
||||
return super().read_arr(io_in, otherfields, None)
|
||||
|
||||
def only_at_tlv_end(self):
|
||||
"""These only make sense at the end of a TLV"""
|
||||
|
@ -142,10 +140,6 @@ class LengthFieldType(FieldType):
|
|||
return v
|
||||
return self.calc_value(otherfields)
|
||||
|
||||
def val_to_bin(self, _, otherfields):
|
||||
return self.underlying_type.val_to_bin(self.calc_value(otherfields),
|
||||
otherfields)
|
||||
|
||||
def val_to_str(self, _, otherfields):
|
||||
return self.underlying_type.val_to_str(self.calc_value(otherfields),
|
||||
otherfields)
|
||||
|
@ -155,9 +149,13 @@ class LengthFieldType(FieldType):
|
|||
they're implied by the length of other fields"""
|
||||
return ''
|
||||
|
||||
def val_from_bin(self, bytestream, otherfields):
|
||||
def read(self, io_in, otherfields):
|
||||
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
|
||||
return self.underlying_type.val_from_bin(bytestream, otherfields)
|
||||
return self.underlying_type.read(io_in, otherfields)
|
||||
|
||||
def write(self, io_out, _, otherfields):
|
||||
self.underlying_type.write(io_out, self.calc_value(otherfields),
|
||||
otherfields)
|
||||
|
||||
def val_from_str(self, s):
|
||||
raise ValueError('{} is implied, cannot be specified'.format(self))
|
||||
|
@ -182,6 +180,6 @@ class DynamicArrayType(ArrayType):
|
|||
assert type(lenfield.fieldtype) is LengthFieldType
|
||||
self.lenfield = lenfield
|
||||
|
||||
def val_from_bin(self, bytestream, otherfields):
|
||||
return super().arr_from_bin(bytestream, otherfields,
|
||||
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields))
|
||||
def read(self, io_in, otherfields):
|
||||
return super().read_arr(io_in, otherfields,
|
||||
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields))
|
||||
|
|
|
@ -1,4 +1,22 @@
|
|||
import struct
|
||||
import io
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def try_unpack(name: str,
|
||||
io_out: io.BufferedIOBase,
|
||||
structfmt: str,
|
||||
empty_ok: bool) -> Optional[int]:
|
||||
"""Unpack a single value using struct.unpack.
|
||||
|
||||
If need_all, never return None, otherwise returns None if EOF."""
|
||||
b = io_out.read(struct.calcsize(structfmt))
|
||||
if len(b) == 0 and empty_ok:
|
||||
return None
|
||||
elif len(b) < struct.calcsize(structfmt):
|
||||
raise ValueError("{}: not enough bytes", name)
|
||||
|
||||
return struct.unpack(structfmt, b)[0]
|
||||
|
||||
|
||||
def split_field(s):
|
||||
|
@ -57,15 +75,11 @@ class IntegerType(FieldType):
|
|||
a, b = split_field(s)
|
||||
return int(a), b
|
||||
|
||||
def val_to_bin(self, v, otherfields):
|
||||
return struct.pack(self.structfmt, v)
|
||||
def write(self, io_out, v, otherfields):
|
||||
io_out.write(struct.pack(self.structfmt, v))
|
||||
|
||||
def val_from_bin(self, bytestream, otherfields):
|
||||
"Returns value, bytesused"
|
||||
if self.bytelen > len(bytestream):
|
||||
raise ValueError('{}: not enough remaining to read'.format(self))
|
||||
return struct.unpack_from(self.structfmt,
|
||||
bytestream)[0], self.bytelen
|
||||
def read(self, io_in, otherfields):
|
||||
return try_unpack(self.name, io_in, self.structfmt, empty_ok=True)
|
||||
|
||||
|
||||
class ShortChannelIDType(IntegerType):
|
||||
|
@ -110,30 +124,24 @@ class TruncatedIntType(FieldType):
|
|||
.format(a, self.name))
|
||||
return int(a), b
|
||||
|
||||
def val_to_bin(self, v, otherfields):
|
||||
def write(self, io_out, v, otherfields):
|
||||
binval = struct.pack('>Q', v)
|
||||
while len(binval) != 0 and binval[0] == 0:
|
||||
binval = binval[1:]
|
||||
if len(binval) > self.maxbytes:
|
||||
raise ValueError('{} exceeds maximum {} capacity'
|
||||
.format(v, self.name))
|
||||
return binval
|
||||
|
||||
def val_from_bin(self, bytestream, otherfields):
|
||||
"Returns value, bytesused"
|
||||
binval = bytes()
|
||||
while len(binval) < len(bytestream):
|
||||
if len(binval) == 0 and bytestream[len(binval)] == 0:
|
||||
raise ValueError('{} encoding is not minimal: {}'
|
||||
.format(self.name, bytestream))
|
||||
binval += bytes([bytestream[len(binval)]])
|
||||
io_out.write(binval)
|
||||
|
||||
def read(self, io_in, otherfields):
|
||||
binval = io_in.read()
|
||||
if len(binval) > self.maxbytes:
|
||||
raise ValueError('{} is too long for {}'.format(binval, self.name))
|
||||
|
||||
if len(binval) > 0 and binval[0] == 0:
|
||||
raise ValueError('{} encoding is not minimal: {}'
|
||||
.format(self.name, binval))
|
||||
# Pad with zeroes and convert as u64
|
||||
return (struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0],
|
||||
len(binval))
|
||||
return struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0]
|
||||
|
||||
|
||||
class FundamentalHexType(FieldType):
|
||||
|
@ -154,16 +162,18 @@ class FundamentalHexType(FieldType):
|
|||
raise ValueError("Length of {} != {}", a, self.bytelen)
|
||||
return ret, b
|
||||
|
||||
def val_to_bin(self, v, otherfields):
|
||||
def write(self, io_out, v, otherfields):
|
||||
if len(bytes(v)) != self.bytelen:
|
||||
raise ValueError("Length of {} != {}", v, self.bytelen)
|
||||
return bytes(v)
|
||||
io_out.write(v)
|
||||
|
||||
def val_from_bin(self, bytestream, otherfields):
|
||||
"Returns value, size from bytestream"
|
||||
if self.bytelen > len(bytestream):
|
||||
def read(self, io_in, otherfields):
|
||||
val = io_in.read(self.bytelen)
|
||||
if len(val) == 0:
|
||||
return None
|
||||
elif len(val) != self.bytelen:
|
||||
raise ValueError('{}: not enough remaining'.format(self))
|
||||
return bytestream[:self.bytelen], self.bytelen
|
||||
return val
|
||||
|
||||
|
||||
class BigSizeType(FieldType):
|
||||
|
@ -177,37 +187,34 @@ class BigSizeType(FieldType):
|
|||
|
||||
# For the convenience of TLV header parsing
|
||||
@staticmethod
|
||||
def to_bin(v):
|
||||
def write(io_out, v, otherfields=None):
|
||||
if v < 253:
|
||||
return bytes([v])
|
||||
io_out.write(bytes([v]))
|
||||
elif v < 2**16:
|
||||
return bytes([253]) + struct.pack('>H', v)
|
||||
io_out.write(bytes([253]) + struct.pack('>H', v))
|
||||
elif v < 2**32:
|
||||
return bytes([254]) + struct.pack('>I', v)
|
||||
io_out.write(bytes([254]) + struct.pack('>I', v))
|
||||
else:
|
||||
return bytes([255]) + struct.pack('>Q', v)
|
||||
io_out.write(bytes([255]) + struct.pack('>Q', v))
|
||||
|
||||
@staticmethod
|
||||
def from_bin(bytestream):
|
||||
"Returns value, bytesused"
|
||||
if bytestream[0] < 253:
|
||||
return int(bytestream[0]), 1
|
||||
elif bytestream[0] == 253:
|
||||
return struct.unpack_from('>H', bytestream[1:])[0], 3
|
||||
elif bytestream[0] == 254:
|
||||
return struct.unpack_from('>I', bytestream[1:])[0], 5
|
||||
def read(io_in, otherfields=None):
|
||||
"Returns value, or None on EOF"
|
||||
b = io_in.read(1)
|
||||
if len(b) == 0:
|
||||
return None
|
||||
if b[0] < 253:
|
||||
return int(b[0])
|
||||
elif b[0] == 253:
|
||||
return try_unpack('BigSize', io_in, '>H', empty_ok=False)
|
||||
elif b[0] == 254:
|
||||
return try_unpack('BigSize', io_in, '>I', empty_ok=False)
|
||||
else:
|
||||
return struct.unpack_from('>Q', bytestream[1:])[0], 9
|
||||
return try_unpack('BigSize', io_in, '>Q', empty_ok=False)
|
||||
|
||||
def val_to_str(self, v, otherfields):
|
||||
return "{}".format(int(v))
|
||||
|
||||
def val_to_bin(self, v, otherfields):
|
||||
return self.to_bin(v)
|
||||
|
||||
def val_from_bin(self, bytestream, otherfields):
|
||||
return self.from_bin(bytestream)
|
||||
|
||||
|
||||
def fundamental_types():
|
||||
# From 01-messaging.md#fundamental-types:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import struct
|
||||
from .fundamental_types import fundamental_types, BigSizeType, split_field
|
||||
import io
|
||||
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack
|
||||
from .array_types import (
|
||||
SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType
|
||||
)
|
||||
|
@ -253,24 +254,21 @@ inherit from this too.
|
|||
|
||||
return '{' + s + '}'
|
||||
|
||||
def val_to_bin(self, v, otherfields):
|
||||
def write(self, io_out, v, otherfields):
|
||||
self._raise_if_badvals(v)
|
||||
b = bytes()
|
||||
for fname, val in v.items():
|
||||
field = self.find_field(fname)
|
||||
b += field.fieldtype.val_to_bin(val, otherfields)
|
||||
return b
|
||||
field.fieldtype.write(io_out, val, otherfields)
|
||||
|
||||
def val_from_bin(self, bytestream, otherfields):
|
||||
totsize = 0
|
||||
def read(self, io_in, otherfields):
|
||||
vals = {}
|
||||
for field in self.fields:
|
||||
val, size = field.fieldtype.val_from_bin(bytestream[totsize:],
|
||||
otherfields)
|
||||
totsize += size
|
||||
val = field.fieldtype.read(io_in, otherfields)
|
||||
if val is None:
|
||||
raise ValueError("{}.{}: short read".format(self, field))
|
||||
vals[field.name] = val
|
||||
|
||||
return vals, totsize
|
||||
return vals
|
||||
|
||||
@staticmethod
|
||||
def field_from_csv(namespace, parts):
|
||||
|
@ -433,17 +431,15 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
|||
|
||||
return '{' + s + '}'
|
||||
|
||||
def val_to_bin(self, v, otherfields):
|
||||
b = bytes()
|
||||
|
||||
def write(self, iobuf, v, otherfields):
|
||||
# If they didn't specify this tlvstream, it's empty.
|
||||
if v is None:
|
||||
return b
|
||||
return
|
||||
|
||||
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
|
||||
# ascending order as TLV spec requires.
|
||||
def copy_val(val, otherfields):
|
||||
return val
|
||||
def write_raw_val(iobuf, val, otherfields):
|
||||
iobuf.write(val)
|
||||
|
||||
def get_value(tup):
|
||||
"""Get value from num, fun, val tuple"""
|
||||
|
@ -454,43 +450,40 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
|||
f = self.find_field(fieldname)
|
||||
if f is None:
|
||||
# fieldname can be an integer for a raw field.
|
||||
ordered.append((int(fieldname), copy_val, v[fieldname]))
|
||||
ordered.append((int(fieldname), write_raw_val, v[fieldname]))
|
||||
else:
|
||||
ordered.append((f.number, f.val_to_bin, v[fieldname]))
|
||||
ordered.append((f.number, f.write, v[fieldname]))
|
||||
|
||||
ordered.sort(key=get_value)
|
||||
|
||||
for tup in ordered:
|
||||
value = tup[1](tup[2], otherfields)
|
||||
b += (BigSizeType.to_bin(tup[0])
|
||||
+ BigSizeType.to_bin(len(value))
|
||||
+ value)
|
||||
for typenum, writefunc, val in ordered:
|
||||
buf = io.BytesIO()
|
||||
writefunc(buf, val, otherfields)
|
||||
BigSizeType.write(iobuf, typenum)
|
||||
BigSizeType.write(iobuf, len(buf.getvalue()))
|
||||
iobuf.write(buf.getvalue())
|
||||
|
||||
return b
|
||||
|
||||
def val_from_bin(self, bytestream, otherfields):
|
||||
totsize = 0
|
||||
def read(self, io_in, otherfields):
|
||||
vals = {}
|
||||
|
||||
while totsize < len(bytestream):
|
||||
tlv_type, size = BigSizeType.from_bin(bytestream[totsize:])
|
||||
totsize += size
|
||||
tlv_len, size = BigSizeType.from_bin(bytestream[totsize:])
|
||||
totsize += size
|
||||
while True:
|
||||
tlv_type = BigSizeType.read(io_in)
|
||||
if tlv_type is None:
|
||||
return vals
|
||||
|
||||
tlv_len = BigSizeType.read(io_in)
|
||||
if tlv_len is None:
|
||||
raise ValueError("{}: truncated tlv_len field".format(self))
|
||||
binval = io_in.read(tlv_len)
|
||||
if len(binval) != tlv_len:
|
||||
raise ValueError("{}: truncated tlv {} value"
|
||||
.format(tlv_type, self))
|
||||
f = self.find_field_by_number(tlv_type)
|
||||
if f is None:
|
||||
vals[tlv_type] = bytestream[totsize:totsize + tlv_len]
|
||||
size = len(vals[tlv_type])
|
||||
# Raw fields are allowed, just index by number.
|
||||
vals[tlv_type] = binval
|
||||
else:
|
||||
vals[f.name], size = f.val_from_bin(bytestream
|
||||
[totsize:totsize
|
||||
+ tlv_len],
|
||||
otherfields)
|
||||
if size != tlv_len:
|
||||
raise ValueError("Truncated tlv field")
|
||||
totsize += size
|
||||
|
||||
return vals, totsize
|
||||
vals[f.name] = f.read(io.BytesIO(binval), otherfields)
|
||||
|
||||
def name_and_val(self, name, v):
|
||||
"""This is overridden by LengthFieldType to return nothing"""
|
||||
|
@ -541,10 +534,15 @@ class Message(object):
|
|||
return missing
|
||||
|
||||
@staticmethod
|
||||
def from_bin(namespace, binmsg):
|
||||
"""Decode a binary wire format to a Message within that namespace"""
|
||||
typenum = struct.unpack_from(">H", binmsg)[0]
|
||||
off = 2
|
||||
def read(namespace, io_in):
|
||||
"""Read and decode a Message within that namespace.
|
||||
|
||||
Returns None on EOF
|
||||
|
||||
"""
|
||||
typenum = try_unpack('message_type', io_in, ">H", empty_ok=True)
|
||||
if typenum is None:
|
||||
return None
|
||||
|
||||
mtype = namespace.get_msgtype_by_number(typenum)
|
||||
if not mtype:
|
||||
|
@ -552,16 +550,21 @@ class Message(object):
|
|||
|
||||
fields = {}
|
||||
for f in mtype.fields:
|
||||
v, size = f.fieldtype.val_from_bin(binmsg[off:], fields)
|
||||
off += size
|
||||
fields[f.name] = v
|
||||
fields[f.name] = f.fieldtype.read(io_in, fields)
|
||||
if fields[f.name] is None:
|
||||
# optional fields are OK to be missing at end!
|
||||
raise ValueError('{}: truncated at field {}'
|
||||
.format(mtype, f.name))
|
||||
|
||||
return Message(mtype, **fields)
|
||||
|
||||
@staticmethod
|
||||
def from_str(namespace, s, incomplete_ok=False):
|
||||
"""Decode a string to a Message within that namespace, of format
|
||||
msgname [ field=...]*."""
|
||||
"""Decode a string to a Message within that namespace.
|
||||
|
||||
Format is msgname [ field=...]*.
|
||||
|
||||
"""
|
||||
parts = s.split()
|
||||
|
||||
mtype = namespace.get_msgtype(parts[0])
|
||||
|
@ -582,14 +585,17 @@ msgname [ field=...]*."""
|
|||
|
||||
return m
|
||||
|
||||
def to_bin(self):
|
||||
"""Encode a Message into its wire format (must not have missing
|
||||
fields)"""
|
||||
def write(self, io_out):
|
||||
"""Write a Message into its wire format.
|
||||
|
||||
Must not have missing fields.
|
||||
|
||||
"""
|
||||
if self.missing_fields():
|
||||
raise ValueError('Missing fields: {}'
|
||||
.format(self.missing_fields()))
|
||||
|
||||
ret = struct.pack(">H", self.messagetype.number)
|
||||
io_out.write(struct.pack(">H", self.messagetype.number))
|
||||
for f in self.messagetype.fields:
|
||||
# Optional fields get val == None. Usually this means they don't
|
||||
# write anything, but length fields are an exception: they intuit
|
||||
|
@ -598,8 +604,7 @@ fields)"""
|
|||
val = self.fields[f.name]
|
||||
else:
|
||||
val = None
|
||||
ret += f.fieldtype.val_to_bin(val, self.fields)
|
||||
return ret
|
||||
f.fieldtype.write(io_out, val, self.fields)
|
||||
|
||||
def to_str(self):
|
||||
"""Encode a Message into a string"""
|
||||
|
|
|
@ -2,3 +2,4 @@ bitstring==3.1.6
|
|||
cryptography==2.8
|
||||
coincurve==13.0.0
|
||||
base58==1.0.2
|
||||
mypy
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#! /usr/bin/python3
|
||||
from pyln.proto.message.fundamental_types import fundamental_types
|
||||
from pyln.proto.message.array_types import SizedArrayType, DynamicArrayType, EllipsisArrayType, LengthFieldType
|
||||
import io
|
||||
|
||||
|
||||
def test_sized_array():
|
||||
|
@ -32,9 +33,11 @@ def test_sized_array():
|
|||
+ [0, 0, 10, 0, 0, 11, 0, 12])]]:
|
||||
v, _ = arrtype.val_from_str(s)
|
||||
assert arrtype.val_to_str(v, None) == s
|
||||
v2, _ = arrtype.val_from_bin(b, None)
|
||||
v2 = arrtype.read(io.BytesIO(b), None)
|
||||
assert v2 == v
|
||||
assert arrtype.val_to_bin(v, None) == b
|
||||
buf = io.BytesIO()
|
||||
arrtype.write(buf, v, None)
|
||||
assert buf.getvalue() == b
|
||||
|
||||
|
||||
def test_ellipsis_array():
|
||||
|
@ -52,23 +55,25 @@ def test_ellipsis_array():
|
|||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
for test in [[EllipsisArrayType(dummy("test1"), "test_arr", byte),
|
||||
"00010203",
|
||||
bytes([0, 1, 2, 3])],
|
||||
[EllipsisArrayType(dummy("test2"), "test_arr", u16),
|
||||
"[0,1,2,256]",
|
||||
bytes([0, 0, 0, 1, 0, 2, 1, 0])],
|
||||
[EllipsisArrayType(dummy("test3"), "test_arr", scid),
|
||||
"[1x2x3,4x5x6,7x8x9,10x11x12]",
|
||||
bytes([0, 0, 1, 0, 0, 2, 0, 3]
|
||||
+ [0, 0, 4, 0, 0, 5, 0, 6]
|
||||
+ [0, 0, 7, 0, 0, 8, 0, 9]
|
||||
+ [0, 0, 10, 0, 0, 11, 0, 12])]]:
|
||||
v, _ = test[0].val_from_str(test[1])
|
||||
assert test[0].val_to_str(v, None) == test[1]
|
||||
v2, _ = test[0].val_from_bin(test[2], None)
|
||||
for arrtype, s, b in [[EllipsisArrayType(dummy("test1"), "test_arr", byte),
|
||||
"00010203",
|
||||
bytes([0, 1, 2, 3])],
|
||||
[EllipsisArrayType(dummy("test2"), "test_arr", u16),
|
||||
"[0,1,2,256]",
|
||||
bytes([0, 0, 0, 1, 0, 2, 1, 0])],
|
||||
[EllipsisArrayType(dummy("test3"), "test_arr", scid),
|
||||
"[1x2x3,4x5x6,7x8x9,10x11x12]",
|
||||
bytes([0, 0, 1, 0, 0, 2, 0, 3]
|
||||
+ [0, 0, 4, 0, 0, 5, 0, 6]
|
||||
+ [0, 0, 7, 0, 0, 8, 0, 9]
|
||||
+ [0, 0, 10, 0, 0, 11, 0, 12])]]:
|
||||
v, _ = arrtype.val_from_str(s)
|
||||
assert arrtype.val_to_str(v, None) == s
|
||||
v2 = arrtype.read(io.BytesIO(b), None)
|
||||
assert v2 == v
|
||||
assert test[0].val_to_bin(v, None) == test[2]
|
||||
buf = io.BytesIO()
|
||||
arrtype.write(buf, v, None)
|
||||
assert buf.getvalue() == b
|
||||
|
||||
|
||||
def test_dynamic_array():
|
||||
|
@ -93,27 +98,29 @@ def test_dynamic_array():
|
|||
|
||||
lenfield = field_dummy('lenfield', LengthFieldType(u16))
|
||||
|
||||
for test in [[DynamicArrayType(dummy("test1"), "test_arr", byte,
|
||||
lenfield),
|
||||
"00010203",
|
||||
bytes([0, 1, 2, 3])],
|
||||
[DynamicArrayType(dummy("test2"), "test_arr", u16,
|
||||
lenfield),
|
||||
"[0,1,2,256]",
|
||||
bytes([0, 0, 0, 1, 0, 2, 1, 0])],
|
||||
[DynamicArrayType(dummy("test3"), "test_arr", scid,
|
||||
lenfield),
|
||||
"[1x2x3,4x5x6,7x8x9,10x11x12]",
|
||||
bytes([0, 0, 1, 0, 0, 2, 0, 3]
|
||||
+ [0, 0, 4, 0, 0, 5, 0, 6]
|
||||
+ [0, 0, 7, 0, 0, 8, 0, 9]
|
||||
+ [0, 0, 10, 0, 0, 11, 0, 12])]]:
|
||||
for arrtype, s, b in [[DynamicArrayType(dummy("test1"), "test_arr", byte,
|
||||
lenfield),
|
||||
"00010203",
|
||||
bytes([0, 1, 2, 3])],
|
||||
[DynamicArrayType(dummy("test2"), "test_arr", u16,
|
||||
lenfield),
|
||||
"[0,1,2,256]",
|
||||
bytes([0, 0, 0, 1, 0, 2, 1, 0])],
|
||||
[DynamicArrayType(dummy("test3"), "test_arr", scid,
|
||||
lenfield),
|
||||
"[1x2x3,4x5x6,7x8x9,10x11x12]",
|
||||
bytes([0, 0, 1, 0, 0, 2, 0, 3]
|
||||
+ [0, 0, 4, 0, 0, 5, 0, 6]
|
||||
+ [0, 0, 7, 0, 0, 8, 0, 9]
|
||||
+ [0, 0, 10, 0, 0, 11, 0, 12])]]:
|
||||
|
||||
lenfield.fieldtype.add_length_for(field_dummy(test[1], test[0]))
|
||||
v, _ = test[0].val_from_str(test[1])
|
||||
otherfields = {test[1]: v}
|
||||
assert test[0].val_to_str(v, otherfields) == test[1]
|
||||
v2, _ = test[0].val_from_bin(test[2], otherfields)
|
||||
lenfield.fieldtype.add_length_for(field_dummy(s, arrtype))
|
||||
v, _ = arrtype.val_from_str(s)
|
||||
otherfields = {s: v}
|
||||
assert arrtype.val_to_str(v, otherfields) == s
|
||||
v2 = arrtype.read(io.BytesIO(b), otherfields)
|
||||
assert v2 == v
|
||||
assert test[0].val_to_bin(v, otherfields) == test[2]
|
||||
buf = io.BytesIO()
|
||||
arrtype.write(buf, v, None)
|
||||
assert buf.getvalue() == b
|
||||
lenfield.fieldtype.len_for = []
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#! /usr/bin/python3
|
||||
from pyln.proto.message.fundamental_types import fundamental_types
|
||||
import io
|
||||
|
||||
|
||||
def test_fundamental_types():
|
||||
|
@ -67,8 +68,10 @@ def test_fundamental_types():
|
|||
for test in expect[t.name]:
|
||||
v, _ = t.val_from_str(test[0])
|
||||
assert t.val_to_str(v, None) == test[0]
|
||||
v2, _ = t.val_from_bin(test[1], None)
|
||||
v2 = t.read(io.BytesIO(test[1]), None)
|
||||
assert v2 == v
|
||||
assert t.val_to_bin(v, None) == test[1]
|
||||
buf = io.BytesIO()
|
||||
t.write(buf, v, None)
|
||||
assert buf.getvalue() == test[1]
|
||||
|
||||
assert untested == set(['varint'])
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#! /usr/bin/python3
|
||||
from pyln.proto.message import MessageNamespace, Message
|
||||
import pytest
|
||||
import io
|
||||
|
||||
|
||||
def test_fundamental():
|
||||
|
@ -51,9 +52,10 @@ def test_static_array():
|
|||
+ [0, 0, 10, 0, 0, 11, 0, 12])]]:
|
||||
m = Message.from_str(ns, test[0])
|
||||
assert m.to_str() == test[0]
|
||||
v = m.to_bin()
|
||||
assert v == test[1]
|
||||
assert Message.from_bin(ns, test[1]).to_str() == test[0]
|
||||
buf = io.BytesIO()
|
||||
m.write(buf)
|
||||
assert buf.getvalue() == test[1]
|
||||
assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
|
||||
|
||||
|
||||
def test_subtype():
|
||||
|
@ -78,9 +80,10 @@ def test_subtype():
|
|||
+ [0, 0, 0, 7, 0, 0, 0, 8])]]:
|
||||
m = Message.from_str(ns, test[0])
|
||||
assert m.to_str() == test[0]
|
||||
v = m.to_bin()
|
||||
assert v == test[1]
|
||||
assert Message.from_bin(ns, test[1]).to_str() == test[0]
|
||||
buf = io.BytesIO()
|
||||
m.write(buf)
|
||||
assert buf.getvalue() == test[1]
|
||||
assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
|
||||
|
||||
# Test missing field logic.
|
||||
m = Message.from_str(ns, "test1", incomplete_ok=True)
|
||||
|
@ -111,16 +114,19 @@ def test_tlv():
|
|||
+ [253, 0, 255, 4, 1, 2, 3, 4])]]:
|
||||
m = Message.from_str(ns, test[0])
|
||||
assert m.to_str() == test[0]
|
||||
v = m.to_bin()
|
||||
assert v == test[1]
|
||||
assert Message.from_bin(ns, test[1]).to_str() == test[0]
|
||||
buf = io.BytesIO()
|
||||
m.write(buf)
|
||||
assert buf.getvalue() == test[1]
|
||||
assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
|
||||
|
||||
# Ordering test (turns into canonical ordering)
|
||||
m = Message.from_str(ns, 'test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304},4=010203}')
|
||||
assert m.to_bin() == bytes([0, 1]
|
||||
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
|
||||
+ [4, 3, 1, 2, 3]
|
||||
+ [253, 0, 255, 4, 1, 2, 3, 4])
|
||||
buf = io.BytesIO()
|
||||
m.write(buf)
|
||||
assert buf.getvalue() == bytes([0, 1]
|
||||
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
|
||||
+ [4, 3, 1, 2, 3]
|
||||
+ [253, 0, 255, 4, 1, 2, 3, 4])
|
||||
|
||||
|
||||
def test_message_constructor():
|
||||
|
@ -135,10 +141,12 @@ def test_message_constructor():
|
|||
m = Message(ns.get_msgtype('test1'),
|
||||
tlvs='{tlv1={field1=01020304,field2=5}'
|
||||
',tlv2={field3=01020304},4=010203}')
|
||||
assert m.to_bin() == bytes([0, 1]
|
||||
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
|
||||
+ [4, 3, 1, 2, 3]
|
||||
+ [253, 0, 255, 4, 1, 2, 3, 4])
|
||||
buf = io.BytesIO()
|
||||
m.write(buf)
|
||||
assert buf.getvalue() == bytes([0, 1]
|
||||
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
|
||||
+ [4, 3, 1, 2, 3]
|
||||
+ [253, 0, 255, 4, 1, 2, 3, 4])
|
||||
|
||||
|
||||
def test_dynamic_array():
|
||||
|
@ -151,13 +159,15 @@ def test_dynamic_array():
|
|||
# This one is fine.
|
||||
m = Message(ns.get_msgtype('test1'),
|
||||
arr1='01020304', arr2='[1,2,3,4]')
|
||||
assert m.to_bin() == bytes([0, 1]
|
||||
+ [0, 4]
|
||||
+ [1, 2, 3, 4]
|
||||
+ [0, 0, 0, 1,
|
||||
0, 0, 0, 2,
|
||||
0, 0, 0, 3,
|
||||
0, 0, 0, 4])
|
||||
buf = io.BytesIO()
|
||||
m.write(buf)
|
||||
assert buf.getvalue() == bytes([0, 1]
|
||||
+ [0, 4]
|
||||
+ [1, 2, 3, 4]
|
||||
+ [0, 0, 0, 1,
|
||||
0, 0, 0, 2,
|
||||
0, 0, 0, 3,
|
||||
0, 0, 0, 4])
|
||||
|
||||
# These ones are not
|
||||
with pytest.raises(ValueError, match='Inconsistent length.*count'):
|
||||
|
|
Loading…
Reference in New Issue