addrgen.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. package addrgen
  2. import (
  3. "errors"
  4. "fmt"
  5. "math/big"
  6. "math/rand"
  7. "net"
  8. "slices"
  9. "strconv"
  10. "strings"
  11. "github.com/SenseUnit/dtlspipe/randpool"
  12. )
  13. type AddrGen interface {
  14. Addr() string
  15. Power() *big.Int
  16. }
  17. type PortGen interface {
  18. Port() uint16
  19. Power() uint16
  20. }
  21. type EndpointGen interface {
  22. Endpoint() string
  23. Power() *big.Int
  24. }
  25. var _ EndpointGen = &AddrSet{}
  26. type AddrSet struct {
  27. portRange PortGen
  28. addrRanges []AddrGen
  29. cumWeights []*big.Int
  30. }
  31. func ParseAddrSet(spec string) (*AddrSet, error) {
  32. lastColonIdx := strings.LastIndex(spec, ":")
  33. if lastColonIdx == -1 {
  34. return nil, errors.New("port specification not found - colon is missing")
  35. }
  36. addrPart := spec[:lastColonIdx]
  37. portPart := spec[lastColonIdx+1:]
  38. portRange, err := ParsePortRangeSpec(portPart)
  39. if err != nil {
  40. return nil, fmt.Errorf("unable to parse port part: %w", err)
  41. }
  42. terms := strings.Split(addrPart, ",")
  43. addrRanges := make([]AddrGen, 0, len(terms))
  44. for _, addrRangeSpec := range terms {
  45. r, err := ParseAddrRangeSpec(addrRangeSpec)
  46. if err != nil {
  47. return nil, fmt.Errorf("addr range spec %q parse failed: %w", addrRangeSpec, err)
  48. }
  49. addrRanges = append(addrRanges, r)
  50. }
  51. if len(addrRanges) == 0 {
  52. return nil, errors.New("no valid address ranges specified")
  53. }
  54. cumWeights := make([]*big.Int, len(addrRanges))
  55. currSum := new(big.Int)
  56. for i, r := range addrRanges {
  57. currSum.Add(currSum, r.Power())
  58. cumWeights[i] = new(big.Int).Set(currSum)
  59. }
  60. return &AddrSet{
  61. portRange: portRange,
  62. addrRanges: addrRanges,
  63. cumWeights: cumWeights,
  64. }, nil
  65. }
  66. func (as *AddrSet) Endpoint() string {
  67. port := as.portRange.Port()
  68. count := len(as.addrRanges)
  69. limit := as.cumWeights[count-1]
  70. random := new(big.Int)
  71. randpool.Borrow(func(r *rand.Rand) {
  72. random.Rand(r, limit)
  73. })
  74. idx, found := slices.BinarySearchFunc(as.cumWeights, random, func(elem, target *big.Int) int {
  75. return elem.Cmp(target)
  76. })
  77. if found {
  78. idx++
  79. }
  80. addr := as.addrRanges[idx].Addr()
  81. return net.JoinHostPort(addr, strconv.FormatUint(uint64(port), 10))
  82. }
  83. func (as *AddrSet) Power() *big.Int {
  84. power := big.NewInt(int64(as.portRange.Power()))
  85. power.Mul(power, as.cumWeights[len(as.addrRanges)-1])
  86. return power
  87. }
  88. var _ EndpointGen = EqualMultiEndpointGen(nil)
  89. type EqualMultiEndpointGen []EndpointGen
  90. func NewEqualMultiEndpointGen(gens ...EndpointGen) (EqualMultiEndpointGen, error) {
  91. if len(gens) < 1 {
  92. return nil, errors.New("no generators provides")
  93. }
  94. return EqualMultiEndpointGen(gens), nil
  95. }
  96. func EqualMultiEndpointGenFromSpecs(specs []string) (EqualMultiEndpointGen, error) {
  97. gens := make([]EndpointGen, 0, len(specs))
  98. for _, spec := range specs {
  99. g, err := ParseAddrSet(spec)
  100. if err != nil {
  101. return nil, fmt.Errorf("can't create endpoint gen from spec %q: %w", spec, err)
  102. }
  103. gens = append(gens, g)
  104. }
  105. return NewEqualMultiEndpointGen(gens...)
  106. }
  107. func (g EqualMultiEndpointGen) Endpoint() string {
  108. var ret string
  109. randpool.Borrow(func(r *rand.Rand) {
  110. ret = g[r.Intn(len(g))].Endpoint()
  111. })
  112. return ret
  113. }
  114. func (g EqualMultiEndpointGen) Power() *big.Int {
  115. sum := new(big.Int)
  116. for _, sg := range g {
  117. sum.Add(sum, sg.Power())
  118. }
  119. return sum
  120. }
  121. var _ EndpointGen = SingleEndpoint("")
  122. type SingleEndpoint string
  123. func (e SingleEndpoint) Endpoint() string {
  124. return string(e)
  125. }
  126. func (e SingleEndpoint) Power() *big.Int {
  127. return big.NewInt(1)
  128. }