Skip to content

Instantly share code, notes, and snippets.

@light0x00
Last active March 5, 2022 14:53
Show Gist options
  • Save light0x00/71ecdcdcf3d137523a934b051ccb9203 to your computer and use it in GitHub Desktop.
Save light0x00/71ecdcdcf3d137523a934b051ccb9203 to your computer and use it in GitHub Desktop.
HashMap 简化版实现
package com.light.practice.algorithm.collections;
import lombok.extern.slf4j.Slf4j;
import java.util.Objects;
/**
* @author light
* @since 2019/9/8
*/
@Slf4j
public class MyHashMap<K, V> {
private static class Node<K, V> {
K key;
V val;
int hash;
Node<K, V> next;
public Node(K key, V val, int hash) {
this.key = key;
this.val = val;
this.hash = hash;
}
}
/**
* MAXIMUM_CAPACITY 设置为 1 << 30 的原因:
* 1 << 30 得到结果的二进制表示为: 01000000 00000000 00000000 00000000
* 这个数可写作: 2的30次幂,是32位有符号数能表示的最大的2的N次幂了
* 由于resize的实现机理,扩容后的capacity需要是2的N次幂,而在32位有符号数范围内,N最大为30
* 因此,capacity最大为 2的30次方,也即 1<<30.
*/
private static final int MAXIMUM_CAPACITY = 1 << 30;
private int size;
private int capacity;
private Node<K, V>[] table;
public MyHashMap(int initialCapacity) {
this.capacity = tableSizeFor(initialCapacity);
@SuppressWarnings({"unchecked"})
Node<K, V>[] newTab = (Node<K, V>[]) new Node[capacity];
table = newTab;
}
private int hashToIndex(int hash) {
return hash & (capacity - 1);
}
private static int hash(Object key) {
//右移16位的意图是: 当tabLen<2^16时,计算hash对应下标时`hash&(tabLen-1)`,由于tabLen的高位全为0,因此仅低位会影响下标
//而将高位右移16位XOR原值,则可让高位影响原低位的值,从而影响下标的计算结果.
int h;
return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16);
}
public V get(K k) {
int keyHash = hash(k);
int index = hashToIndex(keyHash);
Node<K, V> entry = table[index];
if (entry != null) {
for (Node<K, V> e = entry; e != null; e = e.next) {
if (Objects.equals(e.key, k)) {
return e.val;
}
}
}
return null;
}
public V put(K key, V value) {
//得到索引
int hash = hash(key);
int index = hashToIndex(hash);
log.debug("put `{}(hash:{})` to bulk:{}", key, hash, index);
//得到链表入口节点
Node<K, V> bucket = table[index];
//该位置没有,直接放置
if (bucket == null) {
table[index] = new Node<>(key, value, hash);
} else {
for (Node<K, V> e =bucket; ; ) {
if (Objects.equals(key, e.key)) {
V oldVal = e.val;
e.val = value;
return oldVal;
} else if (e.next == null) {
e.next = new Node<>(key, value, hash);
break;
}
e = e.next;
}
}
++size;
if (size >= capacity) {
debug();
resize();
debug();
}
return null;
}
public int size() {
return this.size;
}
private void resize() {
if (capacity >= MAXIMUM_CAPACITY) {
return;
}
int oldCap = capacity;
int newCap = oldCap << 1;
@SuppressWarnings({"unchecked"})
Node<K, V>[] newTab = (Node<K, V>[]) new Node[newCap];
Node<K, V>[] oldTab = table;
for (int i = 0; i < oldTab.length; i++) {
Node<K, V> bucket = oldTab[i];
Node<K, V> loHead = null;
Node<K, V> loTail = null;
Node<K, V> hiHead = null;
Node<K, V> hiTail = null;
for (Node<K, V> e = bucket; e != null; e = e.next) {
// if ((e.hash / oldCap) % 2 == 0) {
if ((e.hash & oldCap) == 0) {
if (loHead == null) {
loHead = e;
} else {
loTail.next = e;
}
loTail = e;
} else {
if (hiHead == null) {
hiHead = e;
} else {
hiTail.next = e;
}
hiTail = e;
}
}
//切断之前的链表节点间的连接
if (loTail != null)
loTail.next = null;
if (hiTail != null)
hiTail.next = null;
newTab[i] = loHead;
newTab[i + oldCap] = hiHead;
}
capacity = newCap;
table = newTab;
}
public V remove(K key) {
int hash = hash(key);
int index = hashToIndex(hash);
Node<K, V> bucket = table[index];
Node<K, V> prev = null;
Node<K, V> target = null;
if (null == bucket) {
return null;
} else if (bucket.hash == hash || Objects.equals(key, bucket.key)) { //存在于链表头
target = bucket;
} else {
prev = bucket;
for (Node<K, V> e = bucket.next; e != null; e = e.next) { //存在于链表中
if (e.hash == hash || (Objects.equals(key, e.key))) {
target = e;
break;
}
prev = e;
}
}
if (target != null) {
if (target == bucket) {
table[index] = target.next;
} else {
prev.next = target.next;
}
size--;
return target.val;
}
return null;
}
public boolean contains(K key) {
return get(key) != null;
}
/**
* 计算大于等于cap的最小2次幂
*/
private static int tableSizeFor(int cap) {
//排除边界情况
if (cap <= 1) {
return 1;
} else if (cap >= MAXIMUM_CAPACITY) {
return MAXIMUM_CAPACITY;
}
//n的取值范围: 1~(2^30-1)
int n = cap - 1;
n |= n >>> 1;
n |= n >>> 2;
n |= n >>> 4;
n |= n >>> 8;
n |= n >>> 16;
return n + 1;
}
public void debug() {
Node<K, V> bucket;
for (int i = 0; i < table.length; i++) {
bucket = table[i];
System.out.print(i + ":");
for (Node<K, V> e = bucket; e != null; e = e.next) {
System.out.print(String.format("(%s:%s)", e.key, e.val));
}
System.out.println();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment