字典树查找,Trie,又称字典树、单词查找树,是一种树形结构,用于保存大量的字符串。它的优点是:利用字符串的公共前缀来节约存储空间。
package com.jwetherell.algorithms.data_structures;
import java.util.Arrays;
/**
* A trie, or prefix tree, is an ordered tree data structure that is used to
* store an associative array where the keys are usually strings.
*
* == This is NOT a compact Trie. ==
*
* http://en.wikipedia.org/wiki/Trie
*
* @author Justin Wetherell <phishman3579@gmail.com>
*/
public class Trie<C extends CharSequence> {
private int size = 0;
protected INodeCreator creator = null;
protected Node root = null;
/**
* Default constructor.
*/
public Trie() { }
/**
* Constructor with external Node creator.
*/
public Trie(INodeCreator creator) {
this.creator = creator;
}
/**
* Create a new node for sequence.
*
* @param parent node of the new node.
* @param character which represents this node.
* @param isWord signifies if the node represents a word.
* @return Node which was created.
*/
protected Node createNewNode(Node parent, Character character, boolean isWord) {
return (new Node(parent, character, isWord));
}
/**
* Add sequence to trie.
*
* @param seq to add to the trie.
* @return True if sequence is added to trie or false if it already exists.
*/
public boolean add(C seq) {
return (this.addSequence(seq)!=null);
}
/**
* Add sequence to trie.
*
* @param seq to add to the trie.
* @return Node which was added to trie or null if it already exists.
*/
protected Node addSequence(C seq) {
if (root==null) {
if (this.creator==null) root = createNewNode(null, null, false);
else root = this.creator.createNewNode(null, null, false);
}
int length = (seq.length() - 1);
Node prev = root;
//For each Character in the input, we'll either go to an already define
// child or create a child if one does not exist
for (int i = 0; i < length; i++) {
Node n = null;
Character c = seq.charAt(i);
int index = prev.childIndex(c);
//If 'prev' has a child which starts with Character c
if (index >= 0) {
//Go to the child
n = prev.getChild(index);
} else {
//Create a new child for the character
if (this.creator==null) n = createNewNode(prev, c, false);
else n = this.creator.createNewNode(prev, c, false);
prev.addChild(n);
}
prev = n;
}
//Deal with the first character of the input string not found in the trie
Node n = null;
Character c = seq.charAt(length);
int index = prev.childIndex(c);
//If 'prev' already contains a child with the last Character
if (index >= 0) {
n = prev.getChild(index);
//If the node doesn't represent a string already
if (n.isWord == false) {
//Set the string to equal the full input string
n.character = c;
n.isWord = true;
size++;
return n;
} else {
//String already exists in Trie
return null;
}
} else {
//Create a new node for the input string
if (this.creator==null) n = createNewNode(prev, c, true);
else n = this.creator.createNewNode(prev, c, true);
prev.addChild(n);
size++;
return n;
}
}
/**
* Remove sequence from the trie.
*
* @param sequence to remove from the trie.
* @return True if sequence was remove or false if sequence is not found.
*/
public boolean remove(C sequence) {
if (root == null) return false;
//Find the key in the Trie
Node previous = null;
Node node = root;
int length = (sequence.length() - 1);
for (int i = 0; i <= length; i++) {
char c = sequence.charAt(i);
int index = node.childIndex(c);
if (index >= 0) {
previous = node;
node = node.getChild(index);
} else {
return false;
}
}
if (node.childrenSize > 0) {
//The node which contains the input string and has children, just NULL out the string
node.isWord = false;
} else {
//The node which contains the input string does NOT have children
int index = previous.childIndex(node.character);
//Remove node from previous node
previous.removeChild(index);
//Go back up the trie removing nodes until you find a node which represents a string
while (previous != null && previous.isWord==false && previous.childrenSize == 0) {
if (previous.parent != null) {
int idx = previous.parent.childIndex(previous.character);
if (idx >= 0) previous.parent.removeChild(idx);
}
previous = previous.parent;
}
}
size--;
return true;
}
/**
* Get node which represents the sequence in the trie.
*
* @param seq to find a node for.
* @return Node which represents the sequence or NULL if not found.
*/
protected Node getNode(C seq) {
if (root == null) return null;
//Find the string in the trie
Node n = root;
int length = (seq.length() - 1);
for (int i = 0; i <= length; i++) {
char c = seq.charAt(i);
int index = n.childIndex(c);
if (index >= 0) {
n = n.getChild(index);
} else {
//string does not exist in trie
return null;
}
}
return n;
}
/**
* Does the trie contain the sequence.
*
* @param seq to locate in the trie.
* @return True if sequence is in the trie.
*/
public boolean contains(C seq) {
Node n = this.getNode(seq);
if (n==null || !n.isWord) return false;
//If the node found in the trie does not have it's string
// field defined then input string was not found
return n.isWord;
}
/**
* Number of sequences in the trie.
*
* @return number of sequences in the trie.
*/
public int size() {
return size;
}
/**
* {@inheritDoc}
*/
@Override
public String toString() {
return TriePrinter.getString(this);
}
protected static class Node {
private static final int MINIMUM_SIZE = 2;
protected Node[] children = new Node[MINIMUM_SIZE];
protected int childrenSize = 0;
protected Node parent = null;
protected boolean isWord = false; //Signifies this node represents a word
protected Character character = null; //First character that is different the parent's string
protected Node(Node parent, Character character, boolean isWord) {
this.parent = parent;
this.character = character;
this.isWord = isWord;
}
protected void addChild(Node node) {
if (childrenSize>=children.length) {
children = Arrays.copyOf(children, ((children.length*3)/2)+1);
}
children[childrenSize++] = node;
}
protected boolean removeChild(int index) {
if (index>=childrenSize) return false;
children[index] = null;
childrenSize--;
System.arraycopy(children, index+1, children, index, childrenSize-index);
if (childrenSize>=MINIMUM_SIZE && childrenSize<children.length/2) {
children = Arrays.copyOf(children, childrenSize);
}
return true;
}
protected int childIndex(Character character) {
for (int i = 0; i < childrenSize; i++) {
Node c = children[i];
if (c.character.equals(character)) return i;
}
return Integer.MIN_VALUE;
}
protected Node getChild(int index) {
if (index>=childrenSize) return null;
return children[index];
}
protected int getChildrenSize() {
return childrenSize;
}
/**
* {@inheritDoc}
*/
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
if (isWord == true) builder.append("Node=").append(isWord).append("\n");
for (int i=0; i<childrenSize; i++) {
Node c = children[i];
builder.append(c.toString());
}
return builder.toString();
}
}
protected static interface INodeCreator {
/**
* Create a new node for sequence.
*
* @param parent node of the new node.
* @param character which represents this node.
* @param isWord signifies if the node represents a word.
* @return Node which was created.
*/
public Node createNewNode(Node parent, Character character, boolean type);
}
protected static class TriePrinter {
public static <C extends CharSequence> void print(Trie<C> trie) {
System.out.println(getString(trie));
}
public static <C extends CharSequence> String getString(Trie<C> tree) {
return getString(tree.root, "", null, true);
}
protected static <C extends CharSequence> String getString(Node node, String prefix, String previousString, boolean isTail) {
StringBuilder builder = new StringBuilder();
String string = null;
if (node.character!=null) {
String temp = String.valueOf(node.character);
if (previousString!=null) string = previousString + temp;
else string = temp;
}
builder.append(prefix + (isTail ? "└── " : "├── ") +
((node.isWord == true) ?
("(" + node.character + ") " + string)
:
node.character) +
"\n");
if (node.children != null) {
for (int i = 0; i < node.childrenSize - 1; i++) {
builder.append(getString(node.children[i], prefix + (isTail ? " " : "│ "), string, false));
}
if (node.childrenSize >= 1) {
builder.append(getString(node.children[node.childrenSize - 1], prefix + (isTail ? " " : "│ "), string, true));
}
}
return builder.toString();
}
}
}
测试代码
private static boolean testTrie() {
{
long count = 0;
long addTime = 0L;
long removeTime = 0L;
long beforeAddTime = 0L;
long afterAddTime = 0L;
long beforeRemoveTime = 0L;
long afterRemoveTime = 0L;
long memory = 0L;
long beforeMemory = 0L;
long afterMemory = 0L;
//Trie.
if (debug>1) System.out.println("Trie.");
testNames[testIndex] = "Trie";
count++;
if (debugMemory) beforeMemory = DataStructures.getMemoryUse();
if (debugTime) beforeAddTime = System.currentTimeMillis();
Trie<String> trie = new Trie<String>();
for (int i=0; i<unsorted.length; i++) {
int item = unsorted[i];
String string = String.valueOf(item);
trie.add(string);
if (validateStructure && !(trie.size()==i+1)) {
System.err.println("YIKES!! "+item+" caused a size mismatch.");
handleError(trie);
return false;
}
if (validateContents && !trie.contains(string)) {
System.err.println("YIKES!! "+string+" doesn't exist.");
handleError(trie);
return false;
}
}
if (debugTime) {
afterAddTime = System.currentTimeMillis();
addTime += afterAddTime-beforeAddTime;
if (debug>0) System.out.println("Trie add time = "+addTime/count+" ms");
}
if (debugMemory) {
afterMemory = DataStructures.getMemoryUse();
memory += afterMemory-beforeMemory;
if (debug>0) System.out.println("Trie memory use = "+(memory/count)+" bytes");
}
String invalid = INVALID.toString();
boolean contains = trie.contains(invalid);
boolean removed = trie.remove(invalid);
if (contains || removed) {
System.err.println("Trie invalidity check. contains="+contains+" removed="+removed);
return false;
} else System.out.println("Trie invalidity check. contains="+contains+" removed="+removed);
if (debug>1) System.out.println(trie.toString());
long lookupTime = 0L;
long beforeLookupTime = 0L;
long afterLookupTime = 0L;
if (debugTime) beforeLookupTime = System.currentTimeMillis();
for (int item : unsorted) {
String string = String.valueOf(item);
trie.contains(string);
}
if (debugTime) {
afterLookupTime = System.currentTimeMillis();
lookupTime += afterLookupTime-beforeLookupTime;
if (debug>0) System.out.println("Trie lookup time = "+lookupTime/count+" ms");
}
if (debugTime) beforeRemoveTime = System.currentTimeMillis();
for (int i=0; i<unsorted.length; i++) {
int item = unsorted[i];
String string = String.valueOf(item);
trie.remove(string);
if (validateStructure && !(trie.size()==unsorted.length-(i+1))) {
System.err.println("YIKES!! "+item+" caused a size mismatch.");
handleError(trie);
return false;
}
if (validateContents && trie.contains(string)) {
System.err.println("YIKES!! "+string+" still exists.");
handleError(trie);
return false;
}
}
if (debugTime) {
afterRemoveTime = System.currentTimeMillis();
removeTime += afterRemoveTime-beforeRemoveTime;
if (debug>0) System.out.println("Trie remove time = "+removeTime/count+" ms");
}
contains = trie.contains(invalid);
removed = trie.remove(invalid);
if (contains || removed) {
System.err.println("Trie invalidity check. contains="+contains+" removed="+removed);
return false;
} else System.out.println("Trie invalidity check. contains="+contains+" removed="+removed);
count++;
if (debugMemory) beforeMemory = DataStructures.getMemoryUse();
if (debugTime) beforeAddTime = System.currentTimeMillis();
for (int i=unsorted.length-1; i>=0; i--) {
int item = unsorted[i];
String string = String.valueOf(item);
trie.add(string);
if (validateStructure && !(trie.size()==unsorted.length-i)) {
System.err.println("YIKES!! "+item+" caused a size mismatch.");
handleError(trie);
return false;
}
if (validateContents && !trie.contains(string)) {
System.err.println("YIKES!! "+string+" doesn't exists.");
handleError(trie);
return false;
}
}
if (debugTime) {
afterAddTime = System.currentTimeMillis();
addTime += afterAddTime-beforeAddTime;
if (debug>0) System.out.println("Trie add time = "+addTime/count+" ms");
}
if (debugMemory) {
afterMemory = DataStructures.getMemoryUse();
memory += afterMemory-beforeMemory;
if (debug>0) System.out.println("Trie memory use = "+(memory/count)+" bytes");
}
contains = trie.contains(invalid);
removed = trie.remove(invalid);
if (contains || removed) {
System.err.println("Trie invalidity check. contains="+contains+" removed="+removed);
return false;
} else System.out.println("Trie invalidity check. contains="+contains+" removed="+removed);
if (debug>1) System.out.println(trie.toString());
lookupTime = 0L;
beforeLookupTime = 0L;
afterLookupTime = 0L;
if (debugTime) beforeLookupTime = System.currentTimeMillis();
for (int item : unsorted) {
String string = String.valueOf(item);
trie.contains(string);
}
if (debugTime) {
afterLookupTime = System.currentTimeMillis();
lookupTime += afterLookupTime-beforeLookupTime;
if (debug>0) System.out.println("Trie lookup time = "+lookupTime/count+" ms");
}
if (debugTime) beforeRemoveTime = System.currentTimeMillis();
for (int i=0; i<unsorted.length; i++) {
int item = unsorted[i];
String string = String.valueOf(item);
trie.remove(string);
if (validateStructure && !(trie.size()==unsorted.length-(i+1))) {
System.err.println("YIKES!! "+item+" caused a size mismatch.");
handleError(trie);
return false;
}
if (validateContents && trie.contains(string)) {
System.err.println("YIKES!! "+string+" still exists.");
handleError(trie);
return false;
}
}
if (debugTime) {
afterRemoveTime = System.currentTimeMillis();
removeTime += afterRemoveTime-beforeRemoveTime;
if (debug>0) System.out.println("Trie remove time = "+removeTime/count+" ms");
}
contains = trie.contains(invalid);
removed = trie.remove(invalid);
if (contains || removed) {
System.err.println("Trie invalidity check. contains="+contains+" removed="+removed);
return false;
} else System.out.println("Trie invalidity check. contains="+contains+" removed="+removed);
//sorted
long addSortedTime = 0L;
long removeSortedTime = 0L;
long beforeAddSortedTime = 0L;
long afterAddSortedTime = 0L;
long beforeRemoveSortedTime = 0L;
long afterRemoveSortedTime = 0L;
if (debugMemory) beforeMemory = DataStructures.getMemoryUse();
if (debugTime) beforeAddSortedTime = System.currentTimeMillis();
for (int i=0; i<sorted.length; i++) {
int item = sorted[i];
String string = String.valueOf(item);
trie.add(string);
if (validateStructure && !(trie.size()==(i+1))) {
System.err.println("YIKES!! "+item+" caused a size mismatch.");
handleError(trie);
return false;
}
if (validateContents && !trie.contains(string)) {
System.err.println("YIKES!! "+item+" doesn't exist.");
handleError(trie);
return false;
}
}
if (debugTime) {
afterAddSortedTime = System.currentTimeMillis();
addSortedTime += afterAddSortedTime-beforeAddSortedTime;
if (debug>0) System.out.println("Trie add time = "+addSortedTime+" ms");
}
if (debugMemory) {
afterMemory = DataStructures.getMemoryUse();
memory += afterMemory-beforeMemory;
if (debug>0) System.out.println("Trie memory use = "+(memory/(count+1))+" bytes");
}
contains = trie.contains(invalid);
removed = trie.remove(invalid);
if (contains || removed) {
System.err.println("Trie invalidity check. contains="+contains+" removed="+removed);
return false;
} else System.out.println("Trie invalidity check. contains="+contains+" removed="+removed);
if (debug>1) System.out.println(trie.toString());
lookupTime = 0L;
beforeLookupTime = 0L;
afterLookupTime = 0L;
if (debugTime) beforeLookupTime = System.currentTimeMillis();
for (int item : sorted) {
String string = String.valueOf(item);
trie.contains(string);
}
if (debugTime) {
afterLookupTime = System.currentTimeMillis();
lookupTime += afterLookupTime-beforeLookupTime;
if (debug>0) System.out.println("Trie lookup time = "+lookupTime/(count+1)+" ms");
}
if (debugTime) beforeRemoveSortedTime = System.currentTimeMillis();
for (int i=sorted.length-1; i>=0; i--) {
int item = sorted[i];
String string = String.valueOf(item);
trie.remove(string);
if (validateStructure && !(trie.size()==i)) {
System.err.println("YIKES!! "+item+" caused a size mismatch.");
handleError(trie);
return false;
}
if (validateContents && trie.contains(string)) {
System.err.println("YIKES!! "+item+" still exists.");
handleError(trie);
return false;
}
}
if (debugTime) {
afterRemoveSortedTime = System.currentTimeMillis();
removeSortedTime += afterRemoveSortedTime-beforeRemoveSortedTime;
if (debug>0) System.out.println("Trie remove time = "+removeSortedTime+" ms");
}
contains = trie.contains(invalid);
removed = trie.remove(invalid);
if (contains || removed) {
System.err.println("Trie invalidity check. contains="+contains+" removed="+removed);
return false;
} else System.out.println("Trie invalidity check. contains="+contains+" removed="+removed);
if (testResults[testIndex]==null) testResults[testIndex] = new long[6];
testResults[testIndex][0]+=addTime/count;
testResults[testIndex][1]+=removeTime/count;
testResults[testIndex][2]+=addSortedTime;
testResults[testIndex][3]+=removeSortedTime;
testResults[testIndex][4]+=lookupTime/(count+1);
testResults[testIndex++][5]+=memory/(count+1);
if (debug>1) System.out.println();
}
return true;
}