| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- package netsafe
- import (
- "context"
- "net"
- "strings"
- "testing"
- )
- func TestIsBlockedIP(t *testing.T) {
- cases := []struct {
- ip string
- want bool
- }{
- {"127.0.0.1", true},
- {"::1", true},
- {"10.0.0.5", true},
- {"172.16.0.1", true},
- {"192.168.1.1", true},
- {"169.254.0.1", true},
- {"0.0.0.0", true},
- {"::", true},
- {"8.8.8.8", false},
- {"1.1.1.1", false},
- {"2606:4700:4700::1111", false},
- }
- for _, c := range cases {
- t.Run(c.ip, func(t *testing.T) {
- ip := net.ParseIP(c.ip)
- if ip == nil {
- t.Fatalf("could not parse %q", c.ip)
- }
- if got := IsBlockedIP(ip); got != c.want {
- t.Fatalf("IsBlockedIP(%s) = %v, want %v", c.ip, got, c.want)
- }
- })
- }
- }
- func TestAllowPrivateFromContext_Default(t *testing.T) {
- if AllowPrivateFromContext(context.Background()) {
- t.Fatal("default context should report AllowPrivate=false")
- }
- }
- func TestAllowPrivateFromContext_RoundTrip(t *testing.T) {
- ctx := ContextWithAllowPrivate(context.Background(), true)
- if !AllowPrivateFromContext(ctx) {
- t.Fatal("expected AllowPrivate=true after ContextWithAllowPrivate(true)")
- }
- ctx = ContextWithAllowPrivate(ctx, false)
- if AllowPrivateFromContext(ctx) {
- t.Fatal("expected AllowPrivate=false after overriding with false")
- }
- }
- func TestNormalizeHost_Valid(t *testing.T) {
- cases := []struct {
- in string
- want string
- }{
- {"example.com", "example.com"},
- {" example.com ", "example.com"},
- {"a.b.c.example.com", "a.b.c.example.com"},
- {"10.0.0.1", "10.0.0.1"},
- {"[2606:4700:4700::1111]", "2606:4700:4700::1111"},
- {"2606:4700:4700::1111", "2606:4700:4700::1111"},
- }
- for _, c := range cases {
- t.Run(c.in, func(t *testing.T) {
- got, err := NormalizeHost(c.in)
- if err != nil {
- t.Fatalf("NormalizeHost(%q) returned error: %v", c.in, err)
- }
- if !strings.EqualFold(got, c.want) {
- t.Fatalf("NormalizeHost(%q) = %q, want %q", c.in, got, c.want)
- }
- })
- }
- }
- func TestNormalizeHost_Invalid(t *testing.T) {
- cases := []string{
- "",
- " ",
- "-leading-dash.com",
- "trailing-dash-.com",
- "bad host with spaces",
- "under_score.example.com",
- "exa$mple.com",
- strings.Repeat("a", 254),
- }
- for _, in := range cases {
- t.Run(in, func(t *testing.T) {
- if _, err := NormalizeHost(in); err == nil {
- t.Fatalf("NormalizeHost(%q) expected error, got nil", in)
- }
- })
- }
- }
- func TestSSRFGuardedDialContext_BlocksLiteralPrivateIP(t *testing.T) {
- _, err := SSRFGuardedDialContext(context.Background(), "tcp", "127.0.0.1:1")
- if err == nil {
- t.Fatal("expected dial to 127.0.0.1 to be blocked")
- }
- if !strings.Contains(err.Error(), "blocked") {
- t.Fatalf("expected 'blocked' in error, got: %v", err)
- }
- }
- func TestSSRFGuardedDialContext_AllowPrivateBypassesGuard(t *testing.T) {
- ctx := ContextWithAllowPrivate(context.Background(), true)
- _, err := SSRFGuardedDialContext(ctx, "tcp", "127.0.0.1:1")
- if err == nil {
- t.Fatal("dial to a closed loopback port should still fail at the connect step")
- }
- if strings.Contains(err.Error(), "blocked private/internal address") {
- t.Fatalf("expected guard to be bypassed when AllowPrivate=true, got: %v", err)
- }
- }
- func TestSSRFGuardedDialContext_BadAddress(t *testing.T) {
- if _, err := SSRFGuardedDialContext(context.Background(), "tcp", "no-port"); err == nil {
- t.Fatal("expected error for address without port")
- }
- }
|