diff --git a/src/utils/batchDiscordRequests.ts b/src/utils/batchDiscordRequests.ts new file mode 100644 index 00000000..b413c3fc --- /dev/null +++ b/src/utils/batchDiscordRequests.ts @@ -0,0 +1,109 @@ +import JSONResponse from "./JsonResponse"; +import { addDelay, convertSecondsToMillis } from "./timeUtils"; +export const DISCORD_HEADERS = { + RATE_LIMIT_RESET_AFTER: "X-RateLimit-Reset-After", + RATE_LIMIT_REMAINING: "X-RateLimit-Remaining", + RETRY_AFTER: "Retry-After", +}; + +const MAX_RETRY = 1; + +interface RequestDetails { + retries: number; + request: () => Promise; + index: number; +} +interface ResponseDetails { + response: Response; + data: RequestDetails; +} + +export const batchDiscordRequests = async ( + requests: { (): Promise }[] +): Promise => { + try { + const requestsQueue: RequestDetails[] = requests.map((request, index) => { + return { + retries: 0, + request: request, + index: index, + }; + }); + + const responseList: Response[] = new Array(requestsQueue.length); + let resetAfter = 0; + let rateLimitRemaining: number | null = null; + let retryAfter: number | null = null; + const handleResponse = async ( + response: JSONResponse, + data: RequestDetails + ): Promise => { + if (response.ok) { + resetAfter = Number.parseFloat( + response.headers.get(DISCORD_HEADERS.RATE_LIMIT_RESET_AFTER) || "0" + ); + rateLimitRemaining = Number.parseInt( + response.headers.get(DISCORD_HEADERS.RATE_LIMIT_REMAINING) || "0" + ); + responseList[data.index] = response; + } else { + retryAfter = Number.parseFloat( + response.headers.get(DISCORD_HEADERS.RETRY_AFTER) || "0" + ); + if (data.retries >= MAX_RETRY) { + responseList[data.index] = response; + } else { + data.retries++; + requestsQueue.push(data); + } + } + }; + + const executeRequest = async ( + data: RequestDetails + ): Promise<{ response: Response; data: RequestDetails }> => { + let response; + try { + response = await data.request(); + } catch (e: unknown) { + response = new JSONResponse({ error: e }, { status: 500 }); + } + return { response, data }; + }; + + let promises: Promise<{ response: Response; data: RequestDetails }>[] = []; + + while (requestsQueue.length > 0) { + const requestData = requestsQueue.pop(); + if (!requestData) continue; + promises.push(executeRequest(requestData)); + if (rateLimitRemaining) { + rateLimitRemaining--; + } + if ( + !rateLimitRemaining || + rateLimitRemaining <= 0 || + requestsQueue.length === 0 + ) { + const resultList: ResponseDetails[] = await Promise.all(promises); + promises = []; + for (const result of resultList) { + const { response, data } = result; + await handleResponse(response, data); + } + if (rateLimitRemaining && rateLimitRemaining <= 0 && resetAfter) { + await addDelay(convertSecondsToMillis(resetAfter)); + rateLimitRemaining = null; + } else if (retryAfter && retryAfter > 0) { + await addDelay(convertSecondsToMillis(retryAfter)); + retryAfter = null; + } + } + } + + return responseList; + } catch (e) { + console.error(e); + throw e; + } +}; diff --git a/src/utils/timeUtils.ts b/src/utils/timeUtils.ts new file mode 100644 index 00000000..065047f9 --- /dev/null +++ b/src/utils/timeUtils.ts @@ -0,0 +1,7 @@ +export const addDelay = async (millisecond: number): Promise => { + await new Promise((resolve) => setTimeout(resolve, millisecond)); +}; + +export const convertSecondsToMillis = (seconds: number): number => { + return Math.ceil(seconds * 1000); +}; diff --git a/tests/unit/utils/batchDiscordRequests.test.ts b/tests/unit/utils/batchDiscordRequests.test.ts new file mode 100644 index 00000000..65f7b6c4 --- /dev/null +++ b/tests/unit/utils/batchDiscordRequests.test.ts @@ -0,0 +1,201 @@ +import { + batchDiscordRequests, + DISCORD_HEADERS, +} from "../../../src/utils/batchDiscordRequests"; +import JSONResponse from "../../../src/utils/JsonResponse"; + +describe("Utils | batchDiscordRequests", () => { + const rateLimitingHeaders = { + [DISCORD_HEADERS.RATE_LIMIT_REMAINING]: "9", + [DISCORD_HEADERS.RATE_LIMIT_RESET_AFTER]: "1.1", // seconds + }; + + const rateLimitExceededHeaders = { + [DISCORD_HEADERS.RETRY_AFTER]: "1.2", // seconds + }; + + let fetchSpy: jest.SpyInstance; + let setTimeoutSpy: jest.SpyInstance; + + beforeEach(() => { + fetchSpy = jest.spyOn(global, "fetch"); + setTimeoutSpy = jest.spyOn(global, "setTimeout"); + }); + + afterEach(() => { + jest.resetAllMocks(); + jest.restoreAllMocks(); + }); + + test("should execute requests when there are no headers", async () => { + fetchSpy.mockImplementation(() => + Promise.resolve(new JSONResponse({}, {})) + ); + const singleRequest = () => fetch("/abc", { method: "GET" }); + await batchDiscordRequests([singleRequest]); + expect(global.fetch).toHaveBeenCalledWith("/abc", { method: "GET" }); + expect(global.fetch).toBeCalledTimes(1); + }); + + test("should execute multiple requests when there are no headers", async () => { + fetchSpy.mockImplementation(() => + Promise.resolve(new JSONResponse({}, {})) + ); + const singleRequest = () => fetch("/abc", { method: "GET" }); + await batchDiscordRequests(new Array(20).fill(singleRequest)); + expect(global.fetch).toHaveBeenCalledWith("/abc", { method: "GET" }); + expect(global.fetch).toBeCalledTimes(20); + }); + + test("should execute requests when there are headers and input size is 40 with a limit of 3", async () => { + const maxRateLimit = 3; + const inputSize = 40; + let remainingRateLimit = maxRateLimit; + const headers = { ...rateLimitingHeaders }; + fetchSpy.mockImplementation( + () => + new Promise((resolve) => { + headers[DISCORD_HEADERS.RATE_LIMIT_REMAINING] = + remainingRateLimit.toString(); + remainingRateLimit--; + return resolve(new JSONResponse({}, { headers: headers })); + }) + ); + setTimeoutSpy.mockImplementation((resolve: any) => { + remainingRateLimit = maxRateLimit; + return resolve(); + }); + const singleRequest = () => fetch("/abc", { method: "GET" }); + await batchDiscordRequests(new Array(inputSize).fill(singleRequest)); + expect(global.fetch).toHaveBeenCalledWith("/abc", { method: "GET" }); + expect(global.fetch).toBeCalledTimes(inputSize); + }); + + test("should execute requests when there are headers and input size is 6 with a limit of 2", async () => { + const maxRateLimit = 3; + const inputSize = 6; + let remainingRateLimit = maxRateLimit; + const headers = { ...rateLimitingHeaders }; + fetchSpy.mockImplementation( + () => + new Promise((resolve) => { + headers[DISCORD_HEADERS.RATE_LIMIT_REMAINING] = + remainingRateLimit.toString(); + remainingRateLimit--; + return resolve(new JSONResponse({}, { headers: headers })); + }) + ); + setTimeoutSpy.mockImplementation((resolve: any) => { + remainingRateLimit = maxRateLimit; + return resolve(); + }); + const singleRequest = () => fetch("/abc", { method: "GET" }); + await batchDiscordRequests(new Array(inputSize).fill(singleRequest)); + expect(global.fetch).toHaveBeenCalledWith("/abc", { method: "GET" }); + expect(global.fetch).toBeCalledTimes(inputSize); + }); + + test("should retry fetch call when the API fails", async () => { + const headers = { ...rateLimitExceededHeaders }; + fetchSpy.mockImplementation( + () => + new Promise((resolve) => { + return resolve( + new JSONResponse({}, { headers: headers, status: 500 }) + ); + }) + ); + setTimeoutSpy.mockImplementation((resolve: any) => { + return resolve(); + }); + const singleRequest = () => fetch("/abc", { method: "GET" }); + await batchDiscordRequests([singleRequest]); + expect(global.fetch).toHaveBeenCalledWith("/abc", { method: "GET" }); + expect(global.fetch).toBeCalledTimes(2); + }); + + test("should retry only failed fetch calls", async () => { + const maxRateLimit = 3; + const inputSize = 6; + let remainingRateLimit = maxRateLimit; + let retries = 5; + const headers = { ...rateLimitingHeaders }; + fetchSpy.mockImplementation( + () => + new Promise((resolve) => { + const status = retries > 0 ? 500 : 200; + retries--; + headers[DISCORD_HEADERS.RATE_LIMIT_REMAINING] = + remainingRateLimit.toString(); + remainingRateLimit--; + return resolve( + new JSONResponse({}, { headers: headers, status: status }) + ); + }) + ); + setTimeoutSpy.mockImplementation((resolve: any) => { + remainingRateLimit = maxRateLimit; + return resolve(); + }); + const singleRequest = () => fetch("/abc", { method: "GET" }); + await batchDiscordRequests(new Array(inputSize).fill(singleRequest)); + expect(global.fetch).toHaveBeenCalledWith("/abc", { method: "GET" }); + expect(global.fetch).toBeCalledTimes(inputSize + 3); + }); + test("should retry only failed fetch calls", async () => { + const maxRateLimit = 3; + const inputSize = 6; + let remainingRateLimit = maxRateLimit; + let retries = 5; + const headers = { ...rateLimitingHeaders }; + fetchSpy.mockImplementation( + () => + new Promise((resolve) => { + const status = retries > 0 ? 500 : 200; + retries--; + headers[DISCORD_HEADERS.RATE_LIMIT_REMAINING] = + remainingRateLimit.toString(); + remainingRateLimit--; + return resolve( + new JSONResponse({}, { headers: headers, status: status }) + ); + }) + ); + setTimeoutSpy.mockImplementation((resolve: any) => { + remainingRateLimit = maxRateLimit; + return resolve(); + }); + const singleRequest = () => fetch("/abc", { method: "GET" }); + await batchDiscordRequests(new Array(inputSize).fill(singleRequest)); + expect(global.fetch).toHaveBeenCalledWith("/abc", { method: "GET" }); + expect(global.fetch).toBeCalledTimes(inputSize + 3); + }); + + test("should retry even for the rate limited exceeded headers", async () => { + const inputSize = 4; + const headers = { ...rateLimitExceededHeaders }; + fetchSpy.mockImplementation( + () => + new Promise((resolve) => { + return resolve( + new JSONResponse({}, { headers: headers, status: 500 }) + ); + }) + ); + setTimeoutSpy.mockImplementation((resolve: any) => { + return resolve(); + }); + const singleRequest = () => fetch("/abc", { method: "GET" }); + await batchDiscordRequests(new Array(inputSize).fill(singleRequest)); + expect(global.fetch).toHaveBeenCalledWith("/abc", { method: "GET" }); + expect(global.fetch).toBeCalledTimes(inputSize * 2); + }); + test("should handle network errors and continue processing", async () => { + const inputSize = 3; + fetchSpy.mockImplementation(() => Promise.reject("Network error")); + const singleRequest = () => fetch("/abc", { method: "GET" }); + await batchDiscordRequests(new Array(inputSize).fill(singleRequest)); + expect(global.fetch).toHaveBeenCalledWith("/abc", { method: "GET" }); + expect(global.fetch).toBeCalledTimes(inputSize * 2); + }); +});