diff --git a/contrib/pyln-proto/pyln/proto/onion.py b/contrib/pyln-proto/pyln/proto/onion.py new file mode 100644 index 000000000..9c9e9945d --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/onion.py @@ -0,0 +1,236 @@ +from .primitives import varint_decode, varint_encode +from io import BytesIO, SEEK_CUR +from binascii import hexlify, unhexlify +import struct + + +class OnionPayload(object): + + @classmethod + def from_bytes(cls, b): + if isinstance(b, bytes): + b = BytesIO(b) + + realm = b.read(1) + b.seek(-1, SEEK_CUR) + + if realm == b'\x00': + return LegacyOnionPayload.from_bytes(b) + elif realm != b'\x01': + return TlvPayload.from_bytes(b, skip_length=False) + else: + raise ValueError("Onion payloads with realm 0x01 are unsupported") + + @classmethod + def from_hex(cls, s): + if isinstance(s, str): + s = s.encode('ASCII') + return cls.from_bytes(bytes(unhexlify(s))) + + def to_bytes(self): + raise ValueError("OnionPayload is an abstract class, use " + "LegacyOnionPayload or TlvPayload instead") + + def to_hex(self): + return hexlify(self.to_bytes()).decode('ASCII') + + +class LegacyOnionPayload(OnionPayload): + + def __init__(self, amt_to_forward, outgoing_cltv_value, + short_channel_id=None, padding=None): + assert(padding is None or len(padding) == 12) + self.padding = b'\x00' * 12 if padding is None else padding + + if isinstance(amt_to_forward, str): + self.amt_to_forward = int(amt_to_forward) + else: + self.amt_to_forward = amt_to_forward + + self.outgoing_cltv_value = outgoing_cltv_value + + if isinstance(short_channel_id, str) and 'x' in short_channel_id: + # Convert the short_channel_id from its string representation to its numeric representation + block, tx, out = short_channel_id.split('x') + num_scid = int(block) << 40 | int(tx) << 16 | int(out) + self.short_channel_id = num_scid + elif isinstance(short_channel_id, int): + self.short_channel_id = short_channel_id + else: + raise ValueError("short_channel_id format cannot be recognized: {}".format(short_channel_id)) + + @classmethod + def from_bytes(cls, b): + if isinstance(b, bytes): + b = BytesIO(b) + + assert(b.read(1) == b'\x00') + + s, a, o = struct.unpack("!QQL", b.read(20)) + padding = b.read(12) + return LegacyOnionPayload(a, o, s, padding) + + def to_bytes(self, include_realm=True): + b = b'' + if include_realm: + b += b'\x00' + + b += struct.pack("!Q", self.short_channel_id) + b += struct.pack("!Q", self.amt_to_forward) + b += struct.pack("!L", self.outgoing_cltv_value) + b += self.padding + assert(len(b) == 32 + include_realm) + return b + + def to_hex(self, include_realm=True): + return hexlify(self.to_bytes(include_realm)).decode('ASCII') + + def __str__(self): + return ("LegacyOnionPayload[scid={self.short_channel_id}, " + "amt_to_forward={self.amt_to_forward}, " + "outgoing_cltv={self.outgoing_cltv_value}]").format(self=self) + + +class TlvPayload(OnionPayload): + + def __init__(self, fields=None): + self.fields = [] if fields is None else fields + + @classmethod + def from_bytes(cls, b, skip_length=False): + if isinstance(b, str): + b = b.encode('ASCII') + if isinstance(b, bytes): + b = BytesIO(b) + + if skip_length: + # Consume the entire remainder of the buffer. + payload_length = len(b.getvalue()) - b.tell() + else: + payload_length = varint_decode(b) + + instance = TlvPayload() + + start = b.tell() + while b.tell() < start + payload_length: + typenum = varint_decode(b) + if typenum is None: + break + length = varint_decode(b) + if length is None: + raise ValueError( + "Unable to read length at position {}".format(b.tell()) + ) + val = b.read(length) + + # Get the subclass that is the correct interpretation of this + # field. Default to the binary field type. + c = tlv_types.get(typenum, (TlvField, "unknown")) + cls = c[0] + field = cls.from_bytes(typenum=typenum, b=val, description=c[1]) + instance.fields.append(field) + + return instance + + @classmethod + def from_hex(cls, h): + return cls.from_bytes(unhexlify(h)) + + def add_field(self, typenum, value): + self.fields.append(TlvField(typenum=typenum, value=value)) + + def get(self, key, default=None): + for f in self.fields: + if f.typenum == key: + return f + return default + + def to_bytes(self): + ser = [f.to_bytes() for f in self.fields] + b = BytesIO() + varint_encode(sum([len(b) for b in ser]), b) + for f in ser: + b.write(f) + return b.getvalue() + + def __str__(self): + return "TlvPayload[" + ', '.join([str(f) for f in self.fields]) + "]" + + +class TlvField(object): + + def __init__(self, typenum, value=None, description=None): + self.typenum = typenum + self.value = value + self.description = description + + @classmethod + def from_bytes(cls, typenum, b, description=None): + return TlvField(typenum=typenum, value=b, description=description) + + def __str__(self): + return "TlvField[{description},{num}={hex}]".format( + description=self.description, + num=self.typenum, + hex=hexlify(self.value).decode('ASCII') + ) + + def to_bytes(self): + b = BytesIO() + varint_encode(self.typenum, b) + varint_encode(len(self.value), b) + b.write(self.value) + return b.getvalue() + + +class Tu32Field(TlvField): + pass + + +class Tu64Field(TlvField): + pass + + +class ShortChannelIdField(TlvField): + pass + + +class TextField(TlvField): + + @classmethod + def from_bytes(cls, typenum, b, description=None): + val = b.decode('UTF-8') + return TextField(typenum, value=val, description=description) + + def to_bytes(self): + b = BytesIO() + val = self.value.encode('UTF-8') + varint_encode(self.typenum, b) + varint_encode(len(val), b) + b.write(val) + return b.getvalue() + + def __str__(self): + return "TextField[{description},{num}=\"{val}\"]".format( + description=self.description, + num=self.typenum, + val=self.value, + ) + + +class HashField(TlvField): + pass + + +class SignatureField(TlvField): + pass + + +# A mapping of known TLV types +tlv_types = { + 2: (Tu64Field, 'amt_to_forward'), + 4: (Tu32Field, 'outgoing_cltv_value'), + 6: (ShortChannelIdField, 'short_channel_id'), + 34349334: (TextField, 'noise_message_body'), + 34349336: (SignatureField, 'noise_message_signature'), +} diff --git a/contrib/pyln-proto/pyln/proto/primitives.py b/contrib/pyln-proto/pyln/proto/primitives.py new file mode 100644 index 000000000..4c1d10ebe --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/primitives.py @@ -0,0 +1,70 @@ +import struct + + +def varint_encode(i, w): + """Encode an integer `i` into the writer `w` + """ + if i < 0xFD: + w.write(struct.pack("!B", i)) + elif i <= 0xFFFF: + w.write(struct.pack("!BH", 0xFD, i)) + elif i <= 0xFFFFFFFF: + w.write(struct.pack("!BL", 0xFE, i)) + else: + w.write(struct.pack("!BQ", 0xFF, i)) + + +def varint_decode(r): + """Decode an integer from reader `r` + """ + raw = r.read(1) + if len(raw) != 1: + return None + + i, = struct.unpack("!B", raw) + if i < 0xFD: + return i + elif i == 0xFD: + return struct.unpack("!H", r.read(2))[0] + elif i == 0xFE: + return struct.unpack("!L", r.read(4))[0] + else: + return struct.unpack("!Q", r.read(8))[0] + + +class ShortChannelId(object): + def __init__(self, block, txnum, outnum): + self.block = block + self.txnum = txnum + self.outnum = outnum + + @classmethod + def from_bytes(cls, b): + assert(len(b) == 8) + i, = struct.unpack("!Q", b) + return cls.from_int(i) + + @classmethod + def from_int(cls, i): + block = (i >> 40) & 0xFFFFFF + txnum = (i >> 16) & 0xFFFFFF + outnum = (i >> 0) & 0xFFFF + return cls(block=block, txnum=txnum, outnum=outnum) + + @classmethod + def from_str(self, s): + block, txnum, outnum = s.split('x') + return ShortChannelId(block=int(block), txnum=int(txnum), + outnum=int(outnum)) + + def to_int(self): + return self.block << 40 | self.txnum << 16 | self.outnum + + def to_bytes(self): + return struct.pack("!Q", self.to_int()) + + def __str__(self): + return "{self.block}x{self.txnum}x{self.outnum}".format(self=self) + + def __eq__(self, other): + return self.block == other.block and self.txnum == other.txnum and self.outnum == other.outnum diff --git a/contrib/pyln-proto/pyln/proto/zbase32.py b/contrib/pyln-proto/pyln/proto/zbase32.py new file mode 100644 index 000000000..fbae94f22 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/zbase32.py @@ -0,0 +1,56 @@ +import bitstring + + +zbase32_chars = b'ybndrfg8ejkmcpqxot1uwisza345h769' +zbase32_revchars = [ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 18, 255, 25, 26, 27, 30, 29, 7, 31, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 24, 1, 12, 3, 8, 5, 6, 28, 21, 9, 10, 255, 11, 2, + 16, 13, 14, 4, 22, 17, 19, 255, 20, 15, 0, 23, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255 +] + + +def bitarray_to_u5(barr): + assert len(barr) % 5 == 0 + ret = [] + s = bitstring.ConstBitStream(barr) + while s.pos != s.len: + ret.append(s.read(5).uint) + return ret + + +def u5_to_bitarray(arr): + ret = bitstring.BitArray() + for a in arr: + ret += bitstring.pack("uint:5", a) + return ret + + +def encode(b): + uint5s = bitarray_to_u5(b) + res = [zbase32_chars[c] for c in uint5s] + return bytes(res) + + +def decode(b): + if isinstance(b, str): + b = b.encode('ASCII') + + uint5s = [] + for c in b: + uint5s.append(zbase32_revchars[c]) + dec = u5_to_bitarray(uint5s) + return dec.bytes diff --git a/contrib/pyln-proto/requirements.txt b/contrib/pyln-proto/requirements.txt index 5daa26d68..5c7fec741 100644 --- a/contrib/pyln-proto/requirements.txt +++ b/contrib/pyln-proto/requirements.txt @@ -1,2 +1,3 @@ +bitstring==3.1.6 cryptography==2.7 coincurve==12.0.0 diff --git a/contrib/pyln-proto/tests/test_onion.py b/contrib/pyln-proto/tests/test_onion.py new file mode 100644 index 000000000..b45c41fe5 --- /dev/null +++ b/contrib/pyln-proto/tests/test_onion.py @@ -0,0 +1,32 @@ +from binascii import unhexlify + +from pyln.proto import onion + + +def test_legacy_payload(): + legacy = unhexlify( + b'00000067000001000100000000000003e800000075000000000000000000000000' + ) + payload = onion.OnionPayload.from_bytes(legacy) + assert(payload.to_bytes(include_realm=True) == legacy) + + +def test_tlv_payload(): + tlv = unhexlify( + b'58fe020c21160c48656c6c6f20776f726c6421fe020c21184076e8acd54afbf2361' + b'0b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d0205f7e4e1a12620e' + b'7fc8ce1c7d3651acefde899c33f12b6958d3304106a0' + ) + payload = onion.OnionPayload.from_bytes(tlv) + assert(payload.to_bytes() == tlv) + + fields = payload.fields + assert(len(fields) == 2) + assert(isinstance(fields[0], onion.TextField)) + assert(fields[0].typenum == 34349334 and fields[0].value == "Hello world!") + assert(fields[1].typenum == 34349336 and fields[1].value == unhexlify( + b'76e8acd54afbf23610b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d' + b'0205f7e4e1a12620e7fc8ce1c7d3651acefde899c33f12b6958d3304106a0' + )) + + assert(payload.to_bytes() == tlv) diff --git a/contrib/pyln-proto/tests/test_primitives.py b/contrib/pyln-proto/tests/test_primitives.py new file mode 100644 index 000000000..4c4a14681 --- /dev/null +++ b/contrib/pyln-proto/tests/test_primitives.py @@ -0,0 +1,30 @@ +from binascii import hexlify, unhexlify +from pyln.proto import zbase32 +from pyln.proto.primitives import ShortChannelId + + +def test_short_channel_id(): + num = 618150934845652992 + b = unhexlify(b'08941d00090d0000') + s = '562205x2317x0' + s1 = ShortChannelId.from_int(num) + s2 = ShortChannelId.from_str(s) + s3 = ShortChannelId.from_bytes(b) + expected = ShortChannelId(block=562205, txnum=2317, outnum=0) + + assert(s1 == expected) + assert(s2 == expected) + assert(s3 == expected) + + assert(expected.to_bytes() == b) + assert(str(expected) == s) + assert(expected.to_int() == num) + + +def test_zbase32(): + zb32 = b'd75qtmgijm79rpooshmgzjwji9gj7dsdat8remuskyjp9oq1ugkaoj6orbxzhuo4njtyh96e3aq84p1tiuz77nchgxa1s4ka4carnbiy' + b = zbase32.decode(zb32) + assert(hexlify(b) == b'1f76e8acd54afbf23610b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d0205f7e4e1a12620e7fc8ce1c7d3651acefde899c33f12b6958d3304106a0') + + enc = zbase32.encode(b) + assert(enc == zb32)