msggen: Generate the cln-rpc Rust structs

We're generating these structs so we can parse them directly into
native objects.
This commit is contained in:
Christian Decker 2022-01-14 13:54:23 +01:00
parent 0fc0ffc961
commit b0053e2ca2
2 changed files with 383 additions and 0 deletions

View File

@ -0,0 +1,140 @@
from msggen.model import Method, CompositeField, Service
from msggen.rust import RustGenerator
from pathlib import Path
import subprocess
import json
def repo_root():
path = subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
return Path(path.strip().decode('UTF-8'))
def load_jsonrpc_method(name):
"""Load a method based on the file naming conventions for the JSON-RPC.
"""
base_path = (repo_root() / "doc" / "schemas").resolve()
req_file = base_path / f"{name.lower()}.request.json"
resp_file = base_path / f"{name.lower()}.schema.json"
request = CompositeField.from_js(json.load(open(req_file)), path=name)
response = CompositeField.from_js(json.load(open(resp_file)), path=name)
# Normalize the method request and response typename so they no
# longer conflict.
request.typename += "Request"
response.typename += "Response"
return Method(
name=name,
request=request,
response=response,
)
def load_jsonrpc_service():
method_names = [
"Getinfo",
# "ListPeers",
"ListFunds",
# "ListConfigs",
"ListChannels",
"AddGossip",
"AutoCleanInvoice",
"CheckMessage",
# "check", # No point in mapping this one
"Close",
# "connect",
# "createinvoice",
# "createonion",
# "datastore",
# "decodepay",
# "decode",
# "deldatastore",
# "delexpiredinvoice",
# "delinvoice",
# "delpay",
# "disableoffer",
# "disconnect",
# "feerates",
# "fetchinvoice",
# "fundchannel_cancel",
# "fundchannel_complete",
# "fundchannel",
# "fundchannel_start",
# "funderupdate",
# "fundpsbt",
# "getinfo",
# "getlog",
# "getroute",
# "getsharedsecret",
# "help",
# "invoice",
# "keysend",
# "listchannels",
# "listconfigs",
# "listdatastore",
# "listforwards",
# "listfunds",
# "listinvoices",
# "listnodes",
# "listoffers",
# "listpays",
# "listpeers",
# "listsendpays",
# "listtransactions",
# "multifundchannel",
# "multiwithdraw",
# "newaddr",
# "notifications",
# "offerout",
# "offer",
# "openchannel_abort",
# "openchannel_bump",
# "openchannel_init",
# "openchannel_signed",
# "openchannel_update",
# "parsefeerate",
# "pay",
# "ping",
# "plugin",
# "reserveinputs",
# "sendcustommsg",
# "sendinvoice",
# "sendonionmessage",
# "sendonion",
# "sendpay",
# "sendpsbt",
# "setchannelfee",
# "signmessage",
# "signpsbt",
# "stop",
# "txdiscard",
# "txprepare",
# "txsend",
# "unreserveinputs",
# "utxopsbt",
# "waitanyinvoice",
# "waitblockheight",
# "waitinvoice",
# "waitsendpay",
# "withdraw",
]
methods = [load_jsonrpc_method(name) for name in method_names]
service = Service(name="Node", methods=methods)
service.includes = ['primitives.proto'] # Make sure we have the primitives included.
return service
def genrustjsonrpc(service):
fname = repo_root() / "cln-rpc" / "src" / "model.rs"
dest = open(fname, "w")
RustGenerator(dest).generate(service)
def run():
service = load_jsonrpc_service()
genrustjsonrpc(service)
if __name__ == "__main__":
run()

View File

