Files
sendico/api/gateway/chain/internal/service/gateway/rpcclient/clients.go
2026-02-03 00:40:46 +01:00

199 lines
5.4 KiB
Go

package rpcclient
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/rpc"
"github.com/tech/sendico/gateway/chain/internal/service/gateway/shared"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
pmodel "github.com/tech/sendico/pkg/model"
"go.uber.org/zap"
)
// Clients holds pre-initialised RPC clients keyed by network name.
type Clients struct {
logger mlogger.Logger
clients map[pmodel.ChainNetwork]clientEntry
}
type clientEntry struct {
eth *ethclient.Client
rpc *rpc.Client
}
// Prepare dials all configured networks up front and returns a ready-to-use client set.
func Prepare(ctx context.Context, logger mlogger.Logger, networks []shared.Network) (*Clients, error) {
if logger == nil {
return nil, merrors.Internal("rpc clients: logger is required")
}
clientLogger := logger.Named("rpc_client")
result := &Clients{
logger: clientLogger,
clients: make(map[pmodel.ChainNetwork]clientEntry),
}
for _, network := range networks {
rpcURL := strings.TrimSpace(network.RPCURL)
if rpcURL == "" {
result.Close()
err := merrors.InvalidArgument(fmt.Sprintf("rpc url not configured for network %s", network.Name))
clientLogger.Warn("Rpc url missing", zap.String("network", string(network.Name)))
return nil, err
}
fields := []zap.Field{
zap.String("network", string(network.Name)),
}
clientLogger.Info("Initialising rpc client", fields...)
dialCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
httpClient := &http.Client{
Transport: &loggingRoundTripper{
logger: clientLogger,
network: network.Name,
endpoint: rpcURL,
base: http.DefaultTransport,
},
}
rpcCli, err := rpc.DialOptions(dialCtx, rpcURL, rpc.WithHTTPClient(httpClient))
cancel()
if err != nil {
result.Close()
clientLogger.Warn("Failed to dial rpc endpoint", append(fields, zap.Error(err))...)
return nil, merrors.Internal(fmt.Sprintf("rpc dial failed for %s: %s", network.Name, err.Error()))
}
client := ethclient.NewClient(rpcCli)
result.clients[network.Name] = clientEntry{
eth: client,
rpc: rpcCli,
}
clientLogger.Info("RPC client ready", fields...)
}
if len(result.clients) == 0 {
clientLogger.Warn("No rpc clients were initialised")
return nil, merrors.InvalidArgument("no rpc clients initialised")
} else {
clientLogger.Info("RPC clients initialised", zap.Int("count", len(result.clients)))
}
return result, nil
}
// Client returns a prepared client for the given network name.
func (c *Clients) Client(network pmodel.ChainNetwork) (*ethclient.Client, error) {
if c == nil {
return nil, merrors.Internal("RPC clients not initialised")
}
entry, ok := c.clients[network]
if !ok || entry.eth == nil {
return nil, merrors.InvalidArgument(fmt.Sprintf("RPC client not configured for network %s", network))
}
return entry.eth, nil
}
// RPCClient returns the raw RPC client for low-level calls.
func (c *Clients) RPCClient(network pmodel.ChainNetwork) (*rpc.Client, error) {
if c == nil {
return nil, merrors.Internal("rpc clients not initialised")
}
entry, ok := c.clients[network]
if !ok || entry.rpc == nil {
return nil, merrors.InvalidArgument(fmt.Sprintf("rpc client not configured for network %s", network))
}
return entry.rpc, nil
}
// Close tears down all RPC clients, logging each close.
func (c *Clients) Close() {
if c == nil {
return
}
for name, entry := range c.clients {
if entry.rpc != nil {
entry.rpc.Close()
} else if entry.eth != nil {
entry.eth.Close()
}
if c.logger != nil {
c.logger.Info("RPC client closed", zap.String("network", string(name)))
}
}
}
type loggingRoundTripper struct {
logger mlogger.Logger
network pmodel.ChainNetwork
endpoint string
base http.RoundTripper
}
func (l *loggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if l.base == nil {
l.base = http.DefaultTransport
}
var reqBody []byte
if req.Body != nil {
raw, _ := io.ReadAll(req.Body)
reqBody = raw
req.Body = io.NopCloser(strings.NewReader(string(raw)))
}
fields := []zap.Field{
zap.String("network", string(l.network)),
}
if len(reqBody) > 0 {
fields = append(fields, zap.String("rpc_request", truncate(string(reqBody), 2048)))
}
l.logger.Debug("RPC request", fields...)
resp, err := l.base.RoundTrip(req)
if err != nil {
l.logger.Warn("RPC http request failed", append(fields, zap.Error(err))...)
return nil, err
}
bodyBytes, _ := io.ReadAll(resp.Body)
resp.Body.Close()
resp.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
respFields := append(fields,
zap.Int("status_code", resp.StatusCode),
)
if contentType := strings.TrimSpace(resp.Header.Get("Content-Type")); contentType != "" {
respFields = append(respFields, zap.String("content_type", contentType))
}
if len(bodyBytes) > 0 {
respFields = append(respFields, zap.String("rpc_response", truncate(string(bodyBytes), 2048)))
}
l.logger.Debug("RPC response", respFields...)
if resp.StatusCode >= 400 {
l.logger.Warn("RPC response error", respFields...)
} else if len(bodyBytes) == 0 {
l.logger.Warn("RPC response empty body", respFields...)
} else if len(bodyBytes) > 0 && !json.Valid(bodyBytes) {
l.logger.Warn("RPC response invalid JSON", respFields...)
}
return resp, nil
}
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
if max <= 3 {
return s[:max]
}
return s[:max-3] + "..."
}