diff --git a/src/accessibility/RovingTabIndex.tsx b/src/accessibility/RovingTabIndex.tsx
index 2fb22e9f8f..6682e1068b 100644
--- a/src/accessibility/RovingTabIndex.tsx
+++ b/src/accessibility/RovingTabIndex.tsx
@@ -10,7 +10,6 @@ import React, {
     createContext,
     useCallback,
     useContext,
-    useEffect,
     useMemo,
     useRef,
     useReducer,
@@ -22,7 +21,7 @@ import React, {
 
 import { getKeyBindingsManager } from "../KeyBindingsManager";
 import { KeyBindingAction } from "./KeyboardShortcuts";
-import { FocusHandler, Ref } from "./roving/types";
+import { FocusHandler } from "./roving/types";
 
 /**
  * Module to simplify implementing the Roving TabIndex accessibility technique
@@ -49,8 +48,8 @@ export function checkInputableElement(el: HTMLElement): boolean {
 }
 
 export interface IState {
-    activeRef?: Ref;
-    refs: Ref[];
+    activeNode?: HTMLElement;
+    nodes: HTMLElement[];
 }
 
 export interface IContext {
@@ -60,7 +59,7 @@ export interface IContext {
 
 export const RovingTabIndexContext = createContext<IContext>({
     state: {
-        refs: [], // list of refs in DOM order
+        nodes: [], // list of nodes in DOM order
     },
     dispatch: () => {},
 });
@@ -76,7 +75,7 @@ export enum Type {
 export interface IAction {
     type: Exclude<Type, Type.Update>;
     payload: {
-        ref: Ref;
+        node: HTMLElement;
     };
 }
 
@@ -87,12 +86,12 @@ interface UpdateAction {
 
 type Action = IAction | UpdateAction;
 
-const refSorter = (a: Ref, b: Ref): number => {
+const nodeSorter = (a: HTMLElement, b: HTMLElement): number => {
     if (a === b) {
         return 0;
     }
 
-    const position = a.current!.compareDocumentPosition(b.current!);
+    const position = a.compareDocumentPosition(b);
 
     if (position & Node.DOCUMENT_POSITION_FOLLOWING || position & Node.DOCUMENT_POSITION_CONTAINED_BY) {
         return -1;
@@ -106,54 +105,56 @@ const refSorter = (a: Ref, b: Ref): number => {
 export const reducer: Reducer<IState, Action> = (state: IState, action: Action) => {
     switch (action.type) {
         case Type.Register: {
-            if (!state.activeRef) {
-                // Our list of refs was empty, set activeRef to this first item
-                state.activeRef = action.payload.ref;
+            if (!state.activeNode) {
+                // Our list of nodes was empty, set activeNode to this first item
+                state.activeNode = action.payload.node;
             }
 
+            if (state.nodes.includes(action.payload.node)) return state;
+
             // Sadly due to the potential of DOM elements swapping order we can't do anything fancy like a binary insert
-            state.refs.push(action.payload.ref);
-            state.refs.sort(refSorter);
+            state.nodes.push(action.payload.node);
+            state.nodes.sort(nodeSorter);
 
             return { ...state };
         }
 
         case Type.Unregister: {
-            const oldIndex = state.refs.findIndex((r) => r === action.payload.ref);
+            const oldIndex = state.nodes.findIndex((r) => r === action.payload.node);
 
             if (oldIndex === -1) {
                 return state; // already removed, this should not happen
             }
 
-            if (state.refs.splice(oldIndex, 1)[0] === state.activeRef) {
-                // we just removed the active ref, need to replace it
-                // pick the ref closest to the index the old ref was in
-                if (oldIndex >= state.refs.length) {
-                    state.activeRef = findSiblingElement(state.refs, state.refs.length - 1, true);
+            if (state.nodes.splice(oldIndex, 1)[0] === state.activeNode) {
+                // we just removed the active node, need to replace it
+                // pick the node closest to the index the old node was in
+                if (oldIndex >= state.nodes.length) {
+                    state.activeNode = findSiblingElement(state.nodes, state.nodes.length - 1, true);
                 } else {
-                    state.activeRef =
-                        findSiblingElement(state.refs, oldIndex) || findSiblingElement(state.refs, oldIndex, true);
+                    state.activeNode =
+                        findSiblingElement(state.nodes, oldIndex) || findSiblingElement(state.nodes, oldIndex, true);
                 }
                 if (document.activeElement === document.body) {
                     // if the focus got reverted to the body then the user was likely focused on the unmounted element
-                    setTimeout(() => state.activeRef?.current?.focus(), 0);
+                    setTimeout(() => state.activeNode?.focus(), 0);
                 }
             }
 
-            // update the refs list
+            // update the nodes list
             return { ...state };
         }
 
         case Type.SetFocus: {
-            // if the ref doesn't change just return the same object reference to skip a re-render
-            if (state.activeRef === action.payload.ref) return state;
-            // update active ref
-            state.activeRef = action.payload.ref;
+            // if the node doesn't change just return the same object reference to skip a re-render
+            if (state.activeNode === action.payload.node) return state;
+            // update active node
+            state.activeNode = action.payload.node;
             return { ...state };
         }
 
         case Type.Update: {
-            state.refs.sort(refSorter);
+            state.nodes.sort(nodeSorter);
             return { ...state };
         }
 
@@ -174,28 +175,28 @@ interface IProps {
 }
 
 export const findSiblingElement = (
-    refs: RefObject<HTMLElement>[],
+    nodes: HTMLElement[],
     startIndex: number,
     backwards = false,
     loop = false,
-): RefObject<HTMLElement> | undefined => {
+): HTMLElement | undefined => {
     if (backwards) {
-        for (let i = startIndex; i < refs.length && i >= 0; i--) {
-            if (refs[i].current?.offsetParent !== null) {
-                return refs[i];
+        for (let i = startIndex; i < nodes.length && i >= 0; i--) {
+            if (nodes[i]?.offsetParent !== null) {
+                return nodes[i];
             }
         }
         if (loop) {
-            return findSiblingElement(refs.slice(startIndex + 1), refs.length - 1, true, false);
+            return findSiblingElement(nodes.slice(startIndex + 1), nodes.length - 1, true, false);
         }
     } else {
-        for (let i = startIndex; i < refs.length && i >= 0; i++) {
-            if (refs[i].current?.offsetParent !== null) {
-                return refs[i];
+        for (let i = startIndex; i < nodes.length && i >= 0; i++) {
+            if (nodes[i]?.offsetParent !== null) {
+                return nodes[i];
             }
         }
         if (loop) {
-            return findSiblingElement(refs.slice(0, startIndex), 0, false, false);
+            return findSiblingElement(nodes.slice(0, startIndex), 0, false, false);
         }
     }
 };
@@ -211,7 +212,7 @@ export const RovingTabIndexProvider: React.FC<IProps> = ({
     onKeyDown,
 }) => {
     const [state, dispatch] = useReducer<Reducer<IState, Action>>(reducer, {
-        refs: [],
+        nodes: [],
     });
 
     const context = useMemo<IContext>(() => ({ state, dispatch }), [state]);
@@ -227,17 +228,17 @@ export const RovingTabIndexProvider: React.FC<IProps> = ({
 
             let handled = false;
             const action = getKeyBindingsManager().getAccessibilityAction(ev);
-            let focusRef: RefObject<HTMLElement> | undefined;
+            let focusNode: HTMLElement | undefined;
             // Don't interfere with input default keydown behaviour
             // but allow people to move focus from it with Tab.
             if (!handleInputFields && checkInputableElement(ev.target as HTMLElement)) {
                 switch (action) {
                     case KeyBindingAction.Tab:
                         handled = true;
-                        if (context.state.refs.length > 0) {
-                            const idx = context.state.refs.indexOf(context.state.activeRef!);
-                            focusRef = findSiblingElement(
-                                context.state.refs,
+                        if (context.state.nodes.length > 0) {
+                            const idx = context.state.nodes.indexOf(context.state.activeNode!);
+                            focusNode = findSiblingElement(
+                                context.state.nodes,
                                 idx + (ev.shiftKey ? -1 : 1),
                                 ev.shiftKey,
                             );
@@ -251,7 +252,7 @@ export const RovingTabIndexProvider: React.FC<IProps> = ({
                         if (handleHomeEnd) {
                             handled = true;
                             // move focus to first (visible) item
-                            focusRef = findSiblingElement(context.state.refs, 0);
+                            focusNode = findSiblingElement(context.state.nodes, 0);
                         }
                         break;
 
@@ -259,7 +260,7 @@ export const RovingTabIndexProvider: React.FC<IProps> = ({
                         if (handleHomeEnd) {
                             handled = true;
                             // move focus to last (visible) item
-                            focusRef = findSiblingElement(context.state.refs, context.state.refs.length - 1, true);
+                            focusNode = findSiblingElement(context.state.nodes, context.state.nodes.length - 1, true);
                         }
                         break;
 
@@ -270,9 +271,9 @@ export const RovingTabIndexProvider: React.FC<IProps> = ({
                             (action === KeyBindingAction.ArrowRight && handleLeftRight)
                         ) {
                             handled = true;
-                            if (context.state.refs.length > 0) {
-                                const idx = context.state.refs.indexOf(context.state.activeRef!);
-                                focusRef = findSiblingElement(context.state.refs, idx + 1, false, handleLoop);
+                            if (context.state.nodes.length > 0) {
+                                const idx = context.state.nodes.indexOf(context.state.activeNode!);
+                                focusNode = findSiblingElement(context.state.nodes, idx + 1, false, handleLoop);
                             }
                         }
                         break;
@@ -284,9 +285,9 @@ export const RovingTabIndexProvider: React.FC<IProps> = ({
                             (action === KeyBindingAction.ArrowLeft && handleLeftRight)
                         ) {
                             handled = true;
-                            if (context.state.refs.length > 0) {
-                                const idx = context.state.refs.indexOf(context.state.activeRef!);
-                                focusRef = findSiblingElement(context.state.refs, idx - 1, true, handleLoop);
+                            if (context.state.nodes.length > 0) {
+                                const idx = context.state.nodes.indexOf(context.state.activeNode!);
+                                focusNode = findSiblingElement(context.state.nodes, idx - 1, true, handleLoop);
                             }
                         }
                         break;
@@ -298,17 +299,17 @@ export const RovingTabIndexProvider: React.FC<IProps> = ({
                 ev.stopPropagation();
             }
 
-            if (focusRef) {
-                focusRef.current?.focus();
+            if (focusNode) {
+                focusNode?.focus();
                 // programmatic focus doesn't fire the onFocus handler, so we must do the do ourselves
                 dispatch({
                     type: Type.SetFocus,
                     payload: {
-                        ref: focusRef,
+                        node: focusNode,
                     },
                 });
                 if (scrollIntoView) {
-                    focusRef.current?.scrollIntoView(scrollIntoView);
+                    focusNode?.scrollIntoView(scrollIntoView);
                 }
             }
         },
@@ -337,46 +338,57 @@ export const RovingTabIndexProvider: React.FC<IProps> = ({
     );
 };
 
-// Hook to register a roving tab index
-// inputRef parameter specifies the ref to use
-// onFocus should be called when the index gained focus in any manner
-// isActive should be used to set tabIndex in a manner such as `tabIndex={isActive ? 0 : -1}`
-// ref should be passed to a DOM node which will be used for DOM compareDocumentPosition
+/**
+ * Hook to register a roving tab index.
+ *
+ * inputRef is an optional argument; when passed this ref points to the DOM element
+ * to which the callback ref is attached.
+ *
+ * Returns:
+ * onFocus should be called when the index gained focus in any manner.
+ * isActive should be used to set tabIndex in a manner such as `tabIndex={isActive ? 0 : -1}`.
+ * ref is a callback ref that should be passed to a DOM node which will be used for DOM compareDocumentPosition.
+ * nodeRef is a ref that points to the DOM element to which the ref mentioned above is attached.
+ *
+ * nodeRef = inputRef when inputRef argument is provided.
+ */
 export const useRovingTabIndex = <T extends HTMLElement>(
     inputRef?: RefObject<T>,
-): [FocusHandler, boolean, RefObject<T>] => {
+): [FocusHandler, boolean, (node: T | null) => void, RefObject<T | null>] => {
     const context = useContext(RovingTabIndexContext);
-    let ref = useRef<T>(null);
+
+    let nodeRef = useRef<T | null>(null);
 
     if (inputRef) {
         // if we are given a ref, use it instead of ours
-        ref = inputRef;
+        nodeRef = inputRef;
     }
 
-    // setup (after refs)
-    useEffect(() => {
-        context.dispatch({
-            type: Type.Register,
-            payload: { ref },
-        });
-        // teardown
-        return () => {
+    const ref = useCallback((node: T | null) => {
+        if (node) {
+            nodeRef.current = node;
+            context.dispatch({
+                type: Type.Register,
+                payload: { node },
+            });
+        } else {
             context.dispatch({
                 type: Type.Unregister,
-                payload: { ref },
+                payload: { node: nodeRef.current! },
             });
-        };
+            nodeRef.current = null;
+        }
     }, []); // eslint-disable-line react-hooks/exhaustive-deps
 
     const onFocus = useCallback(() => {
         context.dispatch({
             type: Type.SetFocus,
-            payload: { ref },
+            payload: { node: nodeRef.current } as { node: T },
         });
     }, []); // eslint-disable-line react-hooks/exhaustive-deps
 
-    const isActive = context.state.activeRef === ref;
-    return [onFocus, isActive, ref];
+    const isActive = context.state.activeNode === nodeRef.current;
+    return [onFocus, isActive, ref, nodeRef];
 };
 
 // re-export the semantic helper components for simplicity
diff --git a/src/accessibility/roving/RovingTabIndexWrapper.tsx b/src/accessibility/roving/RovingTabIndexWrapper.tsx
index 93436ef4b5..c6a67ac783 100644
--- a/src/accessibility/roving/RovingTabIndexWrapper.tsx
+++ b/src/accessibility/roving/RovingTabIndexWrapper.tsx
@@ -13,7 +13,11 @@ import { FocusHandler, Ref } from "./types";
 
 interface IProps {
     inputRef?: Ref;
-    children(renderProps: { onFocus: FocusHandler; isActive: boolean; ref: Ref }): ReactElement<any, any>;
+    children(renderProps: {
+        onFocus: FocusHandler;
+        isActive: boolean;
+        ref: (node: HTMLElement | null) => void;
+    }): ReactElement<any, any>;
 }
 
 // Wrapper to allow use of useRovingTabIndex outside of React Functional Components.
diff --git a/test/unit-tests/accessibility/RovingTabIndex-test.tsx b/test/unit-tests/accessibility/RovingTabIndex-test.tsx
index 8130167db4..d3d75397b7 100644
--- a/test/unit-tests/accessibility/RovingTabIndex-test.tsx
+++ b/test/unit-tests/accessibility/RovingTabIndex-test.tsx
@@ -28,6 +28,12 @@ const checkTabIndexes = (buttons: NodeListOf<HTMLElement>, expectations: number[
     expect([...buttons].map((b) => b.tabIndex)).toStrictEqual(expectations);
 };
 
+const createButtonElement = (text: string): HTMLButtonElement => {
+    const button = document.createElement("button");
+    button.textContent = text;
+    return button;
+};
+
 // give the buttons keys for the fibre reconciler to not treat them all as the same
 const button1 = <Button key={1}>a</Button>;
 const button2 = <Button key={2}>b</Button>;
@@ -123,11 +129,7 @@ describe("RovingTabIndex", () => {
                         {button2}
                         <RovingTabIndexWrapper>
                             {({ onFocus, isActive, ref }) => (
-                                <button
-                                    onFocus={onFocus}
-                                    tabIndex={isActive ? 0 : -1}
-                                    ref={ref as React.RefObject<HTMLButtonElement>}
-                                >
+                                <button onFocus={onFocus} tabIndex={isActive ? 0 : -1} ref={ref}>
                                     .
                                 </button>
                             )}
@@ -147,75 +149,75 @@ describe("RovingTabIndex", () => {
 
     describe("reducer functions as expected", () => {
         it("SetFocus works as expected", () => {
-            const ref1 = React.createRef<HTMLElement>();
-            const ref2 = React.createRef<HTMLElement>();
+            const node1 = createButtonElement("Button 1");
+            const node2 = createButtonElement("Button 2");
             expect(
                 reducer(
                     {
-                        activeRef: ref1,
-                        refs: [ref1, ref2],
+                        activeNode: node1,
+                        nodes: [node1, node2],
                     },
                     {
                         type: Type.SetFocus,
                         payload: {
-                            ref: ref2,
+                            node: node2,
                         },
                     },
                 ),
             ).toStrictEqual({
-                activeRef: ref2,
-                refs: [ref1, ref2],
+                activeNode: node2,
+                nodes: [node1, node2],
             });
         });
 
         it("Unregister works as expected", () => {
-            const ref1 = React.createRef<HTMLElement>();
-            const ref2 = React.createRef<HTMLElement>();
-            const ref3 = React.createRef<HTMLElement>();
-            const ref4 = React.createRef<HTMLElement>();
+            const button1 = createButtonElement("Button 1");
+            const button2 = createButtonElement("Button 2");
+            const button3 = createButtonElement("Button 3");
+            const button4 = createButtonElement("Button 4");
 
             let state: IState = {
-                refs: [ref1, ref2, ref3, ref4],
+                nodes: [button1, button2, button3, button4],
             };
 
             state = reducer(state, {
                 type: Type.Unregister,
                 payload: {
-                    ref: ref2,
+                    node: button2,
                 },
             });
             expect(state).toStrictEqual({
-                refs: [ref1, ref3, ref4],
+                nodes: [button1, button3, button4],
             });
 
             state = reducer(state, {
                 type: Type.Unregister,
                 payload: {
-                    ref: ref3,
+                    node: button3,
                 },
             });
             expect(state).toStrictEqual({
-                refs: [ref1, ref4],
+                nodes: [button1, button4],
             });
 
             state = reducer(state, {
                 type: Type.Unregister,
                 payload: {
-                    ref: ref4,
+                    node: button4,
                 },
             });
             expect(state).toStrictEqual({
-                refs: [ref1],
+                nodes: [button1],
             });
 
             state = reducer(state, {
                 type: Type.Unregister,
                 payload: {
-                    ref: ref1,
+                    node: button1,
                 },
             });
             expect(state).toStrictEqual({
-                refs: [],
+                nodes: [],
             });
         });
 
@@ -235,122 +237,122 @@ describe("RovingTabIndex", () => {
             );
 
             let state: IState = {
-                refs: [],
+                nodes: [],
             };
 
             state = reducer(state, {
                 type: Type.Register,
                 payload: {
-                    ref: ref1,
+                    node: ref1.current!,
                 },
             });
             expect(state).toStrictEqual({
-                activeRef: ref1,
-                refs: [ref1],
+                activeNode: ref1.current,
+                nodes: [ref1.current],
             });
 
             state = reducer(state, {
                 type: Type.Register,
                 payload: {
-                    ref: ref2,
+                    node: ref2.current!,
                 },
             });
             expect(state).toStrictEqual({
-                activeRef: ref1,
-                refs: [ref1, ref2],
+                activeNode: ref1.current,
+                nodes: [ref1.current, ref2.current],
             });
 
             state = reducer(state, {
                 type: Type.Register,
                 payload: {
-                    ref: ref3,
+                    node: ref3.current!,
                 },
             });
             expect(state).toStrictEqual({
-                activeRef: ref1,
-                refs: [ref1, ref2, ref3],
+                activeNode: ref1.current,
+                nodes: [ref1.current, ref2.current, ref3.current],
             });
 
             state = reducer(state, {
                 type: Type.Register,
                 payload: {
-                    ref: ref4,
+                    node: ref4.current!,
                 },
             });
             expect(state).toStrictEqual({
-                activeRef: ref1,
-                refs: [ref1, ref2, ref3, ref4],
+                activeNode: ref1.current,
+                nodes: [ref1.current, ref2.current, ref3.current, ref4.current],
             });
 
             // test that the automatic focus switch works for unmounting
             state = reducer(state, {
                 type: Type.SetFocus,
                 payload: {
-                    ref: ref2,
+                    node: ref2.current!,
                 },
             });
             expect(state).toStrictEqual({
-                activeRef: ref2,
-                refs: [ref1, ref2, ref3, ref4],
+                activeNode: ref2.current,
+                nodes: [ref1.current, ref2.current, ref3.current, ref4.current],
             });
 
             state = reducer(state, {
                 type: Type.Unregister,
                 payload: {
-                    ref: ref2,
+                    node: ref2.current!,
                 },
             });
             expect(state).toStrictEqual({
-                activeRef: ref3,
-                refs: [ref1, ref3, ref4],
+                activeNode: ref3.current,
+                nodes: [ref1.current, ref3.current, ref4.current],
             });
 
             // test that the insert into the middle works as expected
             state = reducer(state, {
                 type: Type.Register,
                 payload: {
-                    ref: ref2,
+                    node: ref2.current!,
                 },
             });
             expect(state).toStrictEqual({
-                activeRef: ref3,
-                refs: [ref1, ref2, ref3, ref4],
+                activeNode: ref3.current,
+                nodes: [ref1.current, ref2.current, ref3.current, ref4.current],
             });
 
             // test that insertion at the edges works
             state = reducer(state, {
                 type: Type.Unregister,
                 payload: {
-                    ref: ref1,
+                    node: ref1.current!,
                 },
             });
             state = reducer(state, {
                 type: Type.Unregister,
                 payload: {
-                    ref: ref4,
+                    node: ref4.current!,
                 },
             });
             expect(state).toStrictEqual({
-                activeRef: ref3,
-                refs: [ref2, ref3],
+                activeNode: ref3.current,
+                nodes: [ref2.current, ref3.current],
             });
 
             state = reducer(state, {
                 type: Type.Register,
                 payload: {
-                    ref: ref1,
+                    node: ref1.current!,
                 },
             });
 
             state = reducer(state, {
                 type: Type.Register,
                 payload: {
-                    ref: ref4,
+                    node: ref4.current!,
                 },
             });
             expect(state).toStrictEqual({
-                activeRef: ref3,
-                refs: [ref1, ref2, ref3, ref4],
+                activeNode: ref3.current,
+                nodes: [ref1.current, ref2.current, ref3.current, ref4.current],
             });
         });
     });