trie树Java实现

本人脑子愚笨double array trie还是没有理解,如果有大神看到这段话希望能指点一下。

传统trie树


import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class Trie {
	private Vertex root;// 一个Trie树有一个根节点

	// 内部类
	protected class Vertex {// 节点类
		protected int words;
		protected int prefixes;
		protected Vertex[] edges;// 每个节点包含26个子节点(类型为自身)

		Vertex() {
			words = 0;
			prefixes = 0;
			edges = new Vertex[26];
			for (int i = 0; i < edges.length; i++) {
				edges[i] = null;
			}
		}
	}

	public Trie() {
		root = new Vertex();
	}

	/**
	 * 遍历所有单词
	 */
	public List<String> listAllWords() {

		List<String> words = new ArrayList<String>();
		Vertex[] edges = root.edges;

		for (int i = 0; i < edges.length; i++) {
			if (edges[i] != null) {
				String word = "" + (char) ('a' + i);
				depthFirstSearchWords(words, edges[i], word);
			}
		}
		return words;
	}

	/**
	 * 深度优先遍历树
	 */
	private void depthFirstSearchWords(List words, Vertex vertex,
			String wordSegment) {
		Vertex[] edges = vertex.edges;
		boolean hasChildren = false;
		for (int i = 0; i < edges.length; i++) {
			if (edges[i] != null) {
				hasChildren = true;
				String newWord = wordSegment + (char) ('a' + i);
				depthFirstSearchWords(words, edges[i], newWord);
			}
		}
		if (!hasChildren) {
			words.add(wordSegment);
		}
	}

	/**
	 * 前缀出现次数
	 */
	public int countPrefixes(String prefix) {
		return countPrefixes(root, prefix);
	}

	/**
	 * 递归查询最后字母出现次数
	 */
	private int countPrefixes(Vertex vertex, String prefixSegment) {
		if (prefixSegment.length() == 0) { // reach the last character of the
											// word
			return vertex.prefixes;
		}

		char c = prefixSegment.charAt(0);
		int index = c - 'a';
		if (vertex.edges[index] == null) { // the word does NOT exist
			return 0;
		} else {

			return countPrefixes(vertex.edges[index],
					prefixSegment.substring(1));

		}

	}

	/**
	 * 字母出现次数
	 */
	public int countWords(String word) {
		return countWords(root, word);
	}

	/**
	 * 递归查询字母出现次数
	 */
	private int countWords(Vertex vertex, String wordSegment) {
		if (wordSegment.length() == 0) { // reach the last character of the word
			return vertex.words;
		}

		char c = wordSegment.charAt(0);
		int index = c - 'a';
		if (vertex.edges[index] == null) { // the word does NOT exist
			return 0;
		} else {
			return countWords(vertex.edges[index], wordSegment.substring(1));

		}

	}

	/**
	 * 插入单词
	 */
	public void addWord(String word) {
		addWord(root, word);
	}

	/**
	 * 递归插入字母
	 */

	private void addWord(Vertex vertex, String word) {
		if (word.length() == 0) { // if all characters of the word has been
									// added
			vertex.words++;
		} else {
			vertex.prefixes++;
			char c = word.charAt(0);
			c = Character.toLowerCase(c);
			int index = c - 'a';
			if (vertex.edges[index] == null) { // if the edge does NOT exist
				vertex.edges[index] = new Vertex();
			}

			addWord(vertex.edges[index], word.substring(1)); // go the the next
																// character
		}
	}

	public static void main(String args[]) // Just used for test
	{
		Trie trie = new Trie();
		trie.addWord("China");
		trie.addWord("China");
		trie.addWord("China");

		trie.addWord("crawl");
		trie.addWord("crime");
		trie.addWord("ban");
		trie.addWord("China");

		trie.addWord("english");
		trie.addWord("establish");
		trie.addWord("eat");
		System.out.println(trie.root.prefixes);
		System.out.println(trie.root.words);

		List<String> list = trie.listAllWords();
		Iterator listiterator = list.listIterator();

		while (listiterator.hasNext()) {
			String s = (String) listiterator.next();
			System.out.println(s);
		}

		int count = trie.countPrefixes("ch");
		int count1 = trie.countWords("china");
		System.out.println("the count of c prefixes:" + count);
		System.out.println("the count of china countWords:" + count1);

	}
}

