auto_https_conn.go 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. package network
  2. import (
  3. "bufio"
  4. "bytes"
  5. "fmt"
  6. "net"
  7. "net/http"
  8. "sync"
  9. )
  10. type AutoHttpsConn struct {
  11. net.Conn
  12. firstBuf []byte
  13. bufStart int
  14. readRequestOnce sync.Once
  15. }
  16. func NewAutoHttpsConn(conn net.Conn) net.Conn {
  17. return &AutoHttpsConn{
  18. Conn: conn,
  19. }
  20. }
  21. func (c *AutoHttpsConn) readRequest() bool {
  22. c.firstBuf = make([]byte, 2048)
  23. n, err := c.Conn.Read(c.firstBuf)
  24. c.firstBuf = c.firstBuf[:n]
  25. if err != nil {
  26. return false
  27. }
  28. reader := bytes.NewReader(c.firstBuf)
  29. bufReader := bufio.NewReader(reader)
  30. request, err := http.ReadRequest(bufReader)
  31. if err != nil {
  32. return false
  33. }
  34. resp := http.Response{
  35. Header: http.Header{},
  36. }
  37. resp.StatusCode = http.StatusTemporaryRedirect
  38. location := fmt.Sprintf("https://%v%v", request.Host, request.RequestURI)
  39. resp.Header.Set("Location", location)
  40. resp.Write(c.Conn)
  41. c.Close()
  42. c.firstBuf = nil
  43. return true
  44. }
  45. func (c *AutoHttpsConn) Read(buf []byte) (int, error) {
  46. c.readRequestOnce.Do(func() {
  47. c.readRequest()
  48. })
  49. if c.firstBuf != nil {
  50. n := copy(buf, c.firstBuf[c.bufStart:])
  51. c.bufStart += n
  52. if c.bufStart >= len(c.firstBuf) {
  53. c.firstBuf = nil
  54. }
  55. return n, nil
  56. }
  57. return c.Conn.Read(buf)
  58. }