pyln: Implement sphinx onion packet generation in python

Suggested-by: Rusty Russell <@rustyrussell>
Signed-off-by: Christian Decker <@cdecker>
This commit is contained in:
Christian Decker 2020-07-31 13:35:30 +02:00 committed by Rusty Russell
parent 0b5e6c5be1
commit 96b182a084
2 changed files with 519 additions and 6 deletions

View File

@ -1,6 +1,15 @@
from .primitives import varint_decode, varint_encode
from io import BytesIO, SEEK_CUR
from .primitives import varint_decode, varint_encode, Secret
from .wire import PrivateKey, PublicKey, ecdh
from binascii import hexlify, unhexlify
from collections import namedtuple
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
from hashlib import sha256
from io import BytesIO, SEEK_CUR
from typing import List, Optional, Union
import coincurve
import os
import struct
@ -50,14 +59,19 @@ class LegacyOnionPayload(OnionPayload):
self.outgoing_cltv_value = outgoing_cltv_value
if isinstance(short_channel_id, str) and 'x' in short_channel_id:
# Convert the short_channel_id from its string representation to its numeric representation
# Convert the short_channel_id from its string representation to
# its numeric representation
block, tx, out = short_channel_id.split('x')
num_scid = int(block) << 40 | int(tx) << 16 | int(out)
self.short_channel_id = num_scid
elif isinstance(short_channel_id, int):
self.short_channel_id = short_channel_id
else:
raise ValueError("short_channel_id format cannot be recognized: {}".format(short_channel_id))
raise ValueError(
"short_channel_id format cannot be recognized: {}".format(
short_channel_id
)
)
@classmethod
def from_bytes(cls, b):
@ -242,6 +256,238 @@ class SignatureField(TlvField):
pass
VERSION_SIZE = 1
REALM_SIZE = 1
HMAC_SIZE = 32
PUBKEY_SIZE = 33
ROUTING_INFO_SIZE = 1300
TOTAL_PACKET_SIZE = VERSION_SIZE + PUBKEY_SIZE + HMAC_SIZE + ROUTING_INFO_SIZE
class RoutingOnion(object):
def __init__(
self, version: int,
ephemeralkey: PublicKey,
payloads: bytes,
hmac: bytes
):
assert(len(payloads) == ROUTING_INFO_SIZE)
self.version = version
self.payloads = payloads
self.ephemeralkey = ephemeralkey
self.hmac = hmac
@classmethod
def from_bin(cls, b: bytes):
if len(b) != TOTAL_PACKET_SIZE:
raise ValueError(
"Encoded binary RoutingOnion size mismatch: {} != {}".format(
len(b), TOTAL_PACKET_SIZE
)
)
version = int(b[0])
ephemeralkey = PublicKey(b[1:34])
payloads = b[34:1334]
hmac = b[1334:]
assert(len(payloads) == ROUTING_INFO_SIZE
and len(hmac) == HMAC_SIZE)
return cls(version=version, ephemeralkey=ephemeralkey,
payloads=payloads, hmac=hmac)
@classmethod
def from_hex(cls, s: str):
return cls.from_bin(unhexlify(s))
def to_bin(self) -> bytes:
ephkey = self.ephemeralkey.to_bytes()
return struct.pack("b", self.version) + \
ephkey + \
self.payloads + \
self.hmac
def to_hex(self):
return hexlify(self.to_bin())
KeySet = namedtuple('KeySet', ['rho', 'mu', 'um', 'pad', 'gamma', 'pi'])
def xor_inplace(d: Union[bytearray, memoryview],
a: Union[bytearray, memoryview],
b: Union[bytearray, memoryview]):
"""Compute a xor b and store the result in d
"""
assert(len(a) == len(b) and len(d) == len(b))
for i in range(len(a)):
d[i] = a[i] ^ b[i]
def xor(a: Union[bytearray, memoryview],
b: Union[bytearray, memoryview]) -> bytearray:
assert(len(a) == len(b))
d = bytearray(len(a))
xor_inplace(d, a, b)
return d
def generate_key(secret: bytes, prefix: bytes):
h = hmac.HMAC(prefix, hashes.SHA256(), backend=default_backend())
h.update(secret)
return h.finalize()
def generate_keyset(secret: Secret) -> KeySet:
types = [bytes(f, 'ascii') for f in KeySet._fields]
keys = [generate_key(secret.data, t) for t in types]
return KeySet(*keys)
class SphinxHopParam(object):
def __init__(self, secret: Secret, ephemeralkey: PublicKey):
self.secret = secret
self.ephemeralkey = ephemeralkey
self.blind = blind(self.ephemeralkey, self.secret)
self.keys = generate_keyset(self.secret)
class SphinxHop(object):
def __init__(self, pubkey: PublicKey, payload: bytes):
self.pubkey = pubkey
self.payload = payload
self.hmac: Optional[bytes] = None
def __len__(self):
return len(self.payload) + HMAC_SIZE
def blind(pubkey, sharedsecret) -> Secret:
m = sha256()
m.update(pubkey.to_bytes())
m.update(sharedsecret.to_bytes())
return Secret(m.digest())
def blind_group_element(pubkey, blind: Secret) -> PublicKey:
pubkey = coincurve.PublicKey(data=pubkey.to_bytes())
blinded = pubkey.multiply(blind.to_bytes(), update=False)
return PublicKey(blinded.format(compressed=True))
def chacha20_stream(key: bytes, dest: Union[bytearray, memoryview]):
algorithm = algorithms.ChaCha20(key, b'\x00' * 16)
cipher = Cipher(algorithm, None, backend=default_backend())
encryptor = cipher.encryptor()
encryptor.update_into(dest, dest)
class SphinxPath(object):
def __init__(self, hops: List[SphinxHop], assocdata: bytes = None,
session_key: Optional[Secret] = None):
self.hops = hops
self.assocdata: Optional[bytes] = assocdata
if session_key is not None:
self.session_key = session_key
else:
self.session_key = Secret(os.urandom(32))
def get_filler(self) -> memoryview:
filler_size = sum(len(h) for h in self.hops[1:])
filler = memoryview(bytearray(filler_size))
params = self.get_hop_params()
for i in range(len(self.hops[:-1])):
h = self.hops[i]
p = params[i]
filler_offset = sum(len(sph) for sph in self.hops[:i])
filler_start = ROUTING_INFO_SIZE - filler_offset
filler_end = ROUTING_INFO_SIZE + len(h)
filler_len = filler_end - filler_start
stream = bytearray(filler_end)
chacha20_stream(p.keys.rho, stream)
xor_inplace(filler[:filler_len], filler[:filler_len],
stream[filler_start:filler_end])
return filler
def compile(self) -> RoutingOnion:
buf = bytearray(ROUTING_INFO_SIZE)
# Prefill the buffer with the pseudorandom stream to avoid telling the
# last hop the real payload size through zero ranges.
padkey = generate_key(self.session_key.data, b'pad')
params = self.get_hop_params()
chacha20_stream(padkey, buf)
filler = self.get_filler()
nexthmac = bytes(32)
for i, h, p in zip(
range(len(self.hops)),
reversed(self.hops),
reversed(params)):
h.hmac = nexthmac
shift_size = len(h)
assert(shift_size == len(h.payload) + HMAC_SIZE)
buf[shift_size:] = buf[:ROUTING_INFO_SIZE - shift_size]
buf[:shift_size] = h.payload + h.hmac
# Encrypt
chacha20_stream(p.keys.rho, buf)
if i == 0:
# Place the filler at the correct position
buf[ROUTING_INFO_SIZE - len(filler):] = filler
# Finally compute the hmac that the next hop will use to verify
# the onion's integrity.
hh = hmac.HMAC(p.keys.mu, hashes.SHA256(),
backend=default_backend())
hh.update(buf)
if self.assocdata is not None:
hh.update(self.assocdata)
nexthmac = hh.finalize()
return RoutingOnion(
version=0,
ephemeralkey=params[0].ephemeralkey,
hmac=nexthmac,
payloads=buf,
)
def get_hop_params(self) -> List[SphinxHopParam]:
assert(self.session_key is not None)
secret = ecdh(PrivateKey(self.session_key.data),
self.hops[0].pubkey)
sph = SphinxHopParam(
ephemeralkey=PrivateKey(self.session_key.data).public_key(),
secret=secret,
)
params = [sph]
for i, h in enumerate(self.hops[1:]):
prev = params[-1]
ek = blind_group_element(prev.ephemeralkey,
prev.blind)
# Start by blinding the current hop's pubkey with the session_key
temp = blind_group_element(h.pubkey, self.session_key)
# Then apply blind for all previous hops
for p in params:
temp = blind_group_element(temp, p.blind)
# Finally hash the compressed resulting pubkey to get the secret
secret = Secret(sha256(temp.to_bytes()).digest())
sph = SphinxHopParam(secret=secret, ephemeralkey=ek)
params.append(sph)
return params
# A mapping of known TLV types
tlv_types = {
2: (Tu64Field, 'amt_to_forward'),

View File

@ -1,6 +1,10 @@
from binascii import unhexlify
from binascii import hexlify, unhexlify
from io import BytesIO
from pyln.proto import onion
from typing import Tuple
import json
import os
import unittest
def test_legacy_payload():
@ -58,3 +62,266 @@ def test_tu_fields():
for i, o in pairs:
f = onion.Tu64Field(1, i)
assert(f.to_bytes() == o)
dirname = os.path.dirname(__file__)
vector_base = os.path.join(dirname, '..', '..', '..', 'tests', 'vectors')
have_vectors = os.path.exists(os.path.join(vector_base, 'onion-test-v0.json'))
def get_vector(filename):
fullname = os.path.join(vector_base, filename)
return json.load(open(fullname, 'r'))
@unittest.skipIf(not have_vectors, "Need the test vectors")
def test_onion_parse():
"""Make sure we parse the serialized onion into its components.
"""
vec = get_vector('onion-test-v0.json')
o = vec['onion']
o = onion.RoutingOnion.from_hex(o)
assert(o.version == 0)
assert(hexlify(o.hmac) == b'b8640887e027e946df96488b47fbc4a4fadaa8beda4abe446fafea5403fae2ef')
assert(o.to_bin() == unhexlify(vec['onion']))
def test_generate_keyset():
unhex = unhexlify
secret = onion.Secret(unhex(
b'53eb63ea8a3fec3b3cd433b85cd62a4b145e1dda09391b348c4e1cd36a03ea66'
))
keys = onion.generate_keyset(secret)
expected = onion.KeySet(
rho=unhex(b'ce496ec94def95aadd4bec15cdb41a740c9f2b62347c4917325fcc6fb0453986'),
mu=unhex(b'b57061dc6d0a2b9f261ac410c8b26d64ac5506cbba30267a649c28c179400eba'),
um=unhex(b'3ca76e96fad1a0300928639d203b4369e81254032156c936179077b08091ca49'),
pad=unhex(b'3c348715f933c32b5571e2c9136b17c4da2e8fd13e35b7092deff56650eea958'),
gamma=unhex(b'c5b96917bc536aff7c2d6584bd60cf3b99151ccac18f173133f1fd0bdcae08b5'),
pi=unhex(b'3a70333f46a4fd1b3f72acae87760b147b07fe4923131066906a4044d4f1ddd1'),
)
assert(keys == expected)
def test_blind():
tests = [
(b'02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619',
b'53eb63ea8a3fec3b3cd433b85cd62a4b145e1dda09391b348c4e1cd36a03ea66',
b'2ec2e5da605776054187180343287683aa6a51b4b1c04d6dd49c45d8cffb3c36'),
(b'028f9438bfbf7feac2e108d677e3a82da596be706cc1cf342b75c7b7e22bf4e6e2',
b'a6519e98832a0b179f62123b3567c106db99ee37bef036e783263602f3488fae',
b'bf66c28bc22e598cfd574a1931a2bafbca09163df2261e6d0056b2610dab938f'),
(b'03bfd8225241ea71cd0843db7709f4c222f62ff2d4516fd38b39914ab6b83e0da0',
b'3a6b412548762f0dbccce5c7ae7bb8147d1caf9b5471c34120b30bc9c04891cc',
b'a1f2dadd184eb1627049673f18c6325814384facdee5bfd935d9cb031a1698a5'),
(b'031dde6926381289671300239ea8e57ffaf9bebd05b9a5b95beaf07af05cd43595',
b'21e13c2d7cfe7e18836df50872466117a295783ab8aab0e7ecc8c725503ad02d',
b'7cfe0b699f35525029ae0fa437c69d0f20f7ed4e3916133f9cacbb13c82ff262'),
(b'03a214ebd875aab6ddfd77f22c5e7311d7f77f17a169e599f157bbcdae8bf071f4',
b'b5756b9b542727dbafc6765a49488b023a725d631af688fc031217e90770c328',
b'c96e00dddaf57e7edcd4fb5954be5b65b09f17cb6d20651b4e90315be5779205'),
]
for pubkey, sharedsecret, expected in tests:
expected = onion.Secret(unhexlify(expected))
pubkey = onion.PublicKey(unhexlify(pubkey))
sharedsecret = onion.Secret(unhexlify(sharedsecret))
res = onion.blind(pubkey, sharedsecret)
assert(res == expected)
def test_blind_group_element():
tests = [
(b'031dde6926381289671300239ea8e57ffaf9bebd05b9a5b95beaf07af05cd43595',
b'7cfe0b699f35525029ae0fa437c69d0f20f7ed4e3916133f9cacbb13c82ff262',
b'03a214ebd875aab6ddfd77f22c5e7311d7f77f17a169e599f157bbcdae8bf071f4'),
(b'028f9438bfbf7feac2e108d677e3a82da596be706cc1cf342b75c7b7e22bf4e6e2',
b'bf66c28bc22e598cfd574a1931a2bafbca09163df2261e6d0056b2610dab938f',
b'03bfd8225241ea71cd0843db7709f4c222f62ff2d4516fd38b39914ab6b83e0da0'),
(b'03bfd8225241ea71cd0843db7709f4c222f62ff2d4516fd38b39914ab6b83e0da0',
b'a1f2dadd184eb1627049673f18c6325814384facdee5bfd935d9cb031a1698a5',
b'031dde6926381289671300239ea8e57ffaf9bebd05b9a5b95beaf07af05cd43595'),
(b'031dde6926381289671300239ea8e57ffaf9bebd05b9a5b95beaf07af05cd43595',
b'7cfe0b699f35525029ae0fa437c69d0f20f7ed4e3916133f9cacbb13c82ff262',
b'03a214ebd875aab6ddfd77f22c5e7311d7f77f17a169e599f157bbcdae8bf071f4'),
]
for pubkey, blind, expected in tests:
expected = onion.PublicKey(unhexlify(expected))
pubkey = onion.PublicKey(unhexlify(pubkey))
blind = onion.Secret(unhexlify(blind))
res = onion.blind_group_element(pubkey, blind)
assert(res.to_bytes() == expected.to_bytes())
def test_xor():
tab = [
(b'\x01', b'\x01', b'\x00'),
(b'\x01', b'\x00', b'\x01'),
(b'\x00', b'\x01', b'\x01'),
(b'\x00', b'\x00', b'\x00'),
(b'\xa0', b'\x01', b'\xa1'),
]
for a, b, expected in tab:
assert(bytearray(expected) == onion.xor(a, b))
d = bytearray(len(a))
onion.xor_inplace(d, a, b)
assert(d == expected)
def sphinx_path_from_test_vector(filename: str) -> Tuple[onion.SphinxPath, dict]:
"""Loads a sphinx test vector from the repo root.
"""
path = os.path.dirname(__file__)
root = os.path.join(path, '..', '..', '..')
filename = os.path.join(root, filename)
v = json.load(open(filename, 'r'))
session_key = onion.Secret(unhexlify(v['generate']['session_key']))
associated_data = unhexlify(v['generate']['associated_data'])
hops = []
for h in v['generate']['hops']:
payload = unhexlify(h['payload'])
if h['type'] == 'raw' or h['type'] == 'tlv':
b = BytesIO()
onion.varint_encode(len(payload), b)
payload = b.getvalue() + payload
elif h['type'] == 'legacy':
padlen = 32 - len(payload)
payload = b'\x00' + payload + (b'\x00' * padlen)
assert(len(payload) == 33)
pubkey = onion.PublicKey(unhexlify(h['pubkey']))
hops.append(onion.SphinxHop(
pubkey=pubkey,
payload=payload,
))
return onion.SphinxPath(hops=hops, session_key=session_key,
assocdata=associated_data), v
def test_hop_params():
"""Test that we generate the onion parameters correctly.
Extracted from running the c-lightning implementation:
```bash
devtools/onion runtest tests/vectors/onion-test-multi-frame.json
```
"""
sp, v = sphinx_path_from_test_vector(
'tests/vectors/onion-test-multi-frame.json'
)
params = sp.get_hop_params()
expected = [(
b'02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619',
b'53eb63ea8a3fec3b3cd433b85cd62a4b145e1dda09391b348c4e1cd36a03ea66',
b'2ec2e5da605776054187180343287683aa6a51b4b1c04d6dd49c45d8cffb3c36'
), (
b'028f9438bfbf7feac2e108d677e3a82da596be706cc1cf342b75c7b7e22bf4e6e2',
b'a6519e98832a0b179f62123b3567c106db99ee37bef036e783263602f3488fae',
b'bf66c28bc22e598cfd574a1931a2bafbca09163df2261e6d0056b2610dab938f'
), (
b'03bfd8225241ea71cd0843db7709f4c222f62ff2d4516fd38b39914ab6b83e0da0',
b'3a6b412548762f0dbccce5c7ae7bb8147d1caf9b5471c34120b30bc9c04891cc',
b'a1f2dadd184eb1627049673f18c6325814384facdee5bfd935d9cb031a1698a5'
), (
b'031dde6926381289671300239ea8e57ffaf9bebd05b9a5b95beaf07af05cd43595',
b'21e13c2d7cfe7e18836df50872466117a295783ab8aab0e7ecc8c725503ad02d',
b'7cfe0b699f35525029ae0fa437c69d0f20f7ed4e3916133f9cacbb13c82ff262'
), (
b'03a214ebd875aab6ddfd77f22c5e7311d7f77f17a169e599f157bbcdae8bf071f4',
b'b5756b9b542727dbafc6765a49488b023a725d631af688fc031217e90770c328',
b'c96e00dddaf57e7edcd4fb5954be5b65b09f17cb6d20651b4e90315be5779205'
)]
assert(len(params) == len(sp.hops))
for a, b in zip(expected, params):
assert(a[0] == hexlify(b.ephemeralkey.to_bytes()))
assert(a[1] == hexlify(b.secret.to_bytes()))
assert(a[2] == hexlify(b.blind.to_bytes()))
def test_filler():
"""Generate the filler from a sphinx path
The expected filler was generated using the following test vector, and by
instrumenting the sphinx code:
```bash
devtools/onion runtest tests/vectors/onion-test-multi-frame.json
```
"""
expected = (
b'b77d99c935d3f32469844f7e09340a91ded147557bdd0456c369f7e449587c0f566'
b'6faab58040146db49024db88553729bce12b860391c29c1779f022ae48a9cb314ca'
b'35d73fc91addc92632bcf7ba6fd9f38e6fd30fabcedbd5407b6648073c38331ee7a'
b'b0332f41f550c180e1601f8c25809ed75b3a1e78635a2ef1b828e92c9658e76e49f'
b'995d72cf9781eec0c838901d0bdde3ac21c13b4979ac9e738a1c4d0b9741d58e777'
b'ad1aed01263ad1390d36a18a6b92f4f799dcf75edbb43b7515e8d72cb4f827a9af0'
b'e7b9338d07b1a24e0305b5535f5b851b1144bad6238b9d9482b5ba6413f1aafac3c'
b'dde5067966ed8b78f7c1c5f916a05f874d5f17a2b7d0ae75d66a5f1bb6ff932570d'
b'c5a0cf3ce04eb5d26bc55c2057af1f8326e20a7d6f0ae644f09d00fac80de60f20a'
b'ceee85be41a074d3e1dda017db79d0070b99f54736396f206ee3777abd4c00a4bb9'
b'5c871750409261e3b01e59a3793a9c20159aae4988c68397a1443be6370fd9614e4'
b'6108291e615691729faea58537209fa668a172d066d0efff9bc77c2bd34bd77870a'
b'd79effd80140990e36731a0b72092f8d5bc8cd346762e93b2bf203d00264e4bc136'
b'fc142de8f7b69154deb05854ea88e2d7506222c95ba1aab065c8a'
)
sp, v = sphinx_path_from_test_vector(
'tests/vectors/onion-test-multi-frame.json'
)
filler = sp.get_filler()
assert(2 * len(filler) == len(expected))
assert(hexlify(filler) == expected)
def test_chacha20_stream():
"""Test that we can generate a correct stream for encryption/decryption
"""
tests = [(
b'ce496ec94def95aadd4bec15cdb41a740c9f2b62347c4917325fcc6fb0453986',
b'e5f14350c2a76fc232b5e46d421e9615471ab9e0bc887beff8c95fdb878f7b3a'
), (
b'450ffcabc6449094918ebe13d4f03e433d20a3d28a768203337bc40b6e4b2c59',
b'03455084337a8dbe5d5bfa27f825f3a9ae4f431f6f7a16ad786704887cbd85bd'
), (
b'11bf5c4f960239cb37833936aa3d02cea82c0f39fd35f566109c41f9eac8deea',
b'e22ea443b8a275174533abc584fae578e80ed4c1851d0554235171e45e1e2a18'
), (
b'cbe784ab745c13ff5cffc2fbe3e84424aa0fd669b8ead4ee562901a4a4e89e9e',
b'35de88a5f7e63d2c0072992046827fc997c3312b54591844fc713c0cca433626'
)]
for a, b in tests:
stream = bytearray(32)
onion.chacha20_stream(unhexlify(a), stream)
assert(hexlify(stream) == b)
# And since we're at it make sure we can actually encrypt inplace on a
# memoryview.
stream = memoryview(bytearray(64))
onion.chacha20_stream(unhexlify(a), memoryview(stream[16:-16]))
assert(hexlify(stream) == b'00' * 16 + b + b'00' * 16)
def test_sphinx_path_compile():
f = 'tests/vectors/onion-test-multi-frame.json'
sp, v = sphinx_path_from_test_vector(f)
o = sp.compile()
assert(o.to_bin() == unhexlify(v['onion']))