| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- // Package netsafe provides SSRF-safe HTTP dialing primitives. A dialer
- // installed via SSRFGuardedDialContext resolves the host, rejects
- // private/internal IPs unless the per-request context whitelists them,
- // and dials the resolved IP directly so the IP checked is the IP used —
- // closing the DNS-rebinding TOCTOU window.
- package netsafe
- import (
- "context"
- "fmt"
- "net"
- "regexp"
- "strings"
- "time"
- )
- // IsBlockedIP returns true for loopback, RFC1918 private, link-local
- // (including 169.254.169.254 cloud-metadata), and unspecified addresses.
- func IsBlockedIP(ip net.IP) bool {
- return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
- ip.IsLinkLocalMulticast() || ip.IsUnspecified()
- }
- type allowPrivateCtxKey struct{}
- // ContextWithAllowPrivate marks a context as permitting outbound requests
- // to private/internal IPs. Use only for callers (e.g. LAN-resident nodes)
- // where the admin has opted in explicitly.
- func ContextWithAllowPrivate(ctx context.Context, allow bool) context.Context {
- return context.WithValue(ctx, allowPrivateCtxKey{}, allow)
- }
- func AllowPrivateFromContext(ctx context.Context) bool {
- v, _ := ctx.Value(allowPrivateCtxKey{}).(bool)
- return v
- }
- var defaultDialer = &net.Dialer{Timeout: 10 * time.Second}
- // SSRFGuardedDialContext is a net/http Transport.DialContext implementation
- // that enforces IsBlockedIP unless the context opts in via
- // ContextWithAllowPrivate.
- func SSRFGuardedDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
- host, port, err := net.SplitHostPort(addr)
- if err != nil {
- return nil, err
- }
- allowPrivate := AllowPrivateFromContext(ctx)
- var ips []net.IPAddr
- if ip := net.ParseIP(host); ip != nil {
- ips = []net.IPAddr{{IP: ip}}
- } else {
- ips, err = net.DefaultResolver.LookupIPAddr(ctx, host)
- if err != nil {
- return nil, err
- }
- }
- var lastErr error
- for _, ipAddr := range ips {
- if !allowPrivate && IsBlockedIP(ipAddr.IP) {
- lastErr = fmt.Errorf("blocked private/internal address %s", ipAddr.IP)
- continue
- }
- conn, derr := defaultDialer.DialContext(ctx, network, net.JoinHostPort(ipAddr.IP.String(), port))
- if derr == nil {
- return conn, nil
- }
- lastErr = derr
- }
- if lastErr == nil {
- lastErr = fmt.Errorf("no usable address for %s", host)
- }
- return nil, lastErr
- }
- // hostnamePattern accepts RFC 1123 hostnames (letters, digits, hyphens,
- // dots). Bracketed IPv6 forms ("[::1]") are stripped before this check
- // runs in NormalizeHost.
- var hostnamePattern = regexp.MustCompile(`^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)*$`)
- // NormalizeHost validates that addr is a plain hostname or IP literal with
- // no embedded path/userinfo/port/scheme — anything that could be used to
- // smuggle URL components past callers that string-format URLs from user
- // input. Returns the bare host (no brackets); callers wrap IPv6 via
- // net.JoinHostPort as needed.
- func NormalizeHost(addr string) (string, error) {
- addr = strings.TrimSpace(addr)
- if addr == "" {
- return "", fmt.Errorf("address is required")
- }
- if strings.HasPrefix(addr, "[") && strings.HasSuffix(addr, "]") {
- addr = addr[1 : len(addr)-1]
- }
- if ip := net.ParseIP(addr); ip != nil {
- return ip.String(), nil
- }
- if len(addr) > 253 || !hostnamePattern.MatchString(addr) {
- return "", fmt.Errorf("invalid host %q", addr)
- }
- return addr, nil
- }
|