package mutil import ( "net/http" "strconv" "github.com/go-chi/chi/v5" "github.com/tech/sendico/pkg/mlogger" "github.com/tech/sendico/pkg/model" "go.mongodb.org/mongo-driver/bson/primitive" "go.uber.org/zap" ) func GetParam(r *http.Request, paramName string) string { return chi.URLParam(r, paramName) } func GetID(r *http.Request) string { return GetParam(r, "id") } func GetAccountID(r *http.Request) string { return GetParam(r, AccountRefName()) } func GetObjRef(r *http.Request) string { return GetParam(r, ObjRefName()) } func GetOrganizationID(r *http.Request) string { return GetParam(r, OrganizationRefName()) } func GetOrganizationRef(r *http.Request) (primitive.ObjectID, error) { return primitive.ObjectIDFromHex(GetOrganizationID(r)) } func GetStatusID(r *http.Request) string { return GetParam(r, OrganizationRefName()) } func GetStatusRef(r *http.Request) (primitive.ObjectID, error) { return primitive.ObjectIDFromHex(GetStatusID(r)) } func GetProjectID(r *http.Request) string { return GetParam(r, ProjectRefName()) } func GetProjectRef(r *http.Request) (primitive.ObjectID, error) { return primitive.ObjectIDFromHex(GetProjectID(r)) } func GetInvitationID(r *http.Request) string { return GetParam(r, InvitationRefName()) } func GetInvitationRef(r *http.Request) (primitive.ObjectID, error) { return primitive.ObjectIDFromHex(GetOrganizationID(r)) } func GetToken(r *http.Request) string { return GetParam(r, TokenName()) } // parseFunc is a function type that parses a string to a specific type type parseFunc[T any] func(string) (T, error) // getOptionalParam is a generic function that handles optional query parameters func GetOptionalParam[T any](logger mlogger.Logger, r *http.Request, key string, parse parseFunc[T]) (*T, error) { vals := r.URL.Query() s := vals.Get(key) if s == "" { return nil, nil } val, err := parse(s) if err != nil { logger.Debug("Malformed query parameter", zap.Error(err), zap.String(key, s)) return nil, err } return &val, nil } // getOptionalInt64Param gets an optional int64 query parameter func GetOptionalInt64Param(logger mlogger.Logger, r *http.Request, key string) (*int64, error) { return GetOptionalParam(logger, r, key, func(s string) (int64, error) { return strconv.ParseInt(s, 10, 64) }) } func GetLimit(logger mlogger.Logger, r *http.Request) (*int64, error) { return GetOptionalInt64Param(logger, r, "limit") } func GetOffset(logger mlogger.Logger, r *http.Request) (*int64, error) { return GetOptionalInt64Param(logger, r, "offset") } func GetLimitAndOffset(logger mlogger.Logger, r *http.Request) (*int64, *int64, error) { limit, err := GetLimit(logger, r) if err != nil { return nil, nil, err } offset, err := GetOffset(logger, r) if err != nil { return nil, nil, err } return limit, offset, nil } func GetOptionalBoolParam(logger mlogger.Logger, r *http.Request, key string) (*bool, error) { return GetOptionalParam(logger, r, key, strconv.ParseBool) } func GetCascadeParam(logger mlogger.Logger, r *http.Request) (*bool, error) { return GetOptionalBoolParam(logger, r, "cascade") } func GetArchiveParam(logger mlogger.Logger, r *http.Request) (*bool, error) { return GetOptionalBoolParam(logger, r, "archived") } func GetViewCursor(logger mlogger.Logger, r *http.Request) (*model.ViewCursor, error) { var res model.ViewCursor var err error if res.Limit, res.Offset, err = GetLimitAndOffset(logger, r); err != nil { return nil, err } if res.IsArchived, err = GetArchiveParam(logger, r); err != nil { return nil, err } return &res, nil }