config_envelope.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package middleware
  2. import (
  3. "bytes"
  4. "crypto/subtle"
  5. "io"
  6. "net/http"
  7. "github.com/gin-gonic/gin"
  8. "github.com/mhsanaei/3x-ui/v3/internal/util/wirecodec"
  9. )
  10. // maxDecodedConfigBytes caps a decompressed request body (defense in depth on
  11. // top of wirecodec's own ceiling).
  12. const maxDecodedConfigBytes = 8 << 20
  13. // ConfigEnvelopeMiddleware advertises node envelope support on every response
  14. // and, for requests that opt into the envelope, decompresses (zstd) and verifies
  15. // the X-Config-Sha256 integrity hash before the body reaches the handler. A
  16. // request carrying neither envelope header passes through untouched, so old
  17. // panels and plain calls keep working (mixed-version safe).
  18. func ConfigEnvelopeMiddleware() gin.HandlerFunc {
  19. return func(c *gin.Context) {
  20. c.Header(wirecodec.CapsHeader, wirecodec.CapZstd)
  21. enc := c.GetHeader("Content-Encoding")
  22. sum := c.GetHeader(wirecodec.HashHeader)
  23. if enc != wirecodec.EncodingZstd && sum == "" {
  24. c.Next()
  25. return
  26. }
  27. // On the envelope path, zstd is the only encoding we understand. Reject any
  28. // other Content-Encoding rather than hashing/forwarding a still-encoded body
  29. // the downstream handler can't read.
  30. if enc != "" && enc != wirecodec.EncodingZstd {
  31. c.AbortWithStatus(http.StatusUnsupportedMediaType)
  32. return
  33. }
  34. raw, err := io.ReadAll(c.Request.Body)
  35. if err != nil {
  36. c.AbortWithStatus(http.StatusBadRequest)
  37. return
  38. }
  39. _ = c.Request.Body.Close()
  40. if enc == wirecodec.EncodingZstd {
  41. decoded, derr := wirecodec.Decompress(raw, maxDecodedConfigBytes)
  42. if derr != nil {
  43. c.AbortWithStatus(http.StatusBadRequest)
  44. return
  45. }
  46. raw = decoded
  47. c.Request.Header.Del("Content-Encoding")
  48. }
  49. if sum != "" {
  50. got := wirecodec.Sha256Hex(raw)
  51. if subtle.ConstantTimeCompare([]byte(got), []byte(sum)) != 1 {
  52. c.AbortWithStatus(http.StatusBadRequest)
  53. return
  54. }
  55. }
  56. c.Request.Body = io.NopCloser(bytes.NewReader(raw))
  57. c.Request.ContentLength = int64(len(raw))
  58. c.Next()
  59. }
  60. }