Procházet zdrojové kódy

Rewrite RateLimitInterceptor (#7889)

stevenyomi před 2 roky
rodič
revize
532f662b05

+ 55 - 30
app/src/main/java/eu/kanade/tachiyomi/network/interceptor/RateLimitInterceptor.kt

@@ -5,6 +5,8 @@ import okhttp3.Interceptor
 import okhttp3.OkHttpClient
 import okhttp3.Response
 import java.io.IOException
+import java.util.ArrayDeque
+import java.util.concurrent.Semaphore
 import java.util.concurrent.TimeUnit
 
 /**
@@ -25,54 +27,77 @@ fun OkHttpClient.Builder.rateLimit(
     permits: Int,
     period: Long = 1,
     unit: TimeUnit = TimeUnit.SECONDS,
-) = addInterceptor(RateLimitInterceptor(permits, period, unit))
+) = addInterceptor(RateLimitInterceptor(null, permits, period, unit))
 
-private class RateLimitInterceptor(
+/** We can probably accept domains or wildcards by comparing with [endsWith], etc. */
+@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
+internal class RateLimitInterceptor(
+    private val host: String?,
     private val permits: Int,
     period: Long,
     unit: TimeUnit,
 ) : Interceptor {
 
-    private val requestQueue = ArrayList<Long>(permits)
+    private val requestQueue = ArrayDeque<Long>(permits)
     private val rateLimitMillis = unit.toMillis(period)
+    private val fairLock = Semaphore(1, true)
 
     override fun intercept(chain: Interceptor.Chain): Response {
-        // Ignore canceled calls, otherwise they would jam the queue
-        if (chain.call().isCanceled()) {
-            throw IOException()
+        val call = chain.call()
+        if (call.isCanceled()) throw IOException("Canceled")
+
+        val request = chain.request()
+        when (host) {
+            null, request.url.host -> {} // need rate limit
+            else -> return chain.proceed(request)
+        }
+
+        try {
+            fairLock.acquire()
+        } catch (e: InterruptedException) {
+            throw IOException(e)
         }
 
-        synchronized(requestQueue) {
-            val now = SystemClock.elapsedRealtime()
-            val waitTime = if (requestQueue.size < permits) {
-                0
-            } else {
-                val oldestReq = requestQueue[0]
-                val newestReq = requestQueue[permits - 1]
+        val requestQueue = this.requestQueue
+        val timestamp: Long
 
-                if (newestReq - oldestReq > rateLimitMillis) {
-                    0
-                } else {
-                    oldestReq + rateLimitMillis - now // Remaining time
+        try {
+            synchronized(requestQueue) {
+                while (requestQueue.size >= permits) { // queue is full, remove expired entries
+                    val periodStart = SystemClock.elapsedRealtime() - rateLimitMillis
+                    var hasRemovedExpired = false
+                    while (requestQueue.isEmpty().not() && requestQueue.first <= periodStart) {
+                        requestQueue.removeFirst()
+                        hasRemovedExpired = true
+                    }
+                    if (call.isCanceled()) {
+                        throw IOException("Canceled")
+                    } else if (hasRemovedExpired) {
+                        break
+                    } else try { // wait for the first entry to expire, or notified by cached response
+                        (requestQueue as Object).wait(requestQueue.first - periodStart)
+                    } catch (_: InterruptedException) {
+                        continue
+                    }
                 }
-            }
 
-            // Final check
-            if (chain.call().isCanceled()) {
-                throw IOException()
+                // add request to queue
+                timestamp = SystemClock.elapsedRealtime()
+                requestQueue.addLast(timestamp)
             }
+        } finally {
+            fairLock.release()
+        }
 
-            if (requestQueue.size == permits) {
-                requestQueue.removeAt(0)
-            }
-            if (waitTime > 0) {
-                requestQueue.add(now + waitTime)
-                Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests
-            } else {
-                requestQueue.add(now)
+        val response = chain.proceed(request)
+        if (response.networkResponse == null) { // response is cached, remove it from queue
+            synchronized(requestQueue) {
+                if (requestQueue.isEmpty() || timestamp < requestQueue.first) return@synchronized
+                requestQueue.removeFirstOccurrence(timestamp)
+                (requestQueue as Object).notifyAll()
             }
         }
 
-        return chain.proceed(chain.request())
+        return response
     }
 }

+ 1 - 59
app/src/main/java/eu/kanade/tachiyomi/network/interceptor/SpecificHostRateLimitInterceptor.kt

@@ -1,11 +1,7 @@
 package eu.kanade.tachiyomi.network.interceptor
 
-import android.os.SystemClock
 import okhttp3.HttpUrl
-import okhttp3.Interceptor
 import okhttp3.OkHttpClient
-import okhttp3.Response
-import java.io.IOException
 import java.util.concurrent.TimeUnit
 
 /**
@@ -28,58 +24,4 @@ fun OkHttpClient.Builder.rateLimitHost(
     permits: Int,
     period: Long = 1,
     unit: TimeUnit = TimeUnit.SECONDS,
-) = addInterceptor(SpecificHostRateLimitInterceptor(httpUrl, permits, period, unit))
-
-class SpecificHostRateLimitInterceptor(
-    httpUrl: HttpUrl,
-    private val permits: Int,
-    period: Long,
-    unit: TimeUnit,
-) : Interceptor {
-
-    private val requestQueue = ArrayList<Long>(permits)
-    private val rateLimitMillis = unit.toMillis(period)
-    private val host = httpUrl.host
-
-    override fun intercept(chain: Interceptor.Chain): Response {
-        // Ignore canceled calls, otherwise they would jam the queue
-        if (chain.call().isCanceled()) {
-            throw IOException()
-        } else if (chain.request().url.host != host) {
-            return chain.proceed(chain.request())
-        }
-
-        synchronized(requestQueue) {
-            val now = SystemClock.elapsedRealtime()
-            val waitTime = if (requestQueue.size < permits) {
-                0
-            } else {
-                val oldestReq = requestQueue[0]
-                val newestReq = requestQueue[permits - 1]
-
-                if (newestReq - oldestReq > rateLimitMillis) {
-                    0
-                } else {
-                    oldestReq + rateLimitMillis - now // Remaining time
-                }
-            }
-
-            // Final check
-            if (chain.call().isCanceled()) {
-                throw IOException()
-            }
-
-            if (requestQueue.size == permits) {
-                requestQueue.removeAt(0)
-            }
-            if (waitTime > 0) {
-                requestQueue.add(now + waitTime)
-                Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests
-            } else {
-                requestQueue.add(now)
-            }
-        }
-
-        return chain.proceed(chain.request())
-    }
-}
+) = addInterceptor(RateLimitInterceptor(httpUrl.host, permits, period, unit))