api login method
This commit is contained in:
@@ -3,6 +3,7 @@ package routers
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/jwtauth/v5"
|
||||
api "github.com/tech/sendico/pkg/api/http"
|
||||
@@ -13,11 +14,52 @@ import (
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"github.com/tech/sendico/server/interface/api/sresponse"
|
||||
emodel "github.com/tech/sendico/server/interface/model"
|
||||
"github.com/tech/sendico/server/internal/api/routers/ipguard"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type tokenHandlerFunc = func(r *http.Request, t *emodel.AccountToken) http.HandlerFunc
|
||||
|
||||
func (ar *AuthorizedRouter) validateClientPolicy(r *http.Request, t *emodel.AccountToken) http.HandlerFunc {
|
||||
clientID := strings.TrimSpace(t.ClientID)
|
||||
if clientID == "" {
|
||||
// Legacy tokens without client_id remain valid until expiration.
|
||||
return nil
|
||||
}
|
||||
client, err := ar.rtdb.GetClient(r.Context(), clientID)
|
||||
if errors.Is(err, merrors.ErrNoData) || client == nil {
|
||||
ar.logger.Debug("Client not found for access token", zap.String("client_id", clientID))
|
||||
return response.Unauthorized(ar.logger, ar.service, "client not found")
|
||||
}
|
||||
if err != nil {
|
||||
ar.logger.Warn("Failed to resolve client for access token", zap.Error(err), zap.String("client_id", clientID))
|
||||
return response.Internal(ar.logger, ar.service, err)
|
||||
}
|
||||
if client.IsRevoked {
|
||||
return response.Unauthorized(ar.logger, ar.service, "client has been revoked")
|
||||
}
|
||||
if client.AccountRef != nil && *client.AccountRef != t.AccountRef {
|
||||
return response.Unauthorized(ar.logger, ar.service, "client account mismatch")
|
||||
}
|
||||
|
||||
clientIP := ipguard.ClientIP(r)
|
||||
allowed, err := ipguard.Allowed(clientIP, client.AllowedCIDRs)
|
||||
if err != nil {
|
||||
ar.logger.Warn("Client IP policy contains invalid CIDR", zap.Error(err), zap.String("client_id", clientID))
|
||||
return response.Forbidden(ar.logger, ar.service, "client_ip_policy_invalid", "client ip policy is invalid")
|
||||
}
|
||||
if !allowed {
|
||||
rawIP := ""
|
||||
if clientIP != nil {
|
||||
rawIP = clientIP.String()
|
||||
}
|
||||
ar.logger.Warn("Client IP policy denied authorized request", zap.String("client_id", clientID), zap.String("remote_ip", rawIP))
|
||||
return response.Forbidden(ar.logger, ar.service, "ip_not_allowed", "request ip is not allowed for this client")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ar *AuthorizedRouter) tokenHandler(service mservice.Type, endpoint string, method api.HTTPMethod, handler tokenHandlerFunc) {
|
||||
hndlr := func(r *http.Request) http.HandlerFunc {
|
||||
_, claims, err := jwtauth.FromContext(r.Context())
|
||||
@@ -30,6 +72,9 @@ func (ar *AuthorizedRouter) tokenHandler(service mservice.Type, endpoint string,
|
||||
ar.logger.Debug("Failed to decode account token", zap.Error(err))
|
||||
return response.BadRequest(ar.logger, ar.service, "credentials_unreadable", "faild to parse credentials")
|
||||
}
|
||||
if h := ar.validateClientPolicy(r, t); h != nil {
|
||||
return h
|
||||
}
|
||||
return handler(r, t)
|
||||
}
|
||||
ar.imp.InstallHandler(service, endpoint, method, hndlr)
|
||||
@@ -48,7 +93,7 @@ func (ar *AuthorizedRouter) AccountHandler(service mservice.Type, endpoint strin
|
||||
}
|
||||
return response.Internal(ar.logger, ar.service, err)
|
||||
}
|
||||
accessToken, err := ar.imp.CreateAccessToken(&a)
|
||||
accessToken, err := ar.imp.CreateAccessTokenForClient(&a, t.ClientID)
|
||||
if err != nil {
|
||||
ar.logger.Warn("Failed to generate access token", zap.Error(err))
|
||||
return response.Internal(ar.logger, ar.service, err)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/go-chi/jwtauth/v5"
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/account"
|
||||
"github.com/tech/sendico/pkg/db/refreshtokens"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/tech/sendico/server/interface/middleware"
|
||||
@@ -14,11 +15,12 @@ import (
|
||||
type AuthorizedRouter struct {
|
||||
logger mlogger.Logger
|
||||
db account.DB
|
||||
rtdb refreshtokens.DB
|
||||
imp *re.HttpEndpointRouter
|
||||
service mservice.Type
|
||||
}
|
||||
|
||||
func NewRouter(logger mlogger.Logger, apiEndpoint string, router chi.Router, db account.DB, enforcer auth.Enforcer, config *middleware.TokenConfig, signature *middleware.Signature) *AuthorizedRouter {
|
||||
func NewRouter(logger mlogger.Logger, apiEndpoint string, router chi.Router, db account.DB, rtdb refreshtokens.DB, enforcer auth.Enforcer, config *middleware.TokenConfig, signature *middleware.Signature) *AuthorizedRouter {
|
||||
ja := jwtauth.New(signature.Algorithm, signature.PrivateKey, signature.PublicKey)
|
||||
router.Use(jwtauth.Verifier(ja))
|
||||
router.Use(jwtauth.Authenticator(ja))
|
||||
@@ -26,6 +28,7 @@ func NewRouter(logger mlogger.Logger, apiEndpoint string, router chi.Router, db
|
||||
ar := AuthorizedRouter{
|
||||
logger: l,
|
||||
db: db,
|
||||
rtdb: rtdb,
|
||||
imp: re.NewHttpEndpointRouter(l, apiEndpoint, router, config, signature),
|
||||
service: mservice.Accounts,
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func NewDispatcher(logger mlogger.Logger, router chi.Router, db account.DB, vdb
|
||||
d.public = rpublic.NewRouter(d.logger, endpoint, db, vdb, rtdb, r, &config.Token, &signature)
|
||||
})
|
||||
router.Group(func(r chi.Router) {
|
||||
d.protected = rauthorized.NewRouter(d.logger, endpoint, r, db, enforcer, &config.Token, &signature)
|
||||
d.protected = rauthorized.NewRouter(d.logger, endpoint, r, db, rtdb, enforcer, &config.Token, &signature)
|
||||
})
|
||||
|
||||
return d
|
||||
|
||||
@@ -10,8 +10,12 @@ import (
|
||||
)
|
||||
|
||||
func (er *HttpEndpointRouter) CreateAccessToken(user *model.Account) (sresponse.TokenData, error) {
|
||||
return er.CreateAccessTokenForClient(user, "")
|
||||
}
|
||||
|
||||
func (er *HttpEndpointRouter) CreateAccessTokenForClient(user *model.Account, clientID string) (sresponse.TokenData, error) {
|
||||
ja := jwtauth.New(er.signature.Algorithm, er.signature.PrivateKey, er.signature.PublicKey)
|
||||
_, res, err := ja.Encode(emodel.Account2Claims(user, er.config.Expiration.Account))
|
||||
_, res, err := ja.Encode(emodel.Account2ClaimsForClient(user, er.config.Expiration.Account, clientID))
|
||||
token := sresponse.TokenData{
|
||||
Token: res,
|
||||
Expiration: time.Now().Add(time.Duration(er.config.Expiration.Account) * time.Hour),
|
||||
|
||||
64
api/edge/bff/internal/api/routers/ipguard/ipguard.go
Normal file
64
api/edge/bff/internal/api/routers/ipguard/ipguard.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package ipguard
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ClientIP resolves caller IP from request remote address.
|
||||
// The service relies on trusted proxy middleware to normalize RemoteAddr.
|
||||
func ClientIP(r *http.Request) net.IP {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
raw := strings.TrimSpace(r.RemoteAddr)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if ip := net.ParseIP(raw); ip != nil {
|
||||
return ip
|
||||
}
|
||||
host, _, err := net.SplitHostPort(raw)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return net.ParseIP(host)
|
||||
}
|
||||
|
||||
func parseCIDRs(raw []string) ([]*net.IPNet, error) {
|
||||
blocks := make([]*net.IPNet, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
clean := strings.TrimSpace(item)
|
||||
if clean == "" {
|
||||
continue
|
||||
}
|
||||
_, block, err := net.ParseCIDR(clean)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blocks = append(blocks, block)
|
||||
}
|
||||
return blocks, nil
|
||||
}
|
||||
|
||||
// Allowed reports whether clientIP is allowed by configured CIDRs.
|
||||
// Empty CIDR list means unrestricted access.
|
||||
func Allowed(clientIP net.IP, cidrs []string) (bool, error) {
|
||||
blocks, err := parseCIDRs(cidrs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
if clientIP == nil {
|
||||
return false, nil
|
||||
}
|
||||
for _, block := range blocks {
|
||||
if block.Contains(clientIP) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
86
api/edge/bff/internal/api/routers/ipguard/ipguard_test.go
Normal file
86
api/edge/bff/internal/api/routers/ipguard/ipguard_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package ipguard
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClientIP(t *testing.T) {
|
||||
t.Run("extracts host from remote addr", func(t *testing.T) {
|
||||
req := &http.Request{RemoteAddr: "10.1.2.3:1234"}
|
||||
ip := ClientIP(req)
|
||||
if ip == nil || ip.String() != "10.1.2.3" {
|
||||
t.Fatalf("unexpected ip: %v", ip)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("supports plain ip", func(t *testing.T) {
|
||||
req := &http.Request{RemoteAddr: "8.8.8.8"}
|
||||
ip := ClientIP(req)
|
||||
if ip == nil || ip.String() != "8.8.8.8" {
|
||||
t.Fatalf("unexpected ip: %v", ip)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid remote addr", func(t *testing.T) {
|
||||
req := &http.Request{RemoteAddr: "invalid"}
|
||||
if ip := ClientIP(req); ip != nil {
|
||||
t.Fatalf("expected nil ip, got %v", ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAllowed(t *testing.T) {
|
||||
clientIP := net.ParseIP("10.1.2.3")
|
||||
if clientIP == nil {
|
||||
t.Fatal("failed to parse test ip")
|
||||
}
|
||||
|
||||
t.Run("allows when cidr matches", func(t *testing.T) {
|
||||
allowed, err := Allowed(clientIP, []string{"10.0.0.0/8"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !allowed {
|
||||
t.Fatal("expected allowed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("denies when cidr does not match", func(t *testing.T) {
|
||||
allowed, err := Allowed(clientIP, []string{"192.168.0.0/16"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if allowed {
|
||||
t.Fatal("expected denied")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allows when cidr list is empty", func(t *testing.T) {
|
||||
allowed, err := Allowed(clientIP, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !allowed {
|
||||
t.Fatal("expected allowed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid cidr fails", func(t *testing.T) {
|
||||
_, err := Allowed(clientIP, []string{"not-a-cidr"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil client ip denied when cidrs configured", func(t *testing.T) {
|
||||
allowed, err := Allowed(nil, []string{"10.0.0.0/8"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if allowed {
|
||||
t.Fatal("expected denied")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package routers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
@@ -9,36 +10,45 @@ import (
|
||||
|
||||
"github.com/tech/sendico/pkg/api/http/response"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/tech/sendico/pkg/mutil/mask"
|
||||
"github.com/tech/sendico/server/interface/api/srequest"
|
||||
"github.com/tech/sendico/server/interface/api/sresponse"
|
||||
"github.com/tech/sendico/server/internal/api/routers/ipguard"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const pendingLoginTTLMinutes = 10
|
||||
const apiLoginGrantType = "password"
|
||||
const apiLoginClientAuthMethod = "client_secret_post"
|
||||
|
||||
func (pr *PublicRouter) logUserIn(ctx context.Context, _ *http.Request, req *srequest.Login) http.HandlerFunc {
|
||||
func (pr *PublicRouter) authenticateAccount(ctx context.Context, req *srequest.Login) (*model.Account, http.HandlerFunc) {
|
||||
// Get the account database entry
|
||||
trimmedLogin := strings.TrimSpace(req.Login)
|
||||
account, err := pr.db.GetByEmail(ctx, strings.ToLower(trimmedLogin))
|
||||
if errors.Is(err, merrors.ErrNoData) || (account == nil) {
|
||||
pr.logger.Debug("User not found while logging in", zap.Error(err), zap.String("login", req.Login))
|
||||
return response.Unauthorized(pr.logger, pr.service, "user not found")
|
||||
return nil, response.Unauthorized(pr.logger, pr.service, "user not found")
|
||||
}
|
||||
if err != nil {
|
||||
pr.logger.Warn("Failed to query user with email", zap.Error(err), zap.String("login", req.Login))
|
||||
return response.Internal(pr.logger, pr.service, err)
|
||||
return nil, response.Internal(pr.logger, pr.service, err)
|
||||
}
|
||||
|
||||
if !account.IsActive() {
|
||||
return response.Forbidden(pr.logger, pr.service, "account_not_verified", "Account verification required")
|
||||
return nil, response.Forbidden(pr.logger, pr.service, "account_not_verified", "Account verification required")
|
||||
}
|
||||
|
||||
if !account.MatchPassword(req.Password) {
|
||||
return response.Unauthorized(pr.logger, pr.service, "password does not match")
|
||||
return nil, response.Unauthorized(pr.logger, pr.service, "password does not match")
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (pr *PublicRouter) respondPendingLogin(account *model.Account) http.HandlerFunc {
|
||||
pendingToken, err := pr.imp.CreatePendingToken(account, pendingLoginTTLMinutes)
|
||||
if err != nil {
|
||||
pr.logger.Warn("Failed to generate pending token", zap.Error(err))
|
||||
@@ -48,20 +58,144 @@ func (pr *PublicRouter) logUserIn(ctx context.Context, _ *http.Request, req *sre
|
||||
return sresponse.LoginPending(pr.logger, account, &pendingToken, mask.Email(account.Login))
|
||||
}
|
||||
|
||||
func (a *PublicRouter) login(r *http.Request) http.HandlerFunc {
|
||||
// TODO: add rate check
|
||||
func hasGrantType(grants []string, target string) bool {
|
||||
for _, grant := range grants {
|
||||
if strings.EqualFold(strings.TrimSpace(grant), target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (pr *PublicRouter) validateClientIPPolicy(r *http.Request, clientID string, client *model.Client) http.HandlerFunc {
|
||||
if client == nil {
|
||||
return response.Unauthorized(pr.logger, pr.service, "client not found")
|
||||
}
|
||||
clientIP := ipguard.ClientIP(r)
|
||||
allowed, err := ipguard.Allowed(clientIP, client.AllowedCIDRs)
|
||||
if err != nil {
|
||||
pr.logger.Warn("Client IP policy contains invalid CIDR", zap.Error(err), zap.String("client_id", clientID))
|
||||
return response.Forbidden(pr.logger, pr.service, "client_ip_policy_invalid", "client ip policy is invalid")
|
||||
}
|
||||
if !allowed {
|
||||
rawIP := ""
|
||||
if clientIP != nil {
|
||||
rawIP = clientIP.String()
|
||||
}
|
||||
pr.logger.Warn("Client IP policy denied request", zap.String("client_id", clientID), zap.String("remote_ip", rawIP))
|
||||
return response.Forbidden(pr.logger, pr.service, "ip_not_allowed", "request ip is not allowed for this client")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pr *PublicRouter) validateAPIClient(ctx context.Context, r *http.Request, req *srequest.Login, account *model.Account) http.HandlerFunc {
|
||||
client, err := pr.rtdb.GetClient(ctx, req.ClientID)
|
||||
if errors.Is(err, merrors.ErrNoData) || client == nil {
|
||||
pr.logger.Debug("API login rejected: client not found", zap.String("client_id", req.ClientID))
|
||||
return response.Unauthorized(pr.logger, pr.service, "client not found")
|
||||
}
|
||||
if err != nil {
|
||||
pr.logger.Warn("API login rejected: failed to load client", zap.Error(err), zap.String("client_id", req.ClientID))
|
||||
return response.Internal(pr.logger, pr.service, err)
|
||||
}
|
||||
if client.IsRevoked {
|
||||
return response.Forbidden(pr.logger, pr.service, "client_revoked", "client has been revoked")
|
||||
}
|
||||
if !hasGrantType(client.GrantTypes, apiLoginGrantType) {
|
||||
return response.Forbidden(pr.logger, pr.service, "client_grant_not_allowed", "client does not allow password grant")
|
||||
}
|
||||
method := strings.ToLower(strings.TrimSpace(client.TokenEndpointAuthMethod))
|
||||
if method == "" {
|
||||
method = apiLoginClientAuthMethod
|
||||
}
|
||||
if method != apiLoginClientAuthMethod {
|
||||
return response.Forbidden(pr.logger, pr.service, "client_auth_method_unsupported", "unsupported client auth method")
|
||||
}
|
||||
|
||||
storedSecret := strings.TrimSpace(client.ClientSecret)
|
||||
if storedSecret == "" {
|
||||
return response.Forbidden(pr.logger, pr.service, "client_secret_missing", "client secret is not configured")
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(storedSecret), []byte(req.ClientSecret)) != 1 {
|
||||
pr.logger.Debug("API login rejected: invalid client secret", zap.String("client_id", req.ClientID))
|
||||
return response.Unauthorized(pr.logger, pr.service, "invalid client secret")
|
||||
}
|
||||
if client.AccountRef != nil {
|
||||
accountRef := account.GetID()
|
||||
if accountRef == nil || *client.AccountRef != *accountRef {
|
||||
return response.Forbidden(pr.logger, pr.service, "client_account_mismatch", "client is bound to another account")
|
||||
}
|
||||
}
|
||||
if h := pr.validateClientIPPolicy(r, req.ClientID, client); h != nil {
|
||||
return h
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pr *PublicRouter) respondAPILogin(ctx context.Context, r *http.Request, req *srequest.Login, account *model.Account) http.HandlerFunc {
|
||||
if req.ClientID == "" || req.DeviceID == "" {
|
||||
return response.BadRequest(pr.logger, pr.service, "missing_session", "session identifier is required")
|
||||
}
|
||||
accessToken, err := pr.imp.CreateAccessTokenForClient(account, req.ClientID)
|
||||
if err != nil {
|
||||
pr.logger.Warn("Failed to generate access token for API login", zap.Error(err))
|
||||
return response.Internal(pr.logger, pr.service, err)
|
||||
}
|
||||
return pr.refreshAndRespondLogin(ctx, r, &req.SessionIdentifier, account, &accessToken)
|
||||
}
|
||||
|
||||
func decodeLogin(r *http.Request, logger mlogger.Logger) (*srequest.Login, http.HandlerFunc) {
|
||||
var req srequest.Login
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
a.logger.Info("Failed to decode login request", zap.Error(err))
|
||||
return response.BadPayload(a.logger, mservice.Accounts, err)
|
||||
logger.Info("Failed to decode login request", zap.Error(err))
|
||||
return nil, response.BadPayload(logger, mservice.Accounts, err)
|
||||
}
|
||||
req.Login = strings.TrimSpace(req.Login)
|
||||
req.Password = strings.TrimSpace(req.Password)
|
||||
req.ClientID = strings.TrimSpace(req.ClientID)
|
||||
req.DeviceID = strings.TrimSpace(req.DeviceID)
|
||||
req.ClientSecret = strings.TrimSpace(req.ClientSecret)
|
||||
|
||||
if req.Login == "" {
|
||||
return response.BadRequest(a.logger, mservice.Accounts, "email_missing", "login request has no user name")
|
||||
return nil, response.BadRequest(logger, mservice.Accounts, "email_missing", "login request has no user name")
|
||||
}
|
||||
if req.Password == "" {
|
||||
return response.BadRequest(a.logger, mservice.Accounts, "password_missing", "login request has no password")
|
||||
return nil, response.BadRequest(logger, mservice.Accounts, "password_missing", "login request has no password")
|
||||
}
|
||||
return a.logUserIn(r.Context(), r, &req)
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
func (a *PublicRouter) login(r *http.Request) http.HandlerFunc {
|
||||
// TODO: add rate check
|
||||
req, h := decodeLogin(r, a.logger)
|
||||
if h != nil {
|
||||
return h
|
||||
}
|
||||
account, h := a.authenticateAccount(r.Context(), req)
|
||||
if h != nil {
|
||||
return h
|
||||
}
|
||||
return a.respondPendingLogin(account)
|
||||
}
|
||||
|
||||
func (a *PublicRouter) apiLogin(r *http.Request) http.HandlerFunc {
|
||||
req, h := decodeLogin(r, a.logger)
|
||||
if h != nil {
|
||||
return h
|
||||
}
|
||||
if req.ClientID == "" {
|
||||
return response.BadRequest(a.logger, mservice.Accounts, "client_id_missing", "clientId is required")
|
||||
}
|
||||
if req.ClientSecret == "" {
|
||||
return response.BadRequest(a.logger, mservice.Accounts, "client_secret_missing", "clientSecret is required")
|
||||
}
|
||||
account, h := a.authenticateAccount(r.Context(), req)
|
||||
if h != nil {
|
||||
return h
|
||||
}
|
||||
if h = a.validateAPIClient(r.Context(), r, req, account); h != nil {
|
||||
return h
|
||||
}
|
||||
return a.respondAPILogin(r.Context(), r, req, account)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package routers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/tech/sendico/pkg/api/http/response"
|
||||
@@ -21,6 +22,9 @@ func (pr *PublicRouter) refreshAccessToken(r *http.Request) http.HandlerFunc {
|
||||
|
||||
account, token, err := pr.validateRefreshToken(r.Context(), r, &req)
|
||||
if err != nil {
|
||||
if errors.Is(err, errClientIPNotAllowed) {
|
||||
return response.Forbidden(pr.logger, pr.service, "ip_not_allowed", "request ip is not allowed for this client")
|
||||
}
|
||||
pr.logger.Warn("Failed to process access token refreshment request", zap.Error(err))
|
||||
return response.Auto(pr.logger, pr.service, err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package routers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/tech/sendico/pkg/api/http/response"
|
||||
@@ -20,6 +21,9 @@ func (pr *PublicRouter) rotateRefreshToken(r *http.Request) http.HandlerFunc {
|
||||
|
||||
account, token, err := pr.validateRefreshToken(r.Context(), r, &req)
|
||||
if err != nil {
|
||||
if errors.Is(err, errClientIPNotAllowed) {
|
||||
return response.Forbidden(pr.logger, pr.service, "ip_not_allowed", "request ip is not allowed for this client")
|
||||
}
|
||||
pr.logger.Warn("Failed to validate refresh token", zap.Error(err))
|
||||
return response.Auto(pr.logger, pr.service, err)
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ func NewRouter(logger mlogger.Logger, apiEndpoint string, db account.DB, vdb ver
|
||||
}
|
||||
|
||||
hr.InstallHandler(hr.service, "/login", api.Post, hr.login)
|
||||
hr.InstallHandler(hr.service, "/login/api", api.Post, hr.apiLogin)
|
||||
hr.InstallHandler(hr.service, "/rotate", api.Post, hr.rotateRefreshToken)
|
||||
hr.InstallHandler(hr.service, "/refresh", api.Post, hr.refreshAccessToken)
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var errClientIPNotAllowed = errors.New("client_ip_not_allowed")
|
||||
|
||||
func validateToken(token string, rt *model.RefreshToken) string {
|
||||
if rt.AccountRef == nil {
|
||||
return "missing account reference"
|
||||
@@ -31,7 +33,23 @@ func validateToken(token string, rt *model.RefreshToken) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (pr *PublicRouter) validateRefreshToken(ctx context.Context, _ *http.Request, req *srequest.TokenRefreshRotate) (*model.Account, *sresponse.TokenData, error) {
|
||||
func (pr *PublicRouter) validateRefreshToken(ctx context.Context, r *http.Request, req *srequest.TokenRefreshRotate) (*model.Account, *sresponse.TokenData, error) {
|
||||
client, err := pr.rtdb.GetClient(ctx, req.ClientID)
|
||||
if errors.Is(err, merrors.ErrNoData) || client == nil {
|
||||
pr.logger.Info("Refresh token rejected: client not found", zap.String("client_id", req.ClientID))
|
||||
return nil, nil, merrors.Unauthorized("client not found")
|
||||
}
|
||||
if err != nil {
|
||||
pr.logger.Warn("Failed to fetch client for refresh token validation", zap.Error(err), zap.String("client_id", req.ClientID))
|
||||
return nil, nil, err
|
||||
}
|
||||
if client.IsRevoked {
|
||||
return nil, nil, merrors.Unauthorized("client has been revoked")
|
||||
}
|
||||
if h := pr.validateClientIPPolicy(r, req.ClientID, client); h != nil {
|
||||
return nil, nil, errClientIPNotAllowed
|
||||
}
|
||||
|
||||
rt, err := pr.rtdb.GetByCRT(ctx, req)
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
pr.logger.Info("Refresh token not found", zap.String("client_id", req.ClientID), zap.String("device_id", req.DeviceID))
|
||||
@@ -49,7 +67,7 @@ func (pr *PublicRouter) validateRefreshToken(ctx context.Context, _ *http.Reques
|
||||
return nil, nil, merrors.Unauthorized("user not found")
|
||||
}
|
||||
|
||||
accessToken, err := pr.imp.CreateAccessToken(&account)
|
||||
accessToken, err := pr.imp.CreateAccessTokenForClient(&account, req.ClientID)
|
||||
if err != nil {
|
||||
pr.logger.Warn("Failed to generate access token", zap.Error(err))
|
||||
return nil, nil, err
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
emodel "github.com/tech/sendico/server/interface/model"
|
||||
)
|
||||
|
||||
func (a *VerificationAPI) createAccessToken(account *model.Account) (sresponse.TokenData, error) {
|
||||
func (a *VerificationAPI) createAccessToken(account *model.Account, clientID string) (sresponse.TokenData, error) {
|
||||
ja := jwtauth.New(a.signature.Algorithm, a.signature.PrivateKey, a.signature.PublicKey)
|
||||
_, res, err := ja.Encode(emodel.Account2Claims(account, a.tokenConfig.Expiration.Account))
|
||||
_, res, err := ja.Encode(emodel.Account2ClaimsForClient(account, a.tokenConfig.Expiration.Account, clientID))
|
||||
token := sresponse.TokenData{
|
||||
Token: res,
|
||||
Expiration: time.Now().Add(time.Duration(a.tokenConfig.Expiration.Account) * time.Hour),
|
||||
|
||||
@@ -53,7 +53,7 @@ func (a *VerificationAPI) verifyCode(r *http.Request, account *model.Account, to
|
||||
if req.SessionIdentifier.ClientID == "" || req.SessionIdentifier.DeviceID == "" {
|
||||
return response.BadRequest(a.logger, a.Name(), "missing_session", "session identifier is required")
|
||||
}
|
||||
accessToken, err := a.createAccessToken(account)
|
||||
accessToken, err := a.createAccessToken(account, req.SessionIdentifier.ClientID)
|
||||
if err != nil {
|
||||
a.logger.Warn("Failed to generate access token", zap.Error(err))
|
||||
return response.Internal(a.logger, a.Name(), err)
|
||||
|
||||
Reference in New Issue
Block a user