CS61B – Lab 06 BSTMap

Lab 06 要求我们基于提供的 Map61B 接口实现一个基于 BST (Binary Search Tree,二叉搜索树)的 Map。接口中包含的是操作 Map 这类数据结构的基本方法,尽管讲义对 Remove 、KeySet 这两类方法不做要求,但出于摸透底层的角度考虑在这还是实现一下为佳。

Binary Search Tree – 二叉搜索树

BST 是一种二叉树结构(废话),它的主要特点是每个节点(Node)的左子树(Left Child)的所有节点值小于该节点值,而其右子树(Right Child)的所有节点值都大于该节点值,从而能够高效的实现有序数据的相关操作。

这里的相关操作包括查找(Find)、插入(Insert)和删除(Del)等, BST 能够在平均情况 $O(\log n)$ 的时间复杂度内完成上述操作,这比线性结构(如数组)更为高效,因此特别适用于频繁进行上述操作的场景。

来源

一般认为 BST 的思想源于 二分查找,通过将数据组织成树形结构,每个节点作为分界点,利用数据的有序性每次对半快速缩小搜索范围。

优势/不足

  • Binary Tree:相较于普通二叉树, BST强制了数据的有序性,提高了搜索效率
  • AVL Tree / 红黑树:BST 因为缺少转置的平衡操作,逻辑上实现更为简单
  • BST 天然支持有序遍历(中序遍历)
  • 但在极端情况下其性能可能退化至 $O(n)$ 且平均性能稍差。

时间/空间复杂度

  • 时间复杂度:
    • 搜索、插入、删除:平均 O(log n),最坏 O(n)。
      原因:在平衡情况下,树高度为 log n,每次操作比较次数与高度成正比;最坏情况(如插入有序数据)树退化为链表,高度为 n。
    • 遍历:O(n),因为需要访问每个节点一次。
  • 空间复杂度:O(n),用于存储 n 个节点,每个节点包含值、左指针和右指针

具体实现

搜索(Find/Search)

  • 步骤1:从根节点开始。
  • 步骤2:比较目标值与当前节点值。
    • 如果相等,返回该节点。
    • 如果目标值更小,进入左子树重复步骤2。
    • 如果目标值更大,进入右子树重复步骤2。
  • 步骤3:如果到达空节点,返回未找到。

插入(Insert)

  • 步骤1:类似搜索,找到应插入的位置(空子树)。
  • 步骤2:创建新节点并链接到父节点。
  • 步骤3:保持 BST 性质(左小右大)。

删除(Delete)

删除的情况比较复杂,需要考虑当前节点的类型

  • 情况1:删除叶子节点(无子节点)→ 直接移除。
  • 情况2:删除有一个子节点的节点 → 用子节点替换。
  • 情况3:删除有两个子节点的节点 → 用中序遍历后继节点(右子树的最小值)替换,然后删除后继节点

Map

Map 是一种很常见的数据结构,类似于数学概念中的单射函数。它允许通过键直接访问对应的值,避免了在数组或列表中线性搜索的低效性。

BST Map

讲义在这的内容是基于 BST 实现 Map 结构,BST 在这的主要优势是保持键的天然有序性,这能使范围查询、最小/最大操作更为高效($O(\log n)$) 但实际上更为常用的 Map 实现似乎是依靠 Hash 进行的,HashMap 虽然无序,但平均性能更好(查询 $O(1)$),而 Map 的遍历操作是不常见的(或者说对 Map 进行频繁遍历是不明智的)。

但是实现是可以的,基本思想和 BST 一致,就是将每个键值对存储为 BST 的一个节点,节点包含键、值、左子指针和右子指针。操作基于键的比较:

  • 删除时,处理节点替换以保持 BST 性质。
  • 插入时,比较键大小决定左右子树路径。
  • 查找时,类似二分搜索缩小范围。

这里还需要注意的就是讲义开篇提及的一个问题,天然有序就要求传入数据必须可以比较。在 Java 中,我们传入的键必须就实现 Comparable 接口或提供比较器,否则无法构建 BST。如果键不可比较,会抛出 ClassCastException。建议在构造函数中要求键类型实现 Comparable 或传入自定义比较器。

Java 中主要提供两种方法来保证键可比较:

  • 使用泛型约束(要求键实现 Comparable 接口)
class BSTMap<K extends Comparable<K>, V> {
    // 实现细节...
}
  • 使用自定义比较器(通过 Comparator 对象)
class BSTMap<K, V> {
    private final Comparator<K> comparator;
    
    public BSTMap(Comparator<K> comparator) {
        this.comparator = comparator;
    }
    
    // 在比较时使用 comparator.compare(key1, key2) 而不是 key1.compareTo(key2)
}

所以实现起来就很容易了

具体实现

无论是哪种比较实现,有一点是肯定的,我们需要一个 Node 私有内部类来帮助我们构建树

public class BSTMap<K, V> implements Map61B<K, V> {

    private class BSTNode {
        final K key;
        V value;
        BSTNode left;
        BSTNode right;
        int height;  // 可选:用于平衡检查

        BSTNode(K key, V value){
            this.key = key;
            this.value = value;
            this.left = null;
            this.right = null;
            this.height = 1;
        }
    }
 
   // 其他实现...
}

那么接下来分别考虑两种比较实现方式

基于泛型约束

class BSTNode&lt;K extends Comparable&lt;K>, V> {
    K key;
    V value;
    BSTNode&lt;K, V> left;
    BSTNode&lt;K, V> right;

    BSTNode(K key, V value) {
        this.key = key;
        this.value = value;
    }
}

public class BSTMap&lt;K extends Comparable&lt;K>, V> {
    private BSTNode&lt;K, V> root;
    private int size;

    // 插入或更新键值对
    public void put(K key, V value) {
        root = putRec(root, key, value);
    }

    private BSTNode&lt;K, V> putRec(BSTNode&lt;K, V> node, K key, V value) {
        if (node == null) {
            size++;
            return new BSTNode&lt;>(key, value);
        }
        int cmp = key.compareTo(node.key);
        if (cmp &lt; 0) {
            node.left = putRec(node.left, key, value);
        } else if (cmp > 0) {
            node.right = putRec(node.right, key, value);
        } else {
            // 键已存在,更新值
            node.value = value;
        }
        return node;
    }

    // 查找键对应的值
    public V get(K key) {
        return getRec(root, key);
    }

    private V getRec(BSTNode&lt;K, V> node, K key) {
        if (node == null) {
            return null;
        }
        int cmp = key.compareTo(node.key);
        if (cmp &lt; 0) {
            return getRec(node.left, key);
        } else if (cmp > 0) {
            return getRec(node.right, key);
        } else {
            return node.value;
        }
    }

    // 删除键值对
    public void remove(K key) {
        root = removeRec(root, key);
    }

    private BSTNode&lt;K, V> removeRec(BSTNode&lt;K, V> node, K key) {
        if (node == null) {
            return null;
        }
        int cmp = key.compareTo(node.key);
        if (cmp &lt; 0) {
            node.left = removeRec(node.left, key);
        } else if (cmp > 0) {
            node.right = removeRec(node.right, key);
        } else {
            // 找到要删除的节点
            size--;
            // 情况1: 无左子节点
            if (node.left == null) {
                return node.right;
            }
            // 情况2: 无右子节点
            if (node.right == null) {
                return node.left;
            }
            // 情况3: 有两个子节点
            BSTNode&lt;K, V> successor = findMin(node.right);
            node.key = successor.key;
            node.value = successor.value;
            node.right = removeRec(node.right, successor.key);
        }
        return node;
    }

    private BSTNode&lt;K, V> findMin(BSTNode&lt;K, V> node) {
        while (node.left != null) {
            node = node.left;
        }
        return node;
    }

    // 中序遍历(用于测试有序性)
    public void inorder() {
        inorderRec(root);
        System.out.println();
    }

    private void inorderRec(BSTNode&lt;K, V> node) {
        if (node != null) {
            inorderRec(node.left);
            System.out.print("(" + node.key + ": " + node.value + ") ");
            inorderRec(node.right);
        }
    }

    public boolean containsKey(K key) {
        if(key == null) throw new IllegalArgumentException("Invalid Argument:Key is null");
        return get(root, key) != null;
    }

    public int size() {
        return size;
    }

    public void clear() {
        root = null;
        size = 0;
    }
}

以及基于自定义比较器

import java.util.*;

public class BSTMap<K, V> implements Map61B<K, V> {

    private class BSTNode {
        final K key;
        V value;
        BSTNode left;
        BSTNode right;
        int height;  // 可选:用于平衡检查

        BSTNode(K key, V value){
            this.key = key;
            this.value = value;
            this.left = null;
            this.right = null;
            this.height = 1;
        }

        // 更新
        void updateInfo(){
            int leftHeight = (left == null) ? 0 : left.height;
            int rightHeight = (right == null) ? 0 : right.height;
            this.height = 1 + Math.max(leftHeight, rightHeight);
        }
    }

    private BSTNode root;
    private int size;
    private final Comparator<? super K> comparator;

    public BSTMap(){
        this.comparator = null;
    }

    public BSTMap(Comparator<? super K> comparator){
        this.comparator = comparator;
    }

