LeetCode 95. Unique Binary Search Trees II (二叉搜索树计数,卡特兰数)

Given an integer n, generate all structurally unique BST’s (binary search trees) that store values 1 … n.

Example:

Input: 3
Output:
[
  [1,null,3,2],
  [3,2,null,1],
  [3,1,null,null,2],
  [2,1,3],
  [1,null,2,null,3]
]
Explanation:
The above output corresponds to the 5 unique BST's shown below:

   1         3     3      2      1
    \       /     /      / \      \
     3     2     1      1   3      2
    /     /       \                 \
   2     1         2                 3

解法
本题绕了了很大弯。本来的方法,找一个节点,递归找左树和右树,然后将左右树合并。(合并过程本来用dfs,越写越复杂,看了别人的代码改了一下) 用了比较麻烦的代码实现,如下
(这里面的build函数有一个当前节点,可以进行优化)

# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, x):
# self.val = x
# self.left = None
# self.right = None

class Solution:
    def generateTrees(self, n):
        """ :type n: int :rtype: List[TreeNode] """
        ans = []
        vis = [0 for _ in range(0, n + 1)]

        def build(cur, lower, upper):
            nonlocal n
            res = []
            root = TreeNode(cur)
            if sum(vis) == n:
                return [root]
            L = []
            R = []
            for i in range(lower, cur):
                if vis[i] == 0:
                    L.append(i)

            for i in range(cur, upper):
                if vis[i] == 0:
                    R.append(i)

            if len(L) == 0 and len(R) == 0:
                return [root]

            if len(L) == 0:
                for r in R:
                    vis[r] = 1
                    left = [None]
                    right = build(r, cur+1, upper)
                    vis[r] = 0
                    pairs = [(x, y) for x in left for y in right]
                    for pair in pairs:
                        root = TreeNode(cur)
                        root.left, root.right = pair
                        res.append(root)
            elif len(R) == 0:
                for l in L:
                    vis[l] = 1
                    right = [None]
                    left = build(l, lower, cur)
                    vis[l] = 0
                    pairs = [(x, y) for x in left for y in right]
                    for pair in pairs:
                        root = TreeNode(cur)
                        root.left, root.right = pair
                        res.append(root)
            else:
                for l in L:
                    for r in R:
                        vis[l] = 1
                        vis[r] = 1
                        left = build(l, lower, cur)
                        right = build(r, cur+1, upper)
                        vis[l] = 0
                        vis[r] = 0
                        pairs = [(x, y) for x in left for y in right]
                        for pair in pairs:
                            root = TreeNode(cur)
                            root.left, root.right = pair
                            res.append(root)
            return res

        for i in range(1, n + 1):
            vis[i] = 1
            res = build(i, 1, n+1)
            ans += res
            vis[i] = 0
        return ans

优化之后的代码

class Solution:
    def generateTrees(self, n):
        """ :type n: int :rtype: List[TreeNode] """

        def build(lower, upper):
            res = []
            if lower > upper:
                return [None]
            if lower == upper:
                return [TreeNode(lower)]
            for i in range(lower, upper+1):
                left_subtree = build(lower, i-1)
                right_subtree = build(i+1, upper)
                pairs = [(l, r) for l in left_subtree for r in right_subtree]
                for pair in pairs:
                    node = TreeNode(i)
                    node.left, node.right = pair
                    res.append(node)
            return res

        return [] if n==0 else build(1, n)
点赞