double array trie树


import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * DoubleArrayTrie在构建双数组的过程中也借助于一棵传统的Trie树,但这棵Trie树并没有被保存下来,
 * 如果要查找以prefix为前缀的所有词不适合用DoubleArrayTrie,应该用传统的Trie树。
 * 
 * @author zhangchaoyang
 * 
 */
public class DoubleArrayTrie {
	private final static int BUF_SIZE = 16384;// 2^14,java采用unicode编码表示所有字符,每个字符固定用两个字节表示。考虑到每个字节的符号位都是0,所以又可以节省两个bit
	private final static int UNIT_SIZE = 8; // size of int + int

	/**
	 * 传统树节点
	 */
	private static class Node {
		/**
		 * 字符的unicode编码
		 */
		int code;
		/**
		 * 在Trie树中的深度
		 */
		int depth;
		int left;
		int right;
	};

	private int check[];
	private int base[];

	private boolean used[];
	private int size;
	/**
	 * base数组当前的长度
	 */
	private int allocSize;
	/**
	 * 所有的词
	 */
	private List<String> key;
	/**
	 * 所有词个数
	 */
	private int keySize;
	/**
	 * 所有词的长度
	 */
	private int length[];
	private int value[];
	private int progress;
	private int nextCheckPos;
	int error_;

	/**
	 * 扩充base和check数组
	 * 
	 * @param newSize
	 *            新大小
	 * @return
	 */
	private int resize(int newSize) {
		int[] base2 = new int[newSize];
		int[] check2 = new int[newSize];
		boolean used2[] = new boolean[newSize];
		if (allocSize > 0) {
			System.arraycopy(base, 0, base2, 0, allocSize);// 如果allocSize超过了base2的长度,会抛出异常
			System.arraycopy(check, 0, check2, 0, allocSize);
			System.arraycopy(used, 0, used2, 0, allocSize);
		}

		base = base2;
		check = check2;
		used = used2;

		return allocSize = newSize;
	}

	/**
	 * 构建传统trie树
	 * 
	 * @param parent
	 *            父节点
	 * @param siblings
	 *            子节点列表
	 * @return
	 */
	private int fetch(Node parent, List<Node> siblings) {
		if (error_ < 0)
			return 0;

		int prev = 0;
		//遍历所有关键字
		for (int i = parent.left; i < parent.right; i++) {
			//如果 该词的总长度 小于 父节点的深度 则有问题进行下一个词
			if ((length != null ? length[i] : key.get(i).length()) < parent.depth)
				continue;
			//获得当前词
			String tmp = key.get(i);

			int cur = 0;
			//如果 该词的总长度 大于 父节点的深度 说明需要进行下一个节点构建
			if ((length != null ? length[i] : tmp.length()) != parent.depth)
				cur = (int) tmp.charAt(parent.depth) + 1;
			//如果下一个词的unicode+1 小于 上一个词 则错误
			if (prev > cur) {
				error_ = -3;
				return 0;
			}
			//如果需要构建
			if (cur != prev || siblings.size() == 0) {
				Node tmp_node = new Node();
				tmp_node.depth = parent.depth + 1;
				tmp_node.code = cur;
				tmp_node.left = i;
				if (siblings.size() != 0)
					siblings.get(siblings.size() - 1).right = i;

				siblings.add(tmp_node);
			}

			prev = cur;
		}

		if (siblings.size() != 0)
			siblings.get(siblings.size() - 1).right = parent.right;

		return siblings.size();
	}

