import type { Dispatch, SetStateAction } from 'react'; import { useState, useCallback, useMemo } from 'react'; import type { TreeNode, FetchChildrenFunction } from './tree'; export interface ExpandedState { [key: string]: boolean; } interface LoadingState { [key: string]: boolean; } interface LoadedChildren< Type extends string, Context extends Record, > { [key: string]: TreeNode[]; } interface HasMoreChildrenState { [key: string]: boolean; } export function useTree< Type extends string, Context extends Record, >({ fetchChildren, expanded: expandedProp, setExpanded: setExpandedProp, }: { fetchChildren?: FetchChildrenFunction; expanded?: ExpandedState; setExpanded?: Dispatch>; }) { const [expandedInternal, setExpandedInternal] = useState({}); const expanded = useMemo( () => expandedProp ?? expandedInternal, [expandedProp, expandedInternal] ); const setExpanded = useCallback( (value: SetStateAction) => { if (setExpandedProp) { setExpandedProp(value); } else { setExpandedInternal(value); } }, [setExpandedProp, setExpandedInternal] ); const [loading, setLoading] = useState({}); const [loadedChildren, setLoadedChildren] = useState< LoadedChildren >({}); const [hasMoreChildren, setHasMoreChildren] = useState({}); const mergeChildren = useCallback( ( staticChildren: TreeNode[] = [], fetchedChildren: TreeNode[] = [] ) => { const fetchedChildrenIds = new Set( fetchedChildren.map((child) => child.id) ); const uniqueStaticChildren = staticChildren.filter( (child) => !fetchedChildrenIds.has(child.id) ); return [...uniqueStaticChildren, ...fetchedChildren]; }, [] ); const toggleNode = useCallback( async ( nodeId: string, nodeType: Type, nodeContext: Context[Type], staticChildren?: TreeNode[] ) => { if (expanded[nodeId]) { // If we're collapsing, just update expanded state setExpanded((prev) => ({ ...prev, [nodeId]: false })); return; } // Get any previously fetched children const previouslyFetchedChildren = loadedChildren[nodeId] || []; // If we have static children, merge them with any previously fetched children if (staticChildren?.length) { const mergedChildren = mergeChildren( staticChildren, previouslyFetchedChildren ); setLoadedChildren((prev) => ({ ...prev, [nodeId]: mergedChildren, })); // Only show "more loading" if we haven't fetched children before setHasMoreChildren((prev) => ({ ...prev, [nodeId]: !previouslyFetchedChildren.length, })); } // Set expanded state immediately to show static/previously fetched children setExpanded((prev) => ({ ...prev, [nodeId]: true })); // If we haven't loaded dynamic children yet if (!previouslyFetchedChildren.length) { setLoading((prev) => ({ ...prev, [nodeId]: true })); try { const fetchedChildren = await fetchChildren?.( nodeId, nodeType, nodeContext ); // Merge static and newly fetched children const allChildren = mergeChildren( staticChildren || [], fetchedChildren ); setLoadedChildren((prev) => ({ ...prev, [nodeId]: allChildren, })); setHasMoreChildren((prev) => ({ ...prev, [nodeId]: false, })); } catch (error) { console.error('Error loading children:', error); } finally { setLoading((prev) => ({ ...prev, [nodeId]: false })); } } }, [expanded, loadedChildren, fetchChildren, mergeChildren, setExpanded] ); return { expanded, loading, loadedChildren, hasMoreChildren, toggleNode, }; }