123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- package addrgen
- import (
- "errors"
- "fmt"
- "math/big"
- "math/rand"
- "net"
- "slices"
- "strconv"
- "strings"
- "github.com/SenseUnit/dtlspipe/randpool"
- )
- type AddrGen interface {
- Addr() string
- Power() *big.Int
- }
- type PortGen interface {
- Port() uint16
- Power() uint16
- }
- type EndpointGen interface {
- Endpoint() string
- Power() *big.Int
- }
- var _ EndpointGen = &AddrSet{}
- type AddrSet struct {
- portRange PortGen
- addrRanges []AddrGen
- cumWeights []*big.Int
- }
- func ParseAddrSet(spec string) (*AddrSet, error) {
- lastColonIdx := strings.LastIndex(spec, ":")
- if lastColonIdx == -1 {
- return nil, errors.New("port specification not found - colon is missing")
- }
- addrPart := spec[:lastColonIdx]
- portPart := spec[lastColonIdx+1:]
- portRange, err := ParsePortRangeSpec(portPart)
- if err != nil {
- return nil, fmt.Errorf("unable to parse port part: %w", err)
- }
- terms := strings.Split(addrPart, ",")
- addrRanges := make([]AddrGen, 0, len(terms))
- for _, addrRangeSpec := range terms {
- r, err := ParseAddrRangeSpec(addrRangeSpec)
- if err != nil {
- return nil, fmt.Errorf("addr range spec %q parse failed: %w", addrRangeSpec, err)
- }
- addrRanges = append(addrRanges, r)
- }
- if len(addrRanges) == 0 {
- return nil, errors.New("no valid address ranges specified")
- }
- cumWeights := make([]*big.Int, len(addrRanges))
- currSum := new(big.Int)
- for i, r := range addrRanges {
- currSum.Add(currSum, r.Power())
- cumWeights[i] = new(big.Int).Set(currSum)
- }
- return &AddrSet{
- portRange: portRange,
- addrRanges: addrRanges,
- cumWeights: cumWeights,
- }, nil
- }
- func (as *AddrSet) Endpoint() string {
- port := as.portRange.Port()
- count := len(as.addrRanges)
- limit := as.cumWeights[count-1]
- random := new(big.Int)
- randpool.Borrow(func(r *rand.Rand) {
- random.Rand(r, limit)
- })
- idx, found := slices.BinarySearchFunc(as.cumWeights, random, func(elem, target *big.Int) int {
- return elem.Cmp(target)
- })
- if found {
- idx++
- }
- addr := as.addrRanges[idx].Addr()
- return net.JoinHostPort(addr, strconv.FormatUint(uint64(port), 10))
- }
- func (as *AddrSet) Power() *big.Int {
- power := big.NewInt(int64(as.portRange.Power()))
- power.Mul(power, as.cumWeights[len(as.addrRanges)-1])
- return power
- }
- var _ EndpointGen = EqualMultiEndpointGen(nil)
- type EqualMultiEndpointGen []EndpointGen
- func NewEqualMultiEndpointGen(gens ...EndpointGen) (EqualMultiEndpointGen, error) {
- if len(gens) < 1 {
- return nil, errors.New("no generators provides")
- }
- return EqualMultiEndpointGen(gens), nil
- }
- func EqualMultiEndpointGenFromSpecs(specs []string) (EqualMultiEndpointGen, error) {
- gens := make([]EndpointGen, 0, len(specs))
- for _, spec := range specs {
- g, err := ParseAddrSet(spec)
- if err != nil {
- return nil, fmt.Errorf("can't create endpoint gen from spec %q: %w", spec, err)
- }
- gens = append(gens, g)
- }
- return NewEqualMultiEndpointGen(gens...)
- }
- func (g EqualMultiEndpointGen) Endpoint() string {
- var ret string
- randpool.Borrow(func(r *rand.Rand) {
- ret = g[r.Intn(len(g))].Endpoint()
- })
- return ret
- }
- func (g EqualMultiEndpointGen) Power() *big.Int {
- sum := new(big.Int)
- for _, sg := range g {
- sum.Add(sum, sg.Power())
- }
- return sum
- }
- var _ EndpointGen = SingleEndpoint("")
- type SingleEndpoint string
- func (e SingleEndpoint) Endpoint() string {
- return string(e)
- }
- func (e SingleEndpoint) Power() *big.Int {
- return big.NewInt(1)
- }
|