netsafe.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. package netsafe
  2. import (
  3. "context"
  4. "fmt"
  5. "net"
  6. "regexp"
  7. "strings"
  8. "time"
  9. )
  10. func IsBlockedIP(ip net.IP) bool {
  11. return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
  12. ip.IsLinkLocalMulticast() || ip.IsUnspecified()
  13. }
  14. type allowPrivateCtxKey struct{}
  15. func ContextWithAllowPrivate(ctx context.Context, allow bool) context.Context {
  16. return context.WithValue(ctx, allowPrivateCtxKey{}, allow)
  17. }
  18. func AllowPrivateFromContext(ctx context.Context) bool {
  19. v, _ := ctx.Value(allowPrivateCtxKey{}).(bool)
  20. return v
  21. }
  22. var defaultDialer = &net.Dialer{Timeout: 10 * time.Second}
  23. func SSRFGuardedDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
  24. host, port, err := net.SplitHostPort(addr)
  25. if err != nil {
  26. return nil, err
  27. }
  28. allowPrivate := AllowPrivateFromContext(ctx)
  29. var ips []net.IPAddr
  30. if ip := net.ParseIP(host); ip != nil {
  31. ips = []net.IPAddr{{IP: ip}}
  32. } else {
  33. ips, err = net.DefaultResolver.LookupIPAddr(ctx, host)
  34. if err != nil {
  35. return nil, err
  36. }
  37. }
  38. var lastErr error
  39. for _, ipAddr := range ips {
  40. if !allowPrivate && IsBlockedIP(ipAddr.IP) {
  41. lastErr = fmt.Errorf("blocked private/internal address %s", ipAddr.IP)
  42. continue
  43. }
  44. conn, derr := defaultDialer.DialContext(ctx, network, net.JoinHostPort(ipAddr.IP.String(), port))
  45. if derr == nil {
  46. return conn, nil
  47. }
  48. lastErr = derr
  49. }
  50. if lastErr == nil {
  51. lastErr = fmt.Errorf("no usable address for %s", host)
  52. }
  53. return nil, lastErr
  54. }
  55. var hostnamePattern = regexp.MustCompile(`^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)*$`)
  56. func NormalizeHost(addr string) (string, error) {
  57. addr = strings.TrimSpace(addr)
  58. if addr == "" {
  59. return "", fmt.Errorf("address is required")
  60. }
  61. if strings.HasPrefix(addr, "[") && strings.HasSuffix(addr, "]") {
  62. addr = addr[1 : len(addr)-1]
  63. }
  64. if ip := net.ParseIP(addr); ip != nil {
  65. return ip.String(), nil
  66. }
  67. if len(addr) > 253 || !hostnamePattern.MatchString(addr) {
  68. return "", fmt.Errorf("invalid host %q", addr)
  69. }
  70. return addr, nil
  71. }