Kth Smallest Element in a BST

Problem

https://leetcode.com/problems/kth-smallest-element-in-a-bst/

Given the root of a binary search tree, and an integer k, return the k:sup:`th` smallest value (1-indexed) of all the values of the nodes in the tree.

Example 1:

image1

Input: root = [3,1,4,null,2], k = 1
Output: 1

Example 2:

image2

Input: root = [5,3,6,2,4,null,null,1], k = 3
Output: 3

Constraints:

  • The number of nodes in the tree is n.

  • 1 <= k <= n <= 10:sup:`4`

  • 0 <= Node.val <= 10:sup:`4`

Follow up: If the BST is modified often (i.e., we can do insert and delete operations) and you need to find the kth smallest frequently, how would you optimize?

Pattern

Tree, Depth-First Search, Binary Search Tree, Binary Tree

Approaches

Code

from __future__ import annotations

from collections import deque


class TreeNode:
    """Node in a binary tree."""

    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

    @classmethod
    def from_list(cls, vals: list[int | None]) -> TreeNode | None:
        if not vals:
            return None
        root = TreeNode(vals[0])
        queue = [root]
        i = 1
        while i < len(vals):
            node = queue.pop(0)
            if i < len(vals) and vals[i] is not None:
                node.left = TreeNode(vals[i])
                queue.append(node.left)
            i += 1
            if i < len(vals) and vals[i] is not None:
                node.right = TreeNode(vals[i])
                queue.append(node.right)
            i += 1
        return root

    def to_list(self) -> list:
        result = []
        queue = [self]
        while queue:
            node = queue.pop(0)
            if node:
                result.append(node.val)
                queue.append(node.left)
                queue.append(node.right)
            else:
                result.append(None)
        while result and result[-1] is None:
            result.pop()
        return result


def kthSmallest(root: TreeNode | None, k: int) -> int:
    stack = deque()
    curr = root

    while stack or curr:
        while curr:
            stack.append(curr)
            curr = curr.left

        curr = stack.pop()
        k -= 1
        if k == 0:
            return curr.val

        curr = curr.right

Test

>>> from kth_smallest_element_in_a_bst__inorder_traversal import kthSmallest, TreeNode
>>> kthSmallest(TreeNode.from_list([3, 1, 4, None, 2]), 1)
1
>>> kthSmallest(TreeNode.from_list([5, 3, 6, 2, 4, None, None, 1]), 3)
3
class kth_smallest_element_in_a_bst__inorder_traversal.TreeNode(val=0, left=None, right=None)

Bases: object

Node in a binary tree.

classmethod from_list(vals: list[int | None]) TreeNode | None
to_list() list
kth_smallest_element_in_a_bst__inorder_traversal.kthSmallest(root: TreeNode | None, k: int) int