	private int insert(List<Node> siblings) {
		if (error_ < 0)
			return 0;

		int begin = 0;
		int pos = ((siblings.get(0).code + 1 > nextCheckPos) ? siblings.get(0).code + 1
				: nextCheckPos) - 1;
		int nonzero_num = 0;
		int first = 0;

		if (allocSize <= pos)
			resize(pos + 1);

		outer: while (true) {
			pos++;

			if (allocSize <= pos)
				resize(pos + 1);

			if (check[pos] != 0) {
				nonzero_num++;
				continue;
			} else if (first == 0) {
				nextCheckPos = pos;
				first = 1;
			}

			begin = pos - siblings.get(0).code;
			if (allocSize <= (begin + siblings.get(siblings.size() - 1).code)) {
				// progress can be zero
				double l = (1.05 > 1.0 * keySize / (progress + 1)) ? 1.05 : 1.0
						* keySize / (progress + 1);
				resize((int) (allocSize * l));
			}

			if (used[begin])
				continue;

			for (int i = 1; i < siblings.size(); i++)
				if (check[begin + siblings.get(i).code] != 0)
					continue outer;

			break;
		}

		// -- Simple heuristics --
		// if the percentage of non-empty contents in check between the
		// index
		// 'next_check_pos' and 'check' is greater than some constant value
		// (e.g. 0.9),
		// new 'next_check_pos' index is written by 'check'.
		if (1.0 * nonzero_num / (pos - nextCheckPos + 1) >= 0.95)
			nextCheckPos = pos;

		used[begin] = true;
		size = (size > begin + siblings.get(siblings.size() - 1).code + 1) ? size
				: begin + siblings.get(siblings.size() - 1).code + 1;

		for (int i = 0; i < siblings.size(); i++)
			check[begin + siblings.get(i).code] = begin;

		for (int i = 0; i < siblings.size(); i++) {
			List<Node> new_siblings = new ArrayList<Node>();

			if (fetch(siblings.get(i), new_siblings) == 0) {
				base[begin + siblings.get(i).code] = (value != null) ? (-value[siblings
						.get(i).left] - 1) : (-siblings.get(i).left - 1);

				if (value != null && (-value[siblings.get(i).left] - 1) >= 0) {
					error_ = -2;
					return 0;
				}

				progress++;
				// if (progress_func_) (*progress_func_) (progress,
				// keySize);
			} else {
				int h = insert(new_siblings);
				base[begin + siblings.get(i).code] = h;
			}
		}
		return begin;
	}

	/**
	 * 构造方法
	 */
	public DoubleArrayTrie() {
		check = null;
		base = null;
		used = null;
		size = 0;
		allocSize = 0;
		// no_delete_ = false;
		error_ = 0;
	}

	// no deconstructor

	// set_result omitted
	// the search methods returns (the list of) the value(s) instead
	// of (the list of) the pair(s) of value(s) and length(s)

	// set_array omitted
	// array omitted

	void clear() {
		// if (! no_delete_)
		check = null;
		base = null;
		used = null;
		allocSize = 0;
		size = 0;
		// no_delete_ = false;
	}

	public int getUnitSize() {
		return UNIT_SIZE;
	}

	public int getSize() {
		return size;
	}

	public int getTotalSize() {
		return size * UNIT_SIZE;
	}

	public int getNonzeroSize() {
		int result = 0;
		for (int i = 0; i < size; i++)
			if (check[i] != 0)
				result++;
		return result;
	}

	/**
	 * 构建DoubleArrayTrie
	 * 
	 * @param key
	 *            所有词
	 * @return
	 */
	public int build(List<String> key) {
		return build(key, null, null, key.size());
	}

	/**
	 * 构建DoubleArrayTrie
	 * 
	 * @param _key
	 *            所有词
	 * @param _length
	 * @param _value
	 * @param _keySize
	 *            所有词的个数
	 * @return
	 */
	public int build(List<String> _key, int _length[], int _value[],
			int _keySize) {
		if (_keySize > _key.size() || _key == null)
			return 0;

		// progress_func_ = progress_func;
		key = _key;
		length = _length;
		keySize = _keySize;
		value = _value;
		progress = 0;

		resize(65536 * 32);

		base[0] = 1;
		nextCheckPos = 0;
		// 新建传统树根节点
		Node root_node = new Node();
		root_node.left = 0;
		root_node.right = keySize;
		root_node.depth = 0;
		// 新建一个节点数组
		List<Node> siblings = new ArrayList<Node>();
		fetch(root_node, siblings);
		insert(siblings);

		// size += (1 << 8 * 2) + 1; // ???
		// if (size >= allocSize) resize (size);

		used = null;
		key = null;

		return error_;
	}

