import React from 'react';

import { getFocusableElements, getElementToFocusInside } from './helpers';

const focusLastElement = (event: React.FocusEvent<HTMLDivElement>) => {
    const parent = event.target.parentNode;
    const focusableElements = getFocusableElements(parent);
    if (focusableElements.length <= 2) {
        return;
    }
    focusableElements[focusableElements.length - 2].focus();
};

const focusFirstElement = (event: React.FocusEvent<HTMLDivElement>) => {
    const parent = event.target.parentNode;
    const focusableElements = getFocusableElements(parent);
    if (focusableElements.length <= 2) {
        return;
    }
    focusableElements[1].focus();
};

export interface WithFocusTrapProps {
    shouldNotFocusFirstElementOnMount?: boolean;
}

interface WithFocusTrapState {
    tabPressed: boolean;
}

function withFocusTrap<P extends WithFocusTrapProps>(
    Component: React.ComponentType<P>
): React.ComponentType<P & WithFocusTrapProps> {
    return class WithFocusTrap extends React.Component<P & WithFocusTrapProps, WithFocusTrapState> {
        state = {
            tabPressed: false,
        };

        componentDidMount() {
            // focus custom element or first element on mount
            if (this.props.shouldNotFocusFirstElementOnMount) {
                return;
            }
            const customElementToFocus = this.getCustomElementToFocus();
            const firstFocusableElement = this.getFirstFocusableElement();

            if (customElementToFocus) {
                customElementToFocus.focus();
            } else if (firstFocusableElement) {
                firstFocusableElement.focus();
            }
        }
        lastInput: HTMLElement | null = null;

        getFirstFocusableElement = () => {
            if (this.lastInput) {
                const parent = this.lastInput.parentNode;
                const focusableElements = getFocusableElements(parent);
                if (focusableElements.length <= 2) {
                    return null;
                }
                return focusableElements[1];
            }
            return null;
        };

        getCustomElementToFocus = () => {
            if (this.lastInput) {
                const elementToFocusInside = getElementToFocusInside(this.lastInput.parentNode);
                if (elementToFocusInside) {
                    const focusableElements = getFocusableElements(elementToFocusInside);
                    return focusableElements[0];
                }
                return null;
            }
            return null;
        };

        setLastInput = (el: HTMLDivElement) => {
            this.lastInput = el;
        };

        checkForTabs = (event: React.KeyboardEvent<HTMLDivElement>) => {
            if (this.state.tabPressed) {
                return;
            }
            if (event.key === 'Tab') {
                this.setState({
                    tabPressed: true,
                });
            }
        };

        render() {
            const parentClassName = this.state.tabPressed
                ? 'focus-trap-tab-pressed'
                : 'focus-trap-tab-not-pressed';

            return (
                <div className={parentClassName} onKeyDown={this.checkForTabs}>
                    <div tabIndex={0} onFocus={focusLastElement} />
                    <Component {...(this.props as P)} />
                    <div tabIndex={0} onFocus={focusFirstElement} ref={this.setLastInput} />
                </div>
            );
        }
    };
}

export default withFocusTrap;
