api_token.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package service
  2. import (
  3. "crypto/subtle"
  4. "errors"
  5. "strings"
  6. "github.com/mhsanaei/3x-ui/v3/database"
  7. "github.com/mhsanaei/3x-ui/v3/database/model"
  8. "github.com/mhsanaei/3x-ui/v3/util/common"
  9. "github.com/mhsanaei/3x-ui/v3/util/random"
  10. )
  11. type ApiTokenService struct{}
  12. const apiTokenLength = 48
  13. type ApiTokenView struct {
  14. Id int `json:"id"`
  15. Name string `json:"name"`
  16. Token string `json:"token"`
  17. Enabled bool `json:"enabled"`
  18. CreatedAt int64 `json:"createdAt"`
  19. }
  20. func toView(t *model.ApiToken) *ApiTokenView {
  21. return &ApiTokenView{
  22. Id: t.Id,
  23. Name: t.Name,
  24. Token: t.Token,
  25. Enabled: t.Enabled,
  26. CreatedAt: t.CreatedAt,
  27. }
  28. }
  29. func (s *ApiTokenService) List() ([]*ApiTokenView, error) {
  30. db := database.GetDB()
  31. var rows []*model.ApiToken
  32. if err := db.Model(model.ApiToken{}).Order("id asc").Find(&rows).Error; err != nil {
  33. return nil, err
  34. }
  35. out := make([]*ApiTokenView, 0, len(rows))
  36. for _, r := range rows {
  37. out = append(out, toView(r))
  38. }
  39. return out, nil
  40. }
  41. func (s *ApiTokenService) Create(name string) (*ApiTokenView, error) {
  42. name = strings.TrimSpace(name)
  43. if name == "" {
  44. return nil, common.NewError("token name is required")
  45. }
  46. if len(name) > 64 {
  47. return nil, common.NewError("token name must be 64 characters or fewer")
  48. }
  49. db := database.GetDB()
  50. var count int64
  51. if err := db.Model(model.ApiToken{}).Where("name = ?", name).Count(&count).Error; err != nil {
  52. return nil, err
  53. }
  54. if count > 0 {
  55. return nil, common.NewError("a token with that name already exists")
  56. }
  57. row := &model.ApiToken{
  58. Name: name,
  59. Token: random.Seq(apiTokenLength),
  60. Enabled: true,
  61. }
  62. if err := db.Create(row).Error; err != nil {
  63. return nil, err
  64. }
  65. return toView(row), nil
  66. }
  67. func (s *ApiTokenService) Delete(id int) error {
  68. if id <= 0 {
  69. return common.NewError("invalid token id")
  70. }
  71. db := database.GetDB()
  72. return db.Where("id = ?", id).Delete(model.ApiToken{}).Error
  73. }
  74. func (s *ApiTokenService) SetEnabled(id int, enabled bool) error {
  75. if id <= 0 {
  76. return common.NewError("invalid token id")
  77. }
  78. db := database.GetDB()
  79. res := db.Model(model.ApiToken{}).Where("id = ?", id).Update("enabled", enabled)
  80. if res.Error != nil {
  81. return res.Error
  82. }
  83. if res.RowsAffected == 0 {
  84. return errors.New("token not found")
  85. }
  86. return nil
  87. }
  88. // Match returns true when the presented bearer token matches any enabled
  89. // row in api_tokens. Uses constant-time compare per row so a remote
  90. // attacker can't time-attack tokens byte-by-byte.
  91. func (s *ApiTokenService) Match(presented string) bool {
  92. if presented == "" {
  93. return false
  94. }
  95. db := database.GetDB()
  96. var rows []*model.ApiToken
  97. if err := db.Model(model.ApiToken{}).Where("enabled = ?", true).Find(&rows).Error; err != nil {
  98. return false
  99. }
  100. presentedBytes := []byte(presented)
  101. matched := false
  102. for _, r := range rows {
  103. if subtle.ConstantTimeCompare([]byte(r.Token), presentedBytes) == 1 {
  104. matched = true
  105. }
  106. }
  107. return matched
  108. }