diff --git a/src/notifications/senders/AppleNotificationSender.ts b/src/notifications/senders/AppleNotificationSender.ts index 4aac769..ba03989 100644 --- a/src/notifications/senders/AppleNotificationSender.ts +++ b/src/notifications/senders/AppleNotificationSender.ts @@ -1,5 +1,6 @@ import jwt from "jsonwebtoken"; import http2 from "http2"; +import { ClientHttp2Session } from "node:http2"; interface APNsUrl { fullUrl: string; @@ -17,10 +18,20 @@ export class AppleNotificationSender { private apnsToken: string | undefined = undefined; private _lastRefreshedTimeMs: number | undefined = undefined; - constructor(private shouldActuallySendNotifications = true) { + constructor( + private shouldActuallySendNotifications = true, + private client: ClientHttp2Session | undefined = undefined, + ) { this.sendNotificationImmediately = this.sendNotificationImmediately.bind(this); this.lastReloadedTimeForAPNsIsTooRecent = this.lastReloadedTimeForAPNsIsTooRecent.bind(this); this.reloadAPNsTokenIfTimePassed = this.reloadAPNsTokenIfTimePassed.bind(this); + this.openConnectionIfNoneExists = this.openConnectionIfNoneExists.bind(this); + this.closeConnectionIfExists = this.closeConnectionIfExists.bind(this); + this.registerClosureEventsForClient = this.registerClosureEventsForClient.bind(this); + + if (this.client !== undefined) { + this.registerClosureEventsForClient(); + } } get lastRefreshedTimeMs(): number | undefined { @@ -83,7 +94,9 @@ export class AppleNotificationSender { throw new Error("APNS_BUNDLE_ID environment variable is not set correctly"); } - const { path, host } = AppleNotificationSender.getAPNsFullUrlToUse(deviceId); + this.openConnectionIfNoneExists(); + + const { path } = AppleNotificationSender.getAPNsFullUrlToUse(deviceId); const headers = { ':method': 'POST', @@ -95,8 +108,8 @@ export class AppleNotificationSender { "apns-topic": bundleId, }; try { - const client = http2.connect(host); - const req = client.request(headers); + if (!this.client) { return false } + const req = this.client.request(headers); req.setEncoding('utf8'); await new Promise((resolve, reject) => { @@ -131,15 +144,29 @@ export class AppleNotificationSender { } } - public static getAPNsFullUrlToUse(deviceId: string): APNsUrl { - // Construct the fetch request - const devBaseUrl = "https://api.development.push.apple.com" - const prodBaseUrl = "https://api.push.apple.com" + private openConnectionIfNoneExists() { + const host = AppleNotificationSender.getAPNsHostToUse(); - let hostToUse = devBaseUrl; - if (process.env.APNS_IS_PRODUCTION === "1") { - hostToUse = prodBaseUrl; + if (!this.client) { + this.client = http2.connect(host); + this.registerClosureEventsForClient(); } + } + + private registerClosureEventsForClient() { + this.client?.on('close', this.closeConnectionIfExists); + this.client?.on('error', this.closeConnectionIfExists); + this.client?.on('goaway', this.closeConnectionIfExists); + this.client?.on('timeout', this.closeConnectionIfExists); + } + + private closeConnectionIfExists() { + this.client?.close(); + this.client = undefined; + } + + public static getAPNsFullUrlToUse(deviceId: string): APNsUrl { + let hostToUse = this.getAPNsHostToUse(); const path = "/3/device/" + deviceId; const fullUrl = hostToUse + path; @@ -151,4 +178,15 @@ export class AppleNotificationSender { }; } + public static getAPNsHostToUse() { + // Construct the fetch request + const devBaseUrl = "https://api.development.push.apple.com" + const prodBaseUrl = "https://api.push.apple.com" + + let hostToUse = devBaseUrl; + if (process.env.APNS_IS_PRODUCTION === "1") { + hostToUse = prodBaseUrl; + } + return hostToUse; + } } diff --git a/test/notifications/senders/AppleNotificationSenderTests.test.ts b/test/notifications/senders/AppleNotificationSenderTests.test.ts index 948be7d..bd7f9ef 100644 --- a/test/notifications/senders/AppleNotificationSenderTests.test.ts +++ b/test/notifications/senders/AppleNotificationSenderTests.test.ts @@ -5,29 +5,36 @@ import { AppleNotificationSender, NotificationAlertArguments } from "../../../src/notifications/senders/AppleNotificationSender"; +import { ClientHttp2Session } from "node:http2"; jest.mock("http2"); const sampleKeyBase64 = "LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0tCk1JR1RBZ0VBTUJNR0J5cUdTTTQ5QWdFR0NDcUdTTTQ5QXdFSEJIa3dkd0lCQVFRZ3NybVNBWklhZ09mQ1A4c0IKV2kyQ0JYRzFPbzd2MWJpc3BJWkN3SXI0UkRlZ0NnWUlLb1pJemowREFRZWhSQU5DQUFUWkh4VjJ3UUpMTUJxKwp5YSt5ZkdpM2cyWlV2NmhyZmUrajA4eXRla1BIalhTMHF6Sm9WRUx6S0hhNkVMOVlBb1pEWEJ0QjZoK2ZHaFhlClNPY09OYmFmCi0tLS0tRU5EIFBSSVZBVEUgS0VZLS0tLS0K"; -function mockHttp2Connect(status: number) { - class MockClient extends EventEmitter { - request = jest.fn((_) => { - const mockRequest: any = new EventEmitter(); - mockRequest.setEncoding = jest.fn(); - mockRequest.write = jest.fn(); - mockRequest.end = jest.fn(() => { - setTimeout(() => { - mockRequest.emit('response', { ':status': status }); - }, 10); - }); - return mockRequest; - }); - - close() {}; +class MockClient extends EventEmitter { + constructor( + private status: number, + ) { + super() } - (http2.connect as jest.Mock) = jest.fn(() => new MockClient()); + request = jest.fn((_) => { + const mockRequest: any = new EventEmitter(); + mockRequest.setEncoding = jest.fn(); + mockRequest.write = jest.fn(); + mockRequest.end = jest.fn(() => { + setTimeout(() => { + mockRequest.emit('response', { ':status': this.status }); + }, 10); + }); + return mockRequest; + }); + + close = jest.fn(() => {}); +} + +function mockHttp2Connect(status: number) { + (http2.connect as jest.Mock) = jest.fn(() => new MockClient(status)); } describe("AppleNotificationSender", () => { @@ -91,7 +98,7 @@ describe("AppleNotificationSender", () => { }); describe("sendNotificationImmediately", () => { - it('makes the connection to the http server if the notification should be sent', async () => { + it('makes the connection to the http server if sending a notification for the first time', async () => { const notificationArguments: NotificationAlertArguments = { title: 'Test notification', body: 'This notification will send', @@ -103,6 +110,20 @@ describe("AppleNotificationSender", () => { expect(result).toBe(true); }); + it('reuses the existing connection if sending another notification', async () => { + const notificationArguments: NotificationAlertArguments = { + title: 'Test notification', + body: 'This notification will send', + } + + const result1 = await notificationSender.sendNotificationImmediately('1', notificationArguments); + const result2 = await notificationSender.sendNotificationImmediately('1', notificationArguments); + + expect(http2.connect).toHaveBeenCalledTimes(1); + expect(result1).toBe(true); + expect(result2).toBe(true); + }); + it('throws an error if the bundle ID is not set correctly', async () => { process.env = { ...process.env, @@ -145,5 +166,25 @@ describe("AppleNotificationSender", () => { expect(http2.connect).not.toHaveBeenCalled(); expect(result).toBe(true); }); + + it("registers a handler to close the connection if `close` event fired", async () => { + const connectionCloseEvents = ['close', 'goaway', 'error', 'timeout']; + + await Promise.all(connectionCloseEvents.map(async (event) => { + const mockClient = new MockClient(200); + notificationSender = new AppleNotificationSender(true, mockClient as unknown as ClientHttp2Session); + + const notificationArguments: NotificationAlertArguments = { + title: 'Test notification', + body: '' + }; + + await notificationSender.sendNotificationImmediately('1', notificationArguments); + + mockClient.emit(event); + + expect(mockClient.close).toHaveBeenCalled(); + })); + }); }); });