diff --git a/CHANGELOG.md b/CHANGELOG.md index 1927abac2..d51ca91bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ changes. - Protocol: handling `query_channel_range` for large numbers of blocks (eg. 4 billion) was slow due to a bug. - Fixed occasional deadlock with peers when exchanging huge amounts of gossip. +- You can no longer make giant unpayable "wumbo" invoices. ### Security diff --git a/lightningd/invoice.c b/lightningd/invoice.c index 22fd96748..e7f296791 100644 --- a/lightningd/invoice.c +++ b/lightningd/invoice.c @@ -370,6 +370,7 @@ static struct command_result *json_invoice(struct command *cmd, u64 *expiry; struct sha256 rhash; bool *exposeprivate; + const struct chainparams *chainparams; #if DEVELOPER const jsmntok_t *routes; #endif @@ -406,6 +407,14 @@ static struct command_result *json_invoice(struct command *cmd, strlen(desc_val)); } + chainparams = get_chainparams(cmd->ld); + if (msatoshi_val && *msatoshi_val > chainparams->max_payment_msat) { + return command_fail(cmd, JSONRPC2_INVALID_PARAMS, + "msatoshi cannot exceed %"PRIu64 + " millisatoshis", + chainparams->max_payment_msat); + } + if (fallbacks) { size_t i; const jsmntok_t *t; @@ -438,7 +447,7 @@ static struct command_result *json_invoice(struct command *cmd, /* Construct bolt11 string. */ info->b11 = new_bolt11(info, msatoshi_val); - info->b11->chain = get_chainparams(cmd->ld); + info->b11->chain = chainparams; info->b11->timestamp = time_now().ts.tv_sec; info->b11->payment_hash = rhash; info->b11->receiver_id = cmd->ld->id; diff --git a/tests/test_invoices.py b/tests/test_invoices.py index d7776953f..7d79b480f 100644 --- a/tests/test_invoices.py +++ b/tests/test_invoices.py @@ -52,6 +52,11 @@ def test_invoice(node_factory): assert 'routes' not in b11 assert 'warning_capacity' in inv + # Make sure no wumbo invoices + with pytest.raises(RpcError, match=r'msatoshi cannot exceed 4294967295 millisatoshis'): + l2.rpc.invoice(4294967295 + 1, 'inv3', '?') + l2.rpc.invoice(4294967295, 'inv3', '?') + def test_invoice_weirdstring(node_factory): l1 = node_factory.get_node()