package addrgen

import (
	"errors"
	"fmt"
	"math/big"
	"math/rand"
	"net"
	"slices"
	"strconv"
	"strings"

	"github.com/SenseUnit/dtlspipe/randpool"
)

type AddrGen interface {
	Addr() string
	Power() *big.Int
}

type PortGen interface {
	Port() uint16
	Power() uint16
}

type EndpointGen interface {
	Endpoint() string
	Power() *big.Int
}

var _ EndpointGen = &AddrSet{}

type AddrSet struct {
	portRange  PortGen
	addrRanges []AddrGen
	cumWeights []*big.Int
}

func ParseAddrSet(spec string) (*AddrSet, error) {
	lastColonIdx := strings.LastIndex(spec, ":")
	if lastColonIdx == -1 {
		return nil, errors.New("port specification not found - colon is missing")
	}
	addrPart := spec[:lastColonIdx]
	portPart := spec[lastColonIdx+1:]
	portRange, err := ParsePortRangeSpec(portPart)
	if err != nil {
		return nil, fmt.Errorf("unable to parse port part: %w", err)
	}

	terms := strings.Split(addrPart, ",")
	addrRanges := make([]AddrGen, 0, len(terms))
	for _, addrRangeSpec := range terms {
		r, err := ParseAddrRangeSpec(addrRangeSpec)
		if err != nil {
			return nil, fmt.Errorf("addr range spec %q parse failed: %w", addrRangeSpec, err)
		}
		addrRanges = append(addrRanges, r)
	}
	if len(addrRanges) == 0 {
		return nil, errors.New("no valid address ranges specified")
	}

	cumWeights := make([]*big.Int, len(addrRanges))
	currSum := new(big.Int)
	for i, r := range addrRanges {
		currSum.Add(currSum, r.Power())
		cumWeights[i] = new(big.Int).Set(currSum)
	}
	return &AddrSet{
		portRange:  portRange,
		addrRanges: addrRanges,
		cumWeights: cumWeights,
	}, nil
}

func (as *AddrSet) Endpoint() string {
	port := as.portRange.Port()
	count := len(as.addrRanges)
	limit := as.cumWeights[count-1]
	random := new(big.Int)
	randpool.Borrow(func(r *rand.Rand) {
		random.Rand(r, limit)
	})
	idx, found := slices.BinarySearchFunc(as.cumWeights, random, func(elem, target *big.Int) int {
		return elem.Cmp(target)
	})
	if found {
		idx++
	}
	addr := as.addrRanges[idx].Addr()
	return net.JoinHostPort(addr, strconv.FormatUint(uint64(port), 10))
}

func (as *AddrSet) Power() *big.Int {
	power := big.NewInt(int64(as.portRange.Power()))
	power.Mul(power, as.cumWeights[len(as.addrRanges)-1])
	return power
}

var _ EndpointGen = EqualMultiEndpointGen(nil)

type EqualMultiEndpointGen []EndpointGen

func NewEqualMultiEndpointGen(gens ...EndpointGen) (EqualMultiEndpointGen, error) {
	if len(gens) < 1 {
		return nil, errors.New("no generators provides")
	}
	return EqualMultiEndpointGen(gens), nil
}

func EqualMultiEndpointGenFromSpecs(specs []string) (EqualMultiEndpointGen, error) {
	gens := make([]EndpointGen, 0, len(specs))
	for _, spec := range specs {
		g, err := ParseAddrSet(spec)
		if err != nil {
			return nil, fmt.Errorf("can't create endpoint gen from spec %q: %w", spec, err)
		}
		gens = append(gens, g)
	}
	return NewEqualMultiEndpointGen(gens...)
}

func (g EqualMultiEndpointGen) Endpoint() string {
	var ret string
	randpool.Borrow(func(r *rand.Rand) {
		ret = g[r.Intn(len(g))].Endpoint()
	})
	return ret
}

func (g EqualMultiEndpointGen) Power() *big.Int {
	sum := new(big.Int)
	for _, sg := range g {
		sum.Add(sum, sg.Power())
	}
	return sum
}

var _ EndpointGen = SingleEndpoint("")

type SingleEndpoint string

func (e SingleEndpoint) Endpoint() string {
	return string(e)
}

func (e SingleEndpoint) Power() *big.Int {
	return big.NewInt(1)
}