87 lines
1.9 KiB
Go
87 lines
1.9 KiB
Go
package ipguard
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"testing"
|
|
)
|
|
|
|
func TestClientIP(t *testing.T) {
|
|
t.Run("extracts host from remote addr", func(t *testing.T) {
|
|
req := &http.Request{RemoteAddr: "10.1.2.3:1234"}
|
|
ip := ClientIP(req)
|
|
if ip == nil || ip.String() != "10.1.2.3" {
|
|
t.Fatalf("unexpected ip: %v", ip)
|
|
}
|
|
})
|
|
|
|
t.Run("supports plain ip", func(t *testing.T) {
|
|
req := &http.Request{RemoteAddr: "8.8.8.8"}
|
|
ip := ClientIP(req)
|
|
if ip == nil || ip.String() != "8.8.8.8" {
|
|
t.Fatalf("unexpected ip: %v", ip)
|
|
}
|
|
})
|
|
|
|
t.Run("invalid remote addr", func(t *testing.T) {
|
|
req := &http.Request{RemoteAddr: "invalid"}
|
|
if ip := ClientIP(req); ip != nil {
|
|
t.Fatalf("expected nil ip, got %v", ip)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAllowed(t *testing.T) {
|
|
clientIP := net.ParseIP("10.1.2.3")
|
|
if clientIP == nil {
|
|
t.Fatal("failed to parse test ip")
|
|
}
|
|
|
|
t.Run("allows when cidr matches", func(t *testing.T) {
|
|
allowed, err := Allowed(clientIP, []string{"10.0.0.0/8"})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if !allowed {
|
|
t.Fatal("expected allowed")
|
|
}
|
|
})
|
|
|
|
t.Run("denies when cidr does not match", func(t *testing.T) {
|
|
allowed, err := Allowed(clientIP, []string{"192.168.0.0/16"})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if allowed {
|
|
t.Fatal("expected denied")
|
|
}
|
|
})
|
|
|
|
t.Run("allows when cidr list is empty", func(t *testing.T) {
|
|
allowed, err := Allowed(clientIP, nil)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if !allowed {
|
|
t.Fatal("expected allowed")
|
|
}
|
|
})
|
|
|
|
t.Run("invalid cidr fails", func(t *testing.T) {
|
|
_, err := Allowed(clientIP, []string{"not-a-cidr"})
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
})
|
|
|
|
t.Run("nil client ip denied when cidrs configured", func(t *testing.T) {
|
|
allowed, err := Allowed(nil, []string{"10.0.0.0/8"})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if allowed {
|
|
t.Fatal("expected denied")
|
|
}
|
|
})
|
|
}
|