netsafe.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. // Package netsafe provides SSRF-safe HTTP dialing primitives. A dialer
  2. // installed via SSRFGuardedDialContext resolves the host, rejects
  3. // private/internal IPs unless the per-request context whitelists them,
  4. // and dials the resolved IP directly so the IP checked is the IP used —
  5. // closing the DNS-rebinding TOCTOU window.
  6. package netsafe
  7. import (
  8. "context"
  9. "fmt"
  10. "net"
  11. "regexp"
  12. "strings"
  13. "time"
  14. )
  15. // IsBlockedIP returns true for loopback, RFC1918 private, link-local
  16. // (including 169.254.169.254 cloud-metadata), and unspecified addresses.
  17. func IsBlockedIP(ip net.IP) bool {
  18. return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
  19. ip.IsLinkLocalMulticast() || ip.IsUnspecified()
  20. }
  21. type allowPrivateCtxKey struct{}
  22. // ContextWithAllowPrivate marks a context as permitting outbound requests
  23. // to private/internal IPs. Use only for callers (e.g. LAN-resident nodes)
  24. // where the admin has opted in explicitly.
  25. func ContextWithAllowPrivate(ctx context.Context, allow bool) context.Context {
  26. return context.WithValue(ctx, allowPrivateCtxKey{}, allow)
  27. }
  28. func AllowPrivateFromContext(ctx context.Context) bool {
  29. v, _ := ctx.Value(allowPrivateCtxKey{}).(bool)
  30. return v
  31. }
  32. var defaultDialer = &net.Dialer{Timeout: 10 * time.Second}
  33. // SSRFGuardedDialContext is a net/http Transport.DialContext implementation
  34. // that enforces IsBlockedIP unless the context opts in via
  35. // ContextWithAllowPrivate.
  36. func SSRFGuardedDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
  37. host, port, err := net.SplitHostPort(addr)
  38. if err != nil {
  39. return nil, err
  40. }
  41. allowPrivate := AllowPrivateFromContext(ctx)
  42. var ips []net.IPAddr
  43. if ip := net.ParseIP(host); ip != nil {
  44. ips = []net.IPAddr{{IP: ip}}
  45. } else {
  46. ips, err = net.DefaultResolver.LookupIPAddr(ctx, host)
  47. if err != nil {
  48. return nil, err
  49. }
  50. }
  51. var lastErr error
  52. for _, ipAddr := range ips {
  53. if !allowPrivate && IsBlockedIP(ipAddr.IP) {
  54. lastErr = fmt.Errorf("blocked private/internal address %s", ipAddr.IP)
  55. continue
  56. }
  57. conn, derr := defaultDialer.DialContext(ctx, network, net.JoinHostPort(ipAddr.IP.String(), port))
  58. if derr == nil {
  59. return conn, nil
  60. }
  61. lastErr = derr
  62. }
  63. if lastErr == nil {
  64. lastErr = fmt.Errorf("no usable address for %s", host)
  65. }
  66. return nil, lastErr
  67. }
  68. // hostnamePattern accepts RFC 1123 hostnames (letters, digits, hyphens,
  69. // dots). Bracketed IPv6 forms ("[::1]") are stripped before this check
  70. // runs in NormalizeHost.
  71. var hostnamePattern = regexp.MustCompile(`^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)*$`)
  72. // NormalizeHost validates that addr is a plain hostname or IP literal with
  73. // no embedded path/userinfo/port/scheme — anything that could be used to
  74. // smuggle URL components past callers that string-format URLs from user
  75. // input. Returns the bare host (no brackets); callers wrap IPv6 via
  76. // net.JoinHostPort as needed.
  77. func NormalizeHost(addr string) (string, error) {
  78. addr = strings.TrimSpace(addr)
  79. if addr == "" {
  80. return "", fmt.Errorf("address is required")
  81. }
  82. if strings.HasPrefix(addr, "[") && strings.HasSuffix(addr, "]") {
  83. addr = addr[1 : len(addr)-1]
  84. }
  85. if ip := net.ParseIP(addr); ip != nil {
  86. return ip.String(), nil
  87. }
  88. if len(addr) > 253 || !hostnamePattern.MatchString(addr) {
  89. return "", fmt.Errorf("invalid host %q", addr)
  90. }
  91. return addr, nil
  92. }