    @Override
    public void put(K key, V value) {
        if(key == null){
            throw new IllegalArgumentException("Invalid Argument:Key is null");
        }
        root = put(root, key, value);
    }

    private BSTNode put(BSTNode node, K key, V value){
        if(node == null){
            size++;
            return new BSTNode(key, value);
        }

        int cmp = compare(key, node.key);
        if (cmp < 0) {
            node.left = put(node.left, key, value);
        } else if (cmp > 0) {
            node.right = put(node.right, key, value);
        } else {
            node.value = value;
        }

        node.updateInfo();
        return node;
    }

    @Override
    public V get(K key) {
        if (key == null) throw new IllegalArgumentException("Invalid Argument:Key is null");
        return get(root, key);
    }

    private V get(BSTNode node, K key){
        if(node == null) return null;

        int cmp = compare(key, node.key);
        if (cmp < 0) {
            return get(node.left, key);
        } else if(cmp > 0) {
            return get(node.right, key);
        } else {
            return node.value;
        }
    }

    @Override
    public boolean containsKey(K key) {
        if(key == null) throw new IllegalArgumentException("Invalid Argument:Key is null");
        return get(root, key) != null;
    }

    @Override
    public int size() {
        return size;
    }

    @Override
    public void clear() {
        root = null;
        size = 0;
    }

    @Override
    public Set<K> keySet() {
        Set<K> keys = new HashSet<>();
        keySet(root, keys);
        return keys;
    }

    private void keySet(BSTNode node, Set<K> keys){
        if(node == null) return;
        keys.add(node.key);
        keySet(node.left, keys);
        keySet(node.right, keys);
    }

    @Override
    public V remove(K key) {
        if (key == null) throw new IllegalArgumentException("Invalid Argument:Key is null");

        V oldValue = get(key);
        if (oldValue != null) {
            root = remove(root, key);
        }
        return oldValue;
    }

    private BSTNode remove(BSTNode node, K key){
        if(node == null) return null;

        int cmp = compare(key, node.key);
        if (cmp < 0) {
            node.left = remove(node.left, key);
        } else if (cmp > 0) {
            node.right = remove(node.right, key);
        } else {
            // 找到要删除的节点
            if (node.left == null) {
                size--;
                return node.right;
            }
            if (node.right == null) {
                size--;
                return node.left;
            }

            // 有两个子节点:用后继节点替换
            BSTNode successor = min(node.right);
            // 创建新节点(因为key是final)
            BSTNode newNode = new BSTNode(successor.key, successor.value);
            newNode.left = node.left;
            newNode.right = removeMin(node.right);
            node = newNode;
            // removeMin中已经减少size,这里不需要再减
        }

        node.updateInfo();
        return node;
    }

    private BSTNode min(BSTNode node){
        if (node == null) return null;
        while (node.left != null){
            node = node.left;
        }
        return node;
    }

    private BSTNode removeMin(BSTNode node){
        if (node == null) return null;
        if (node.left == null){
            size--;
            return node.right;
        }
        node.left = removeMin(node.left);
        node.updateInfo();
        return node;
    }

    @SuppressWarnings("unchecked")
    private int compare(K k1, K k2){
        if(comparator != null){
            return comparator.compare(k1, k2);
        } else {
            return ((Comparable<? super K>)k1).compareTo(k2);
        }
    }
}

Iterator 类的方法比较特殊,他需要返回一个 Iterator 的接口作为迭代器。对于 BST Map,我们设计的 Iterator 应该按键的自然顺序(或比较器定义的顺序)返回键值对。同时需要能够保存当前状态和处理并发情况。

通常情况是通过 Stack 类(或者类 Stack 类,比如 ArrayDeque)来进行实现,并发情况则通过计数器进行检测。

Iterator 下要至少实现 next 和 hasNext 两个方法,且两个方法的复杂度建议 $O(1)$

在迭代器构造函数中,我们需要从根节点开始,将所有左子节点压入栈(直到左子节点为 null)。这样就能确保栈顶总是下一个要访问的最小键节点。

public Iterator<K> iterator() {
    return new BSTIterator();
}

private class BSTIterator implements Iterator<K> {
    private final Stack<BSTNode> stack;

    public BSTIterator(){
        stack = new Stack<>();
        pushLeft(root);
    }

    private void pushLeft(BSTNode node){
        while(node != null){
            stack.push(node);
            node = node.left;
        }
    }

    @Override
    public boolean hasNext() {
        return !stack.isEmpty();
    }

    @Override
    public K next() {
        if (!hasNext()) throw new NoSuchElementException();

        BSTNode node = stack.pop();
        pushLeft(node.right);
        return node.key;
    }
}

以上

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注