diff --git a/src/accessibility/RovingTabIndex.tsx b/src/accessibility/RovingTabIndex.tsx index 2fb22e9f8f..6682e1068b 100644 --- a/src/accessibility/RovingTabIndex.tsx +++ b/src/accessibility/RovingTabIndex.tsx @@ -10,7 +10,6 @@ import React, { createContext, useCallback, useContext, - useEffect, useMemo, useRef, useReducer, @@ -22,7 +21,7 @@ import React, { import { getKeyBindingsManager } from "../KeyBindingsManager"; import { KeyBindingAction } from "./KeyboardShortcuts"; -import { FocusHandler, Ref } from "./roving/types"; +import { FocusHandler } from "./roving/types"; /** * Module to simplify implementing the Roving TabIndex accessibility technique @@ -49,8 +48,8 @@ export function checkInputableElement(el: HTMLElement): boolean { } export interface IState { - activeRef?: Ref; - refs: Ref[]; + activeNode?: HTMLElement; + nodes: HTMLElement[]; } export interface IContext { @@ -60,7 +59,7 @@ export interface IContext { export const RovingTabIndexContext = createContext({ state: { - refs: [], // list of refs in DOM order + nodes: [], // list of nodes in DOM order }, dispatch: () => {}, }); @@ -76,7 +75,7 @@ export enum Type { export interface IAction { type: Exclude; payload: { - ref: Ref; + node: HTMLElement; }; } @@ -87,12 +86,12 @@ interface UpdateAction { type Action = IAction | UpdateAction; -const refSorter = (a: Ref, b: Ref): number => { +const nodeSorter = (a: HTMLElement, b: HTMLElement): number => { if (a === b) { return 0; } - const position = a.current!.compareDocumentPosition(b.current!); + const position = a.compareDocumentPosition(b); if (position & Node.DOCUMENT_POSITION_FOLLOWING || position & Node.DOCUMENT_POSITION_CONTAINED_BY) { return -1; @@ -106,54 +105,56 @@ const refSorter = (a: Ref, b: Ref): number => { export const reducer: Reducer = (state: IState, action: Action) => { switch (action.type) { case Type.Register: { - if (!state.activeRef) { - // Our list of refs was empty, set activeRef to this first item - state.activeRef = action.payload.ref; + if (!state.activeNode) { + // Our list of nodes was empty, set activeNode to this first item + state.activeNode = action.payload.node; } + if (state.nodes.includes(action.payload.node)) return state; + // Sadly due to the potential of DOM elements swapping order we can't do anything fancy like a binary insert - state.refs.push(action.payload.ref); - state.refs.sort(refSorter); + state.nodes.push(action.payload.node); + state.nodes.sort(nodeSorter); return { ...state }; } case Type.Unregister: { - const oldIndex = state.refs.findIndex((r) => r === action.payload.ref); + const oldIndex = state.nodes.findIndex((r) => r === action.payload.node); if (oldIndex === -1) { return state; // already removed, this should not happen } - if (state.refs.splice(oldIndex, 1)[0] === state.activeRef) { - // we just removed the active ref, need to replace it - // pick the ref closest to the index the old ref was in - if (oldIndex >= state.refs.length) { - state.activeRef = findSiblingElement(state.refs, state.refs.length - 1, true); + if (state.nodes.splice(oldIndex, 1)[0] === state.activeNode) { + // we just removed the active node, need to replace it + // pick the node closest to the index the old node was in + if (oldIndex >= state.nodes.length) { + state.activeNode = findSiblingElement(state.nodes, state.nodes.length - 1, true); } else { - state.activeRef = - findSiblingElement(state.refs, oldIndex) || findSiblingElement(state.refs, oldIndex, true); + state.activeNode = + findSiblingElement(state.nodes, oldIndex) || findSiblingElement(state.nodes, oldIndex, true); } if (document.activeElement === document.body) { // if the focus got reverted to the body then the user was likely focused on the unmounted element - setTimeout(() => state.activeRef?.current?.focus(), 0); + setTimeout(() => state.activeNode?.focus(), 0); } } - // update the refs list + // update the nodes list return { ...state }; } case Type.SetFocus: { - // if the ref doesn't change just return the same object reference to skip a re-render - if (state.activeRef === action.payload.ref) return state; - // update active ref - state.activeRef = action.payload.ref; + // if the node doesn't change just return the same object reference to skip a re-render + if (state.activeNode === action.payload.node) return state; + // update active node + state.activeNode = action.payload.node; return { ...state }; } case Type.Update: { - state.refs.sort(refSorter); + state.nodes.sort(nodeSorter); return { ...state }; } @@ -174,28 +175,28 @@ interface IProps { } export const findSiblingElement = ( - refs: RefObject[], + nodes: HTMLElement[], startIndex: number, backwards = false, loop = false, -): RefObject | undefined => { +): HTMLElement | undefined => { if (backwards) { - for (let i = startIndex; i < refs.length && i >= 0; i--) { - if (refs[i].current?.offsetParent !== null) { - return refs[i]; + for (let i = startIndex; i < nodes.length && i >= 0; i--) { + if (nodes[i]?.offsetParent !== null) { + return nodes[i]; } } if (loop) { - return findSiblingElement(refs.slice(startIndex + 1), refs.length - 1, true, false); + return findSiblingElement(nodes.slice(startIndex + 1), nodes.length - 1, true, false); } } else { - for (let i = startIndex; i < refs.length && i >= 0; i++) { - if (refs[i].current?.offsetParent !== null) { - return refs[i]; + for (let i = startIndex; i < nodes.length && i >= 0; i++) { + if (nodes[i]?.offsetParent !== null) { + return nodes[i]; } } if (loop) { - return findSiblingElement(refs.slice(0, startIndex), 0, false, false); + return findSiblingElement(nodes.slice(0, startIndex), 0, false, false); } } }; @@ -211,7 +212,7 @@ export const RovingTabIndexProvider: React.FC = ({ onKeyDown, }) => { const [state, dispatch] = useReducer>(reducer, { - refs: [], + nodes: [], }); const context = useMemo(() => ({ state, dispatch }), [state]); @@ -227,17 +228,17 @@ export const RovingTabIndexProvider: React.FC = ({ let handled = false; const action = getKeyBindingsManager().getAccessibilityAction(ev); - let focusRef: RefObject | undefined; + let focusNode: HTMLElement | undefined; // Don't interfere with input default keydown behaviour // but allow people to move focus from it with Tab. if (!handleInputFields && checkInputableElement(ev.target as HTMLElement)) { switch (action) { case KeyBindingAction.Tab: handled = true; - if (context.state.refs.length > 0) { - const idx = context.state.refs.indexOf(context.state.activeRef!); - focusRef = findSiblingElement( - context.state.refs, + if (context.state.nodes.length > 0) { + const idx = context.state.nodes.indexOf(context.state.activeNode!); + focusNode = findSiblingElement( + context.state.nodes, idx + (ev.shiftKey ? -1 : 1), ev.shiftKey, ); @@ -251,7 +252,7 @@ export const RovingTabIndexProvider: React.FC = ({ if (handleHomeEnd) { handled = true; // move focus to first (visible) item - focusRef = findSiblingElement(context.state.refs, 0); + focusNode = findSiblingElement(context.state.nodes, 0); } break; @@ -259,7 +260,7 @@ export const RovingTabIndexProvider: React.FC = ({ if (handleHomeEnd) { handled = true; // move focus to last (visible) item - focusRef = findSiblingElement(context.state.refs, context.state.refs.length - 1, true); + focusNode = findSiblingElement(context.state.nodes, context.state.nodes.length - 1, true); } break; @@ -270,9 +271,9 @@ export const RovingTabIndexProvider: React.FC = ({ (action === KeyBindingAction.ArrowRight && handleLeftRight) ) { handled = true; - if (context.state.refs.length > 0) { - const idx = context.state.refs.indexOf(context.state.activeRef!); - focusRef = findSiblingElement(context.state.refs, idx + 1, false, handleLoop); + if (context.state.nodes.length > 0) { + const idx = context.state.nodes.indexOf(context.state.activeNode!); + focusNode = findSiblingElement(context.state.nodes, idx + 1, false, handleLoop); } } break; @@ -284,9 +285,9 @@ export const RovingTabIndexProvider: React.FC = ({ (action === KeyBindingAction.ArrowLeft && handleLeftRight) ) { handled = true; - if (context.state.refs.length > 0) { - const idx = context.state.refs.indexOf(context.state.activeRef!); - focusRef = findSiblingElement(context.state.refs, idx - 1, true, handleLoop); + if (context.state.nodes.length > 0) { + const idx = context.state.nodes.indexOf(context.state.activeNode!); + focusNode = findSiblingElement(context.state.nodes, idx - 1, true, handleLoop); } } break; @@ -298,17 +299,17 @@ export const RovingTabIndexProvider: React.FC = ({ ev.stopPropagation(); } - if (focusRef) { - focusRef.current?.focus(); + if (focusNode) { + focusNode?.focus(); // programmatic focus doesn't fire the onFocus handler, so we must do the do ourselves dispatch({ type: Type.SetFocus, payload: { - ref: focusRef, + node: focusNode, }, }); if (scrollIntoView) { - focusRef.current?.scrollIntoView(scrollIntoView); + focusNode?.scrollIntoView(scrollIntoView); } } }, @@ -337,46 +338,57 @@ export const RovingTabIndexProvider: React.FC = ({ ); }; -// Hook to register a roving tab index -// inputRef parameter specifies the ref to use -// onFocus should be called when the index gained focus in any manner -// isActive should be used to set tabIndex in a manner such as `tabIndex={isActive ? 0 : -1}` -// ref should be passed to a DOM node which will be used for DOM compareDocumentPosition +/** + * Hook to register a roving tab index. + * + * inputRef is an optional argument; when passed this ref points to the DOM element + * to which the callback ref is attached. + * + * Returns: + * onFocus should be called when the index gained focus in any manner. + * isActive should be used to set tabIndex in a manner such as `tabIndex={isActive ? 0 : -1}`. + * ref is a callback ref that should be passed to a DOM node which will be used for DOM compareDocumentPosition. + * nodeRef is a ref that points to the DOM element to which the ref mentioned above is attached. + * + * nodeRef = inputRef when inputRef argument is provided. + */ export const useRovingTabIndex = ( inputRef?: RefObject, -): [FocusHandler, boolean, RefObject] => { +): [FocusHandler, boolean, (node: T | null) => void, RefObject] => { const context = useContext(RovingTabIndexContext); - let ref = useRef(null); + + let nodeRef = useRef(null); if (inputRef) { // if we are given a ref, use it instead of ours - ref = inputRef; + nodeRef = inputRef; } - // setup (after refs) - useEffect(() => { - context.dispatch({ - type: Type.Register, - payload: { ref }, - }); - // teardown - return () => { + const ref = useCallback((node: T | null) => { + if (node) { + nodeRef.current = node; + context.dispatch({ + type: Type.Register, + payload: { node }, + }); + } else { context.dispatch({ type: Type.Unregister, - payload: { ref }, + payload: { node: nodeRef.current! }, }); - }; + nodeRef.current = null; + } }, []); // eslint-disable-line react-hooks/exhaustive-deps const onFocus = useCallback(() => { context.dispatch({ type: Type.SetFocus, - payload: { ref }, + payload: { node: nodeRef.current } as { node: T }, }); }, []); // eslint-disable-line react-hooks/exhaustive-deps - const isActive = context.state.activeRef === ref; - return [onFocus, isActive, ref]; + const isActive = context.state.activeNode === nodeRef.current; + return [onFocus, isActive, ref, nodeRef]; }; // re-export the semantic helper components for simplicity diff --git a/src/accessibility/roving/RovingTabIndexWrapper.tsx b/src/accessibility/roving/RovingTabIndexWrapper.tsx index 93436ef4b5..c6a67ac783 100644 --- a/src/accessibility/roving/RovingTabIndexWrapper.tsx +++ b/src/accessibility/roving/RovingTabIndexWrapper.tsx @@ -13,7 +13,11 @@ import { FocusHandler, Ref } from "./types"; interface IProps { inputRef?: Ref; - children(renderProps: { onFocus: FocusHandler; isActive: boolean; ref: Ref }): ReactElement; + children(renderProps: { + onFocus: FocusHandler; + isActive: boolean; + ref: (node: HTMLElement | null) => void; + }): ReactElement; } // Wrapper to allow use of useRovingTabIndex outside of React Functional Components. diff --git a/test/unit-tests/accessibility/RovingTabIndex-test.tsx b/test/unit-tests/accessibility/RovingTabIndex-test.tsx index 8130167db4..d3d75397b7 100644 --- a/test/unit-tests/accessibility/RovingTabIndex-test.tsx +++ b/test/unit-tests/accessibility/RovingTabIndex-test.tsx @@ -28,6 +28,12 @@ const checkTabIndexes = (buttons: NodeListOf, expectations: number[ expect([...buttons].map((b) => b.tabIndex)).toStrictEqual(expectations); }; +const createButtonElement = (text: string): HTMLButtonElement => { + const button = document.createElement("button"); + button.textContent = text; + return button; +}; + // give the buttons keys for the fibre reconciler to not treat them all as the same const button1 = ; const button2 = ; @@ -123,11 +129,7 @@ describe("RovingTabIndex", () => { {button2} {({ onFocus, isActive, ref }) => ( - )} @@ -147,75 +149,75 @@ describe("RovingTabIndex", () => { describe("reducer functions as expected", () => { it("SetFocus works as expected", () => { - const ref1 = React.createRef(); - const ref2 = React.createRef(); + const node1 = createButtonElement("Button 1"); + const node2 = createButtonElement("Button 2"); expect( reducer( { - activeRef: ref1, - refs: [ref1, ref2], + activeNode: node1, + nodes: [node1, node2], }, { type: Type.SetFocus, payload: { - ref: ref2, + node: node2, }, }, ), ).toStrictEqual({ - activeRef: ref2, - refs: [ref1, ref2], + activeNode: node2, + nodes: [node1, node2], }); }); it("Unregister works as expected", () => { - const ref1 = React.createRef(); - const ref2 = React.createRef(); - const ref3 = React.createRef(); - const ref4 = React.createRef(); + const button1 = createButtonElement("Button 1"); + const button2 = createButtonElement("Button 2"); + const button3 = createButtonElement("Button 3"); + const button4 = createButtonElement("Button 4"); let state: IState = { - refs: [ref1, ref2, ref3, ref4], + nodes: [button1, button2, button3, button4], }; state = reducer(state, { type: Type.Unregister, payload: { - ref: ref2, + node: button2, }, }); expect(state).toStrictEqual({ - refs: [ref1, ref3, ref4], + nodes: [button1, button3, button4], }); state = reducer(state, { type: Type.Unregister, payload: { - ref: ref3, + node: button3, }, }); expect(state).toStrictEqual({ - refs: [ref1, ref4], + nodes: [button1, button4], }); state = reducer(state, { type: Type.Unregister, payload: { - ref: ref4, + node: button4, }, }); expect(state).toStrictEqual({ - refs: [ref1], + nodes: [button1], }); state = reducer(state, { type: Type.Unregister, payload: { - ref: ref1, + node: button1, }, }); expect(state).toStrictEqual({ - refs: [], + nodes: [], }); }); @@ -235,122 +237,122 @@ describe("RovingTabIndex", () => { ); let state: IState = { - refs: [], + nodes: [], }; state = reducer(state, { type: Type.Register, payload: { - ref: ref1, + node: ref1.current!, }, }); expect(state).toStrictEqual({ - activeRef: ref1, - refs: [ref1], + activeNode: ref1.current, + nodes: [ref1.current], }); state = reducer(state, { type: Type.Register, payload: { - ref: ref2, + node: ref2.current!, }, }); expect(state).toStrictEqual({ - activeRef: ref1, - refs: [ref1, ref2], + activeNode: ref1.current, + nodes: [ref1.current, ref2.current], }); state = reducer(state, { type: Type.Register, payload: { - ref: ref3, + node: ref3.current!, }, }); expect(state).toStrictEqual({ - activeRef: ref1, - refs: [ref1, ref2, ref3], + activeNode: ref1.current, + nodes: [ref1.current, ref2.current, ref3.current], }); state = reducer(state, { type: Type.Register, payload: { - ref: ref4, + node: ref4.current!, }, }); expect(state).toStrictEqual({ - activeRef: ref1, - refs: [ref1, ref2, ref3, ref4], + activeNode: ref1.current, + nodes: [ref1.current, ref2.current, ref3.current, ref4.current], }); // test that the automatic focus switch works for unmounting state = reducer(state, { type: Type.SetFocus, payload: { - ref: ref2, + node: ref2.current!, }, }); expect(state).toStrictEqual({ - activeRef: ref2, - refs: [ref1, ref2, ref3, ref4], + activeNode: ref2.current, + nodes: [ref1.current, ref2.current, ref3.current, ref4.current], }); state = reducer(state, { type: Type.Unregister, payload: { - ref: ref2, + node: ref2.current!, }, }); expect(state).toStrictEqual({ - activeRef: ref3, - refs: [ref1, ref3, ref4], + activeNode: ref3.current, + nodes: [ref1.current, ref3.current, ref4.current], }); // test that the insert into the middle works as expected state = reducer(state, { type: Type.Register, payload: { - ref: ref2, + node: ref2.current!, }, }); expect(state).toStrictEqual({ - activeRef: ref3, - refs: [ref1, ref2, ref3, ref4], + activeNode: ref3.current, + nodes: [ref1.current, ref2.current, ref3.current, ref4.current], }); // test that insertion at the edges works state = reducer(state, { type: Type.Unregister, payload: { - ref: ref1, + node: ref1.current!, }, }); state = reducer(state, { type: Type.Unregister, payload: { - ref: ref4, + node: ref4.current!, }, }); expect(state).toStrictEqual({ - activeRef: ref3, - refs: [ref2, ref3], + activeNode: ref3.current, + nodes: [ref2.current, ref3.current], }); state = reducer(state, { type: Type.Register, payload: { - ref: ref1, + node: ref1.current!, }, }); state = reducer(state, { type: Type.Register, payload: { - ref: ref4, + node: ref4.current!, }, }); expect(state).toStrictEqual({ - activeRef: ref3, - refs: [ref1, ref2, ref3, ref4], + activeNode: ref3.current, + nodes: [ref1.current, ref2.current, ref3.current, ref4.current], }); }); });