diff --git a/go.mod b/go.mod index e503db6..e6f925c 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/vincenzopalazzo/lnprototest-v2 -go 1.23 +go 1.22 require ( github.com/akamensky/argparse v1.4.0 @@ -37,6 +37,7 @@ require ( github.com/lightningnetwork/lnd/tlv v1.0.3 // indirect github.com/lightningnetwork/lnd/tor v1.0.1 // indirect github.com/miekg/dns v1.1.43 // indirect + github.com/sourcegraph/jsonrpc2 v0.2.0 // indirect github.com/stretchr/testify v1.8.4 // indirect golang.org/x/crypto v0.21.0 // indirect golang.org/x/net v0.21.0 // indirect diff --git a/go.sum b/go.sum index fe2c719..27874b3 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,7 @@ github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9 github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 h1:+9834+KizmvFV7pXQGSXQTsaWhq2GjuNUt0aUU0YBYw= @@ -224,6 +225,8 @@ github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= +github.com/sourcegraph/jsonrpc2 v0.2.0 h1:KjN/dC4fP6aN9030MZCJs9WQbTOjWHhrtKVpzzSrr/U= +github.com/sourcegraph/jsonrpc2 v0.2.0/go.mod h1:ZafdZgk/axhT1cvZAPOhw+95nz2I/Ra5qMlU4gTRwIo= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/lnprototest/server.go b/lnprototest/server.go index 06c49b6..95ee470 100644 --- a/lnprototest/server.go +++ b/lnprototest/server.go @@ -55,7 +55,8 @@ func (self *ProtoTestServer) Connect(nodeId string, port uint32, network wire.Bi if err != nil { return err } - hostname := fmt.Sprintf("127.0.01:%d", &port) + hostname := fmt.Sprintf("127.0.0.1:%d", port) + fmt.Printf("\n*******%s******\n", hostname) addr, err := net.ResolveTCPAddr("tcp", hostname) if err != nil { return err diff --git a/server/rpc.go b/server/rpc.go index c620b2c..3d42a75 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -2,31 +2,43 @@ package server import ( "bytes" + "context" "encoding/hex" + "encoding/json" + "fmt" "github.com/btcsuite/btcd/wire" + "github.com/sourcegraph/jsonrpc2" ) type ConnectRPC struct { NodeId string - port uint64 + Port uint32 } -func (_ *ConnectRPC) Call(request *ConnectRPC, response *ConnectRPC) error { - if err := lnprototestServer.Connect(request.NodeId, uint32(request.port), wire.SimNet); err != nil { +func ConnectCall(request *json.RawMessage, response *json.RawMessage) error { + var connect ConnectRPC + fmt.Println(string(*request)) + if err := json.Unmarshal(*request, &connect); err != nil { + return nil + } + if err := lnprototestServer.Connect(connect.NodeId, connect.Port, wire.SimNet); err != nil { return err } - response.NodeId = request.NodeId - response.port = request.port + *response = *request return nil } type SendRPC struct { - msg string + Msg string } -func (_ *SendRPC) Call(request *SendRPC, response *SendRPC) error { - buff, err := hex.DecodeString(request.msg) +func SendCall(request *json.RawMessage, response *json.RawMessage) error { + var sendCall SendRPC + if err := json.Unmarshal(*request, &sendCall); err != nil { + return nil + } + buff, err := hex.DecodeString(sendCall.Msg) if err != nil { return err } @@ -40,6 +52,49 @@ func (_ *SendRPC) Call(request *SendRPC, response *SendRPC) error { return err } - response.msg = hex.EncodeToString(buffResp.Bytes()) - return nil + if buffResp == nil { + return fmt.Errorf("empty answer from the node") + } + + sendCall.Msg = hex.EncodeToString(buffResp.Bytes()) + *response, err = json.Marshal(sendCall) + return err +} + +type RPCHandler struct{} + +// Handle implements the jsonrpc2.Handler interface. +func (h *RPCHandler) Handle(ctx context.Context, c *jsonrpc2.Conn, r *jsonrpc2.Request) { + switch r.Method { + case "connect": + var response json.RawMessage + if err := ConnectCall(r.Params, &response); err != nil { + if err := c.ReplyWithError(ctx, r.ID, &jsonrpc2.Error{ + Code: -1, + Message: fmt.Sprintf("%s", err), + Data: nil, + }); err != nil { + return + } + } + if err := c.Reply(ctx, r.ID, response); err != nil { + return + } + case "send": + var response json.RawMessage + if err := SendCall(r.Params, &response); err != nil { + if err := c.ReplyWithError(ctx, r.ID, &jsonrpc2.Error{ + Code: -1, + Message: fmt.Sprintf("%s", err), + Data: nil, + }); err != nil { + return + } + } + default: + err := &jsonrpc2.Error{Code: jsonrpc2.CodeMethodNotFound, Message: "Method not found"} + if err := c.ReplyWithError(ctx, r.ID, err); err != nil { + return + } + } } diff --git a/server/server.go b/server/server.go index 996879c..ead08c7 100644 --- a/server/server.go +++ b/server/server.go @@ -1,13 +1,14 @@ package server import ( + "context" "errors" "fmt" "net" - "net/rpc" - "net/rpc/jsonrpc" "os" + "github.com/sourcegraph/jsonrpc2" + "github.com/vincenzopalazzo/lnprototest-v2/lnprototest" ) @@ -15,30 +16,17 @@ var lnprototestServer *lnprototest.ProtoTestServer = nil type Server struct { dataDir string + context context.Context } func Make(datadir string) (*Server, error) { return &Server{ dataDir: datadir, + context: context.Background(), }, nil } -func (self *Server) RegisterRPCs() error { - if err := rpc.Register(new(ConnectRPC)); err != nil { - return err - } - - if err := rpc.Register(new(SendRPC)); err != nil { - return err - } - return nil -} - func (self *Server) Listen() error { - if err := self.RegisterRPCs(); err != nil { - return err - } - unixPath := fmt.Sprintf("%s/lnprototest.sock", self.dataDir) if _, err := os.Stat(unixPath); !errors.Is(err, os.ErrNotExist) { os.Remove(unixPath) @@ -62,6 +50,7 @@ func (self *Server) Listen() error { if err != nil { continue } - go jsonrpc.ServeConn(conn) + + _ = jsonrpc2.NewConn(self.context, jsonrpc2.NewPlainObjectStream(conn), &RPCHandler{}) } }