	public void open(String fileName) throws IOException {
		File file = new File(fileName);
		size = (int) file.length() / UNIT_SIZE;
		check = new int[size];
		base = new int[size];

		DataInputStream is = null;
		try {
			is = new DataInputStream(new BufferedInputStream(
					new FileInputStream(file), BUF_SIZE));
			for (int i = 0; i < size; i++) {
				base[i] = is.readInt();
				check[i] = is.readInt();
			}
		} finally {
			if (is != null)
				is.close();
		}
	}

	public void save(String fileName) throws IOException {
		DataOutputStream out = null;
		try {
			out = new DataOutputStream(new BufferedOutputStream(
					new FileOutputStream(fileName)));
			for (int i = 0; i < size; i++) {
				out.writeInt(base[i]);
				out.writeInt(check[i]);
			}
			out.close();
		} finally {
			if (out != null)
				out.close();
		}
	}

	public int exactMatchSearch(String key) {
		return exactMatchSearch(key, 0, 0, 0);
	}

	public int exactMatchSearch(String key, int pos, int len, int nodePos) {
		if (len <= 0)
			len = key.length();
		if (nodePos <= 0)
			nodePos = 0;

		int result = -1;

		char[] keyChars = key.toCharArray();

		int b = base[nodePos];
		int p;

		for (int i = pos; i < len; i++) {
			p = b + (int) (keyChars[i]) + 1;
			if (b == check[p])
				b = base[p];
			else
				return result;
		}

		p = b;
		int n = base[p];
		if (b == check[p] && n < 0) {
			result = -n - 1;
		}
		return result;
	}

	public List<Integer> commonPrefixSearch(String key) {
		return commonPrefixSearch(key, 0, 0, 0);
	}

	public List<Integer> commonPrefixSearch(String key, int pos, int len,
			int nodePos) {
		if (len <= 0)
			len = key.length();
		if (nodePos <= 0)
			nodePos = 0;

		List<Integer> result = new ArrayList<Integer>();

		char[] keyChars = key.toCharArray();

		int b = base[nodePos];
		int n;
		int p;

		for (int i = pos; i < len; i++) {
			p = b;
			n = base[p];

			if (b == check[p] && n < 0) {
				result.add(-n - 1);
			}

			p = b + (int) (keyChars[i]) + 1;
			if (b == check[p])
				b = base[p];
			else
				return result;
		}

		p = b;
		n = base[p];

		if (b == check[p] && n < 0) {
			result.add(-n - 1);
		}

		return result;
	}

	// debug
	public void dump() {
		for (int i = 0; i < size; i++) {
			System.err.println("i: " + i + " [" + base[i] + ", " + check[i]
					+ "]");
		}
	}

	public static void main(String[] args) {
		DoubleArrayTrie adt = new DoubleArrayTrie();
		List<String> list = new ArrayList<String>();
		list.add("阿胶");
		list.add("阿拉伯");
		list.add("阿拉伯人");
		list.add("埃及");
		Collections.sort(list);// 所有词必须先排序

		adt.build(list);// 构建DoubleArrayTrie
		String key = "一个来自于埃及的乞丐自称是来自阿拉伯是正宗的阿拉伯人最喜欢吃阿胶";
		List<Integer> rect = adt.commonPrefixSearch(key);// 检索key中的哪些前缀是词典中的词
		for (int index : rect) {
			System.out.println(list.get(index));
		}
		int index = adt.exactMatchSearch(key);// 检索key是不是词典中的词
		if (index >= 0) {
			System.out.println(list.get(index));
		}
	}
}
    原文作者:Trie树
    原文地址: https://blog.csdn.net/zsr_251/article/details/47275315
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