Browse Source

FIX hashStorage

Hamidreza Ghavami 1 year ago
parent
commit
786a3ac992
2 changed files with 22 additions and 7 deletions
  1. 13 3
      web/global/hashStorage.go
  2. 9 4
      web/service/tgbot.go

+ 13 - 3
web/global/hashStorage.go

@@ -3,8 +3,10 @@ package global
 import (
 	"crypto/md5"
 	"encoding/hex"
+	"regexp"
 	"sync"
 	"time"
+	"x-ui/util/common"
 )
 
 type HashEntry struct {
@@ -59,15 +61,23 @@ func (h *HashStorage) saveValue(query string) string {
 	return md5HashString
 }
 
-func (h *HashStorage) GetValue(hash string) string {
+func (h *HashStorage) GetValue(hash string) (string, error) {
 	h.RLock()
 	defer h.RUnlock()
 
 	entry, exists := h.Data[hash]
 	if !exists {
-		return hash
+		if h.isMD5(hash) {
+			return "", common.NewError("hash not found in storage!")
+		}
+		return hash, nil
 	}
-	return entry.Value
+	return entry.Value, nil
+}
+
+func (h *HashStorage) isMD5(hash string) bool {
+	match, _ := regexp.MatchString("^[a-f0-9]{32}$", hash)
+	return match
 }
 
 func (h *HashStorage) RemoveExpiredHashes() {

+ 9 - 4
web/service/tgbot.go

@@ -61,8 +61,9 @@ func (t *Tgbot) Start(i18nFS embed.FS) error {
 		return err
 	}
 
-	// init hash storage
-	t.hashStorage = global.NewHashStorage(5*time.Minute, false)
+	// init hash storage => store callback queries
+	// NOTE: it only save the query if its length is more than 64 chars.
+	t.hashStorage = global.NewHashStorage(20*time.Minute, false)
 
 	tgBottoken, err := t.settingService.GetTgBotToken()
 	if err != nil || tgBottoken == "" {
@@ -199,8 +200,12 @@ func (t *Tgbot) asnwerCallback(callbackQuery *telego.CallbackQuery, isAdmin bool
 	chatId := callbackQuery.Message.Chat.ID
 
 	if isAdmin {
-		// get query from hash storage (if the query was <= 64 chars hash storage dont save the hash and return data itself)
-		decodedQuery := t.hashStorage.GetValue(callbackQuery.Data)
+		// get query from hash storage
+		decodedQuery, err := t.hashStorage.GetValue(callbackQuery.Data)
+		if err != nil {
+			t.SendMsgToTgbot(chatId, "Query not found! Please use the command again!")
+			return
+		}
 		dataArray := strings.Split(decodedQuery, " ")
 
 		if len(dataArray) >= 2 && len(dataArray[1]) > 0 {