@ -0,0 +1,243 @@
from typing import TextIO
from typing import Tuple
from textwrap import dedent, indent
import logging
import sys
from .model import (ArrayField, CompositeField, EnumField,
PrimitiveField, Service)
logger = logging.getLogger(__name__)
# The following words need to be changed, otherwise they'd clash with
# built-in keywords.
keywords = ["in", "type"]
# Manual overrides for some of the auto-generated types for paths
# Manual overrides for some of the auto-generated types for paths
overrides = {
'ListPeers.peers[].channels[].state_changes[].old_state': "ChannelState",
'ListPeers.peers[].channels[].state_changes[].new_state': "ChannelState",
'ListPeers.peers[].channels[].state_changes[].cause': "ChannelStateChangeCause",
'ListPeers.peers[].channels[].opener': "ChannelSide",
'ListPeers.peers[].channels[].closer': "ChannelSide",
'ListPeers.peers[].channels[].features[]': "string",
'ListFunds.channels[].state': 'ChannelState',
}
# A map of schema type to rust primitive types.
typemap = {
'boolean': 'bool',
'hex': 'String',
'msat': 'Amount',
'number': 'i64',
'pubkey': 'String',
'short_channel_id': 'String',
'signature': 'String',
'string': 'String',
'txid': 'String',
}
header = f"""#![allow(non_camel_case_types)]
//! This file was automatically generated using the following command:
//!
//! ```bash
//! {" ".join(sys.argv)}
//! ```
//!
//! Do not edit this file, it'll be overwritten. Rather edit the schema that
//! this file was generated from
"""
def normalize_varname(field):
"""Make sure that the variable name of this field is valid.
"""
# Dashes are not valid names
field.path = field.path.replace("-", "_")
return field
def gen_field(field):
if isinstance(field, CompositeField):
return gen_composite(field)
elif isinstance(field, EnumField):
return gen_enum(field)
elif isinstance(field, ArrayField):
return gen_array(field)
elif isinstance(field, PrimitiveField):
return gen_primitive(field)
else:
raise ValueError(f"Unmanaged type {field}")
def gen_enum(e):
defi, decl = "", ""
if e.description != "":
decl += f"/// {e.description}\n"
decl += f"#[derive(Copy, Clone, Debug, Deserialize, Serialize)]\n#[serde(rename_all = \"lowercase\")]\npub enum {e.typename} {{\n"
for v in e.variants:
if v is None:
continue
norm = v.normalized()
# decl += f" #[serde(rename = \"{v}\")]\n"
decl += f" {norm},\n"
decl += "}\n\n"
typename = e.typename
if e.path in overrides:
decl = "" # No declaration if we have an override
typename = overrides[e.path]
if e.required:
defi = f" // Path `{e.path}`\n #[serde(rename = \"{e.name}\")]\n pub {e.name.normalized()}: {typename},\n"
else:
defi = f' #[serde(skip_serializing_if = "Option::is_none")]'
defi = f" pub {e.name.normalized()}: Option<{typename}>,\n"
return defi, decl
def gen_primitive(p):
defi, decl = "", ""
org = p.name.name
typename = typemap.get(p.typename, p.typename)
normalize_varname(p)
if p.required:
defi = f" #[serde(alias = \"{org}\")]\n pub {p.name}: {typename},\n"
else:
defi = f" #[serde(alias = \"{org}\", skip_serializing_if = \"Option::is_none\")]\n pub {p.name}: Option<{typename}>,\n"
return defi, decl
def gen_array(a):
name = a.name.normalized().replace("[]", "")
logger.debug(f"Generating array field {a.name} -> {name} ({a.path})")
_, decl = gen_field(a.itemtype)
if isinstance(a.itemtype, PrimitiveField):
itemtype = a.itemtype.typename
elif isinstance(a.itemtype, CompositeField):
itemtype = a.itemtype.typename
elif isinstance(a.itemtype, EnumField):
itemtype = a.itemtype.typename
if a.path in overrides:
decl = "" # No declaration if we have an override
itemtype = overrides[a.path]
itemtype = typemap.get(itemtype, itemtype)
alias = a.name.normalized()[:-2] # Strip the `[]` suffix for arrays.
defi = f" #[serde(alias = \"{alias}\")]\n pub {name}: {'Vec<'*a.dims}{itemtype}{'>'*a.dims},\n"
return (defi, decl)
def gen_composite(c) -> Tuple[str, str]:
logger.debug(f"Generating composite field {c.name} ({c.path})")
fields = []
for f in c.fields:
fields.append(gen_field(f))
r = "".join([f[1] for f in fields])
r += f"""#[derive(Clone, Debug, Deserialize, Serialize)]\npub struct {c.typename} {{\n"""
r += "".join([f[0] for f in fields])
r += "}\n\n"
return ("", r)
class RustGenerator:
def __init__(self, dest: TextIO):
self.dest = dest
def write(self, text: str, numindent: int = 0) -> None:
raw = dedent(text)
if numindent > 0:
raw = indent(text, "\t" * numindent)
self.dest.write(raw)
def generate_requests(self, service: Service):
self.write("""\
pub mod requests {
#[allow(unused_imports)]
use crate::primitives::*;
#[allow(unused_imports)]
use serde::{{Deserialize, Serialize}};
""")
for meth in service.methods:
req = meth.request
_, decl = gen_composite(req)
self.write(decl, numindent=1)
self.write("}\n\n")
def generate_responses(self, service: Service):
self.write("""
pub mod responses {
#[allow(unused_imports)]
use crate::primitives::*;
#[allow(unused_imports)]
use serde::{{Deserialize, Serialize}};
""")
for meth in service.methods:
res = meth.response
_, decl = gen_composite(res)
self.write(decl, numindent=1)
self.write("}\n\n")
def generate_enums(self, service: Service):
"""The Request and Response enums serve as parsing primitives.
"""
self.write(f"""\
use serde::{{Deserialize, Serialize}};
pub use requests::*;
pub use responses::*;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "method", content = "params")]
#[serde(rename_all = "lowercase")]
pub enum Request {{
""")
for method in service.methods:
self.write(f"{method.name}(requests::{method.request.typename}),\n", numindent=1)
self.write(f"""\
}}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "method", content = "result")]
#[serde(rename_all = "lowercase")]
pub enum Response {{
""")
for method in service.methods:
self.write(f"{method.name}(responses::{method.response.typename}),\n", numindent=1)
self.write(f"""\
}}
""")
def generate(self, service: Service) -> None:
self.write(header)
self.generate_enums(service)
self.generate_requests(service)
self.generate_responses(service)