257 lines
7.7 KiB
Go
257 lines
7.7 KiB
Go
package tronclient
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"math/big"
|
|
"net"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/fbsobreira/gotron-sdk/pkg/client"
|
|
"github.com/fbsobreira/gotron-sdk/pkg/proto/api"
|
|
"github.com/fbsobreira/gotron-sdk/pkg/proto/core"
|
|
"github.com/tech/sendico/pkg/merrors"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/metadata"
|
|
)
|
|
|
|
// Client wraps the gotron-sdk gRPC client with convenience methods.
|
|
type Client struct {
|
|
grpc *client.GrpcClient
|
|
timeout time.Duration
|
|
}
|
|
|
|
// NewClient creates a new TRON gRPC client connected to the given endpoint.
|
|
func NewClient(grpcURL string, timeout time.Duration, authToken string, forceIPv4 bool) (*Client, error) {
|
|
if grpcURL == "" {
|
|
return nil, merrors.InvalidArgument("tronclient: grpc url is required")
|
|
}
|
|
if timeout <= 0 {
|
|
timeout = 30 * time.Second
|
|
}
|
|
|
|
address, useTLS, err := normalizeGRPCAddress(grpcURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
grpcClient := client.NewGrpcClientWithTimeout(address, timeout)
|
|
|
|
var transportCreds grpc.DialOption
|
|
if useTLS {
|
|
transportCreds = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{MinVersion: tls.VersionTLS12}))
|
|
} else {
|
|
transportCreds = grpc.WithTransportCredentials(insecure.NewCredentials())
|
|
}
|
|
|
|
opts := []grpc.DialOption{transportCreds}
|
|
if forceIPv4 {
|
|
opts = append(opts, grpc.WithContextDialer(grpcForceIPv4Dialer))
|
|
}
|
|
if token := strings.TrimSpace(authToken); token != "" {
|
|
opts = append(opts,
|
|
grpc.WithUnaryInterceptor(grpcTokenUnaryInterceptor(token)),
|
|
grpc.WithStreamInterceptor(grpcTokenStreamInterceptor(token)),
|
|
)
|
|
}
|
|
|
|
if err := grpcClient.Start(opts...); err != nil {
|
|
return nil, merrors.Internal(fmt.Sprintf("tronclient: failed to connect to %s: %v", grpcURL, err))
|
|
}
|
|
|
|
return &Client{
|
|
grpc: grpcClient,
|
|
timeout: timeout,
|
|
}, nil
|
|
}
|
|
|
|
func grpcForceIPv4Dialer(ctx context.Context, address string) (net.Conn, error) {
|
|
host, port, err := net.SplitHostPort(address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ips, err := net.DefaultResolver.LookupIP(ctx, "ip4", host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(ips) == 0 {
|
|
return nil, merrors.Internal(fmt.Sprintf("no IPv4 address found for %s", host))
|
|
}
|
|
|
|
var dialer net.Dialer
|
|
var lastErr error
|
|
for _, ip := range ips {
|
|
target := net.JoinHostPort(ip.String(), port)
|
|
conn, err := dialer.DialContext(ctx, "tcp4", target)
|
|
if err == nil {
|
|
return conn, nil
|
|
}
|
|
lastErr = err
|
|
if ctx.Err() != nil {
|
|
break
|
|
}
|
|
}
|
|
if lastErr != nil {
|
|
return nil, merrors.Internal(fmt.Sprintf("failed to dial any IPv4 address for %s: %v", host, lastErr))
|
|
}
|
|
return nil, merrors.Internal(fmt.Sprintf("failed to dial IPv4 address for %s", host))
|
|
}
|
|
|
|
func normalizeGRPCAddress(grpcURL string) (string, bool, error) {
|
|
target := strings.TrimSpace(grpcURL)
|
|
useTLS := false
|
|
if target == "" {
|
|
return "", false, merrors.InvalidArgument("tronclient: grpc url is required")
|
|
}
|
|
if strings.Contains(target, "://") {
|
|
u, err := url.Parse(target)
|
|
if err != nil {
|
|
return "", false, merrors.InvalidArgument("tronclient: invalid grpc url")
|
|
}
|
|
if u.Scheme == "https" || u.Scheme == "grpcs" {
|
|
useTLS = true
|
|
}
|
|
host := strings.TrimSpace(u.Host)
|
|
if host == "" {
|
|
return "", false, merrors.InvalidArgument("tronclient: grpc url missing host")
|
|
}
|
|
if useTLS && u.Port() == "" {
|
|
host = host + ":443"
|
|
}
|
|
return host, useTLS, nil
|
|
}
|
|
return target, useTLS, nil
|
|
}
|
|
|
|
func grpcTokenUnaryInterceptor(token string) grpc.UnaryClientInterceptor {
|
|
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
|
ctx = metadata.AppendToOutgoingContext(ctx, "x-token", token)
|
|
return invoker(ctx, method, req, reply, cc, opts...)
|
|
}
|
|
}
|
|
|
|
func grpcTokenStreamInterceptor(token string) grpc.StreamClientInterceptor {
|
|
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
|
ctx = metadata.AppendToOutgoingContext(ctx, "x-token", token)
|
|
return streamer(ctx, desc, cc, method, opts...)
|
|
}
|
|
}
|
|
|
|
// Close closes the gRPC connection.
|
|
func (c *Client) Close() {
|
|
if c != nil && c.grpc != nil {
|
|
c.grpc.Stop()
|
|
}
|
|
}
|
|
|
|
// SetAPIKey configures the TRON-PRO-API-KEY for TronGrid requests.
|
|
func (c *Client) SetAPIKey(apiKey string) {
|
|
if c != nil && c.grpc != nil {
|
|
c.grpc.SetAPIKey(apiKey)
|
|
}
|
|
}
|
|
|
|
// Transfer creates a native TRX transfer transaction.
|
|
// Addresses should be in base58 format.
|
|
// Amount is in SUN (1 TRX = 1,000,000 SUN).
|
|
func (c *Client) Transfer(from, to string, amountSun int64) (*api.TransactionExtention, error) {
|
|
if c == nil || c.grpc == nil {
|
|
return nil, merrors.Internal("tronclient: client not initialized")
|
|
}
|
|
return c.grpc.Transfer(from, to, amountSun)
|
|
}
|
|
|
|
// TRC20Send creates a TRC20 token transfer transaction.
|
|
// Addresses should be in base58 format.
|
|
// Amount is in the token's smallest unit.
|
|
// FeeLimit is in SUN (recommended: 100_000_000 = 100 TRX).
|
|
func (c *Client) TRC20Send(from, to, contract string, amount *big.Int, feeLimit int64) (*api.TransactionExtention, error) {
|
|
if c == nil || c.grpc == nil {
|
|
return nil, merrors.Internal("tronclient: client not initialized")
|
|
}
|
|
return c.grpc.TRC20Send(from, to, contract, amount, feeLimit)
|
|
}
|
|
|
|
// Broadcast broadcasts a signed transaction to the network.
|
|
func (c *Client) Broadcast(tx *core.Transaction) (*api.Return, error) {
|
|
if c == nil || c.grpc == nil {
|
|
return nil, merrors.Internal("tronclient: client not initialized")
|
|
}
|
|
return c.grpc.Broadcast(tx)
|
|
}
|
|
|
|
// GetTransactionInfoByID retrieves transaction info by its hash.
|
|
// The txID should be a hex string (without 0x prefix).
|
|
func (c *Client) GetTransactionInfoByID(txID string) (*core.TransactionInfo, error) {
|
|
if c == nil || c.grpc == nil {
|
|
return nil, merrors.Internal("tronclient: client not initialized")
|
|
}
|
|
return c.grpc.GetTransactionInfoByID(txID)
|
|
}
|
|
|
|
// GetTransactionByID retrieves the full transaction by its hash.
|
|
func (c *Client) GetTransactionByID(txID string) (*core.Transaction, error) {
|
|
if c == nil || c.grpc == nil {
|
|
return nil, merrors.Internal("tronclient: client not initialized")
|
|
}
|
|
return c.grpc.GetTransactionByID(txID)
|
|
}
|
|
|
|
// TRC20GetDecimals returns the decimals of a TRC20 token.
|
|
func (c *Client) TRC20GetDecimals(contract string) (*big.Int, error) {
|
|
if c == nil || c.grpc == nil {
|
|
return nil, merrors.Internal("tronclient: client not initialized")
|
|
}
|
|
return c.grpc.TRC20GetDecimals(contract)
|
|
}
|
|
|
|
// TRC20ContractBalance returns the balance of an address for a TRC20 token.
|
|
func (c *Client) TRC20ContractBalance(addr, contract string) (*big.Int, error) {
|
|
if c == nil || c.grpc == nil {
|
|
return nil, merrors.Internal("tronclient: client not initialized")
|
|
}
|
|
return c.grpc.TRC20ContractBalance(addr, contract)
|
|
}
|
|
|
|
// AwaitConfirmation polls for transaction confirmation until ctx is cancelled.
|
|
func (c *Client) AwaitConfirmation(ctx context.Context, txID string, pollInterval time.Duration) (*core.TransactionInfo, error) {
|
|
if c == nil || c.grpc == nil {
|
|
return nil, merrors.Internal("tronclient: client not initialized")
|
|
}
|
|
if pollInterval <= 0 {
|
|
pollInterval = 3 * time.Second
|
|
}
|
|
|
|
ticker := time.NewTicker(pollInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
txInfo, err := c.grpc.GetTransactionInfoByID(txID)
|
|
if err == nil && txInfo != nil && txInfo.BlockNumber > 0 {
|
|
return txInfo, nil
|
|
}
|
|
|
|
select {
|
|
case <-ticker.C:
|
|
continue
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
}
|
|
|
|
// TxIDFromExtention extracts the transaction ID hex string from a TransactionExtention.
|
|
func TxIDFromExtention(txExt *api.TransactionExtention) string {
|
|
if txExt == nil || len(txExt.Txid) == 0 {
|
|
return ""
|
|
}
|
|
return hex.EncodeToString(txExt.Txid)
|
|
}
|