Thursday, December 12, 2013

Leetcode OJ: Recover Binary Search Tree

Recover Binary Search Tree

 

Two elements of a binary search tree (BST) are swapped by mistake.
Recover the tree without changing its structure. 

An O(n) space algorithm is straight-forward. One can write the in-order traversal of the BST to a linear array, then find the two swapped elements and swap them back.

A better solution is to not explicitly write the in-order traversal to an array. In fact, we can just do an in-order traversal on the BST itself and find the swapped elements in the process. Notice that if we exchange two adjacent (in terms of in-order) nodes, then the logic to find swapped elements is a little different, but this is handled subtly in the code.

/**
* Definition for binary tree
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
public class Solution {
TreeNode first = null;
TreeNode second = null;
TreeNode afterFirst = null;
public void recoverTree(TreeNode root) {
inOrder(root, null);
if (second == null) {
second = afterFirst;
}
int temp = first.val;
first.val = second.val;
second.val = temp;
}
public TreeNode inOrder(TreeNode root, TreeNode prev) {
if (root == null) {
return prev;
}
prev = inOrder(root.left, prev);
if (prev != null && prev.val > root.val) {
if (first == null) {
first = prev;
afterFirst = root;
} else {
second = root;
}
}
return inOrder(root.right, root);
}
}

However, since the previous solution uses recursion, then the worst case space complexity is the depth of the tree, which could be as bad as O(n).

A tree O(1) space solution requires the used of Threaded Binary Tree, which I learned from other's solution. A threaded binary tree allows one to truly traverse a binary tree with constant space. In our case, we want a forward traversal, so the right child of any node who doesn't have a right child points to its in-order successor. Since we want to preserve the structure, we make that pointing on the fly and recover it after that pointing after use. A threaded binary tree solution is as follows. Notice that we need to find the in-order predecessor of each node twice. The total complexity of finding in-order predecessor of all nodes is O(n), since it's easy to see that every edge is traversed only once. So it is not too bad.

public class Solution {
TreeNode first = null;
TreeNode second = null;
public void recoverTree(TreeNode root) {
TreeNode prev = null;
TreeNode cur = root;
while (cur != null) {
if (cur.left == null) {
if (prev != null && prev.val > cur.val) {
second = cur;
if (first == null) {
first = prev;
}
}
prev = cur;
cur = cur.right;
} else {
TreeNode pred = cur.left;
while (pred.right != null && pred.right != cur) {
pred = pred.right;
}
if (pred.right == null) {
pred.right = cur;
cur = cur.left;
} else {
pred.right = null;
if (prev != null && prev.val > cur.val) {
second = cur;
if (first == null) {
first = prev;
}
}
prev = cur;
cur = cur.right;
}
}
}
int temp = first.val;
first.val = second.val;
second.val = temp;
}
}

Saturday, December 7, 2013

LeetCode OJ: Wildcard Matching

Implement wildcard pattern matching with support for '?' and '*'.
'?' Matches any single character.
'*' Matches any sequence of characters (including the empty sequence).

The matching should cover the entire input string (not partial).

The function prototype should be:
bool isMatch(const char *s, const char *p)

Some examples:
isMatch("aa","a") → false
isMatch("aa","aa") → true
isMatch("aaa","aa") → false
isMatch("aa", "*") → true
isMatch("aa", "a*") → true
isMatch("ab", "?*") → true
isMatch("aab", "c*a*b") → false

A typical DP. However, a naive O(N^3) time, O(N^2) space solution can not path large test cases. So 3 optimizations are used.

(1) Suppose s and p are of length n and m respectively, then we want to compute a m by n boolean matrix F, where F(i, j) is true if the first i letters of the pattern matches the first j letters of the string. However, we do not need to store then entire matrix, as we can compute the matrix row by row, where the computation of any row only relies on the previous row. Only 2 rows are needed at any time and the space complexity is reduced to O(2*N)

(2) It's easy to see that the '*' is the cause of the O(N^3) complexity, so a simple optimization is to collapse all adjacent '*' into one, which will not change the language represented by the pattern.

(3) When N is very large, like tens of thousand, then even O(N^2) is too slow. We say a non '*' character in the pattern as a hard character, as it must match at least one character from the input string. If we have k hard characters among first i characters of the pattern, then we know F(i, j) must be false if j < k. In this way, we can find a lower and upper bound of j between which F(i, j) can be true for any fixed i. This optimization is especially useful when the pattern is very long, but most of characters in the pattern are hard.

An solution using the above three optimizations is as follows.

public class Solution {
public static boolean isMatch(String s, String p) {
p = collapseStar(p);
int m = p.length();
int n = s.length();
// Find the limit array
int[] nChar = new int[m + 1];
nChar[0] = 0;
for (int i = 1; i <= m; i++) {
if (p.charAt(i - 1) == '*') {
nChar[i] = nChar[i - 1];
} else {
nChar[i] = nChar[i - 1] + 1;
}
}
boolean[] prev = new boolean[n + 1];
boolean[] cur = new boolean[n + 1];
int prevLB, prevUB, curLB, curUB;
prev[0] = true;
for (int j = 1; j <= n; j++) {
prev[j] = false;
}
prevLB = 0;
prevUB = 0;
for (int i = 1; i <= m; i++) {
cur[0] = false;
curLB = nChar[i];
curUB = n - (nChar[m] - nChar[i]);
for (int j = curLB; j <= curUB; j++) {
if (p.charAt(i - 1) == '?') {
cur[j] = prev[j-1] && j - 1 >= prevLB && j - 1 <= prevUB;
} else if (p.charAt(i - 1) == '*') {
boolean f = false;
for (int k = 0; k <= j; k++) {
f |= (prev[k] && k >= prevLB && k <= prevUB);
}
cur[j] = f;
} else {
cur[j] = prev[j-1] && j - 1 >= prevLB && j - 1 <= prevUB && (p.charAt(i - 1) == s.charAt(j - 1));
}
}
boolean[] temp = prev;
prev = cur;
cur = temp;
prevLB = curLB;
prevUB = curUB;
}
return prev[n] && n >= prevLB && n <= prevUB;
}
public static String collapseStar(String p) {
StringBuilder sb = new StringBuilder();
boolean lastStar = false;
for (int i = 0; i < p.length(); i++) {
if (p.charAt(i) == '*') {
if (!lastStar) {
sb.append('*');
lastStar = true;
}
} else {
sb.append(p.charAt(i));
lastStar = false;
}
}
return sb.toString();
}
}

LeetCode OJ: Word Ladder I/II

Given two words (start and end), and a dictionary, find all shortest transformation sequence(s) from start to end, such that:
  1. Only one letter can be changed at a time
  2. Each intermediate word must exist in the dictionary
For example,
Given:
start = "hit"
end = "cog"
dict = ["hot","dot","dog","lot","log"]
Return
  [
    ["hit","hot","dot","dog","cog"],
    ["hit","hot","lot","log","cog"]
  ]
Note:
  • All words have the same length.
  • All words contain only lowercase alphabetic characters.

The above description is for Word Ladder II. Word Ladder one only require you to return the length of shortest path, which will be 5 in the above example.

The output part is basic DFS, and the key part is to find a parent map for the single source shortest path problem. Normally, when we have a graph G(V, E) and a single source s, Dijkstra's algorithm will build a parent map P: V -> V, which can be used to find the shortest path to any destination t from s. P(t) is simply the vetex v, such that path s->...->v->t is the shortest path from s to u. In Word Ladder II, since we are interested in not only one shortest path, but all shortest path, this parent map must be mapping from V to a subset of V, i.e. P: V -> V^2, as there might be more than one parent of a given vertex and we want to get all of them.

A simple modification of Dijkstra will work. Usually, when the temporary minimum distance of a vertex u is updated, we let the vertex v that updates it to be its parent. Now, we add the v to the parent set instead. When the minimum distance is decreased, we clear the original parent set and then add v; when the new distance is the same as the minimum distance, we add the v.

An accepted implementation is as follows:
public class Solution {
public static class Node {
public int dist;
public String value;
public Node(int dist, String value) {
this.dist = dist;
this.value = value;
}
}
public static class NodeComparator implements Comparator<Node> {
@Override
public int compare(Node o1, Node o2) {
return o1.dist - o2.dist;
}
}
// This is a n^2 algorithm, since we are not using heap for Dijkstra algorithm
public static ArrayList<ArrayList<String>> findLadders(String start, String end, HashSet<String> dict) {
if (!dict.contains(start)) {
dict.add(start);
}
if (!dict.contains(end)) {
dict.add(end);
}
int n = dict.size();
// Create adjacency list
HashMap<String, LinkedList<String>> adj = new HashMap<String, LinkedList<String>>(n);
for (String word: dict) {
char[] letters = word.toCharArray();
LinkedList<String> adjList = new LinkedList<String>();
for (int i = 0; i < letters.length; i++) {
char save = letters[i];
for (char replace = 'a'; replace <= 'z'; replace++) {
letters[i] = replace;
String neighbor = new String(letters);
if (replace != save && dict.contains(neighbor)) {
adjList.add(neighbor);
}
}
letters[i] = save;
}
adj.put(word, adjList);
}
// Run Dijkstra
HashMap<String, Integer> minDist = new HashMap<String, Integer>(n);
for (String word: dict) {
minDist.put(word, Integer.MAX_VALUE);
}
minDist.put(start, 0);
HashSet<String> visited = new HashSet<String>(n);
PriorityQueue<Node> pq = new PriorityQueue<Node>(n, new NodeComparator());
pq.add(new Node(0, start));
HashMap<String, LinkedList<String>> parent = new HashMap<String, LinkedList<String>>(n);
for (String word: dict) {
parent.put(word, new LinkedList<String>());
}
while (!pq.isEmpty()) {
boolean found = false;
Node cur = null;
while (!pq.isEmpty()) {
cur = pq.remove();
if (!visited.contains(cur.value) && cur.dist == minDist.get(cur.value)) {
found = true;
break;
}
}
if (!found || cur.dist == Integer.MAX_VALUE) {
break;
}
String value = cur.value;
visited.add(value);
for (String s: adj.get(value)) {
int newDist = minDist.get(value) + 1;
int oldDist = minDist.get(s);
if (newDist < oldDist) {
minDist.put(s, newDist);
parent.get(s).clear();
parent.get(s).add(value);
pq.add(new Node(newDist, s));
} else if (newDist == oldDist) {
parent.get(s).add(value);
}
}
}
// Recursively find all path
int shortest = minDist.get(end);
ArrayList<ArrayList<String>> result = new ArrayList<ArrayList<String>>();
if (shortest == Integer.MAX_VALUE) {
return new ArrayList<ArrayList<String>>();
} else {
int size = minDist.get(end) + 1;
search(result, new ArrayList<String>(shortest), end, 0, size, dict, parent);
return result;
}
}
public static void search(ArrayList<ArrayList<String>> result, ArrayList<String> current,
String cur, int step, int maxStep, HashSet<String> dict, HashMap<String, LinkedList<String>> parent) {
current.add(cur);
step++;
if (step == maxStep) {
result.add(inverseList(current));
} else {
for (String p: parent.get(cur)) {
search(result, current, p, step, maxStep, dict, parent);
}
}
step--;
current.remove(step);
}
public static ArrayList<String> inverseList(ArrayList<String> a) {
int n = a.size();
ArrayList<String> ret = new ArrayList<String>(n);
for (int i = 0; i < n; i++) {
ret.add(a.get(n - i - 1));
}
return ret;
}
}

Notice that we used the adjacency list and priority queue for optimal performance. Otherwise it can not pass the large test cases.

Suppose we have n words in the dictionary, each word is at most m characters long, then a naive way to create the adjacency list by comparing each pair of words suffers O(m*n^2) complexity, which is unacceptable for large n like 20000. In my implementation, we instead find neighbors of a word w by trying to change each of its letters to another one, and the complexity is reduces to (m*n*26).

For the PriorityQueue implementation in Java, decreaseKey operation (decrease the priority of a node given its reference, the priority queue shall rearrange to keep heap property), which is used in Dijkstra's algorithm is not supported. Rather than implement my own priority queue with this operation, I simply insert a node with the decreased key instead. In this way, there will be more than one node for the same vertex in the priority queue, one with the newest/smallest key, whereas the rest are with previous large keys. Since we can keep track of the smallest temporary distance of all the vertices, we can simply ignore a node with key larger than that when doing a removeMin operation. A simple analysis show that this will potential result in larger heap, so the complexity is increased from O(E + VlogV) to O(E + VlogE), which is essentially the same since logE is at most 2logV.

BFS can also be used for this problem since edge weights are uniformly one. Similar, one need to maintain a parent map that is one to many instead of one to one. An BFS implementation is as follows. The running time on LeetCode is not better that the Dijkstra one. Both of them takes more than 1000 ms. Maybe Java is too slow?


public class Solution1 {
public static class Node {
public int level;
public String value;
public Node(String value, int level) {
this.level = level;
this.value = value;
}
}
// This is a n^2 algorithm, since we are not using heap for Dijkstra algorithm
public static ArrayList<ArrayList<String>> findLadders(String start, String end, HashSet<String> dict) {
if (!dict.contains(start)) {
dict.add(start);
}
if (!dict.contains(end)) {
dict.add(end);
}
int n = dict.size();
// Create adjacency list
HashMap<String, LinkedList<String>> adj = new HashMap<String, LinkedList<String>>(n);
for (String word: dict) {
char[] letters = word.toCharArray();
LinkedList<String> adjList = new LinkedList<String>();
for (int i = 0; i < letters.length; i++) {
char save = letters[i];
for (char replace = 'a'; replace <= 'z'; replace++) {
letters[i] = replace;
String neighbor = new String(letters);
if (replace != save && dict.contains(neighbor)) {
adjList.add(neighbor);
}
}
letters[i] = save;
}
adj.put(word, adjList);
}
// Run BFS
HashMap<String, Integer> minDist = new HashMap<String, Integer>(n);
for (String word: dict) {
minDist.put(word, Integer.MAX_VALUE);
}
minDist.put(start, 0);
HashSet<String> visited = new HashSet<String>(n);
HashMap<String, LinkedList<String>> parent = new HashMap<String, LinkedList<String>>(n);
for (String word: dict) {
parent.put(word, new LinkedList<String>());
}
LinkedList<Node> queue = new LinkedList<Node>();
visited.add(start);
queue.addLast(new Node(start, 0));
while (!queue.isEmpty()) {
Node cur = queue.removeFirst();
String value = cur.value;
for (String s: adj.get(value)) {
int newDist = cur.level + 1;
int oldDist = minDist.get(s);
if (newDist < oldDist) {
minDist.put(s, newDist);
parent.get(s).clear();
parent.get(s).add(value);
} else if (newDist == oldDist) {
parent.get(s).add(value);
}
if (!visited.contains(s)) {
queue.add(new Node(s, newDist));
visited.add(s);
}
}
}
// Recursively find all path
int shortest = minDist.get(end);
ArrayList<ArrayList<String>> result = new ArrayList<ArrayList<String>>();
if (shortest == Integer.MAX_VALUE) {
return new ArrayList<ArrayList<String>>();
} else {
int size = minDist.get(end) + 1;
search(result, new ArrayList<String>(shortest), end, 0, size, dict, parent);
return result;
}
}
public static void search(ArrayList<ArrayList<String>> result, ArrayList<String> current,
String cur, int step, int maxStep, HashSet<String> dict, HashMap<String, LinkedList<String>> parent) {
current.add(cur);
step++;
if (step == maxStep) {
result.add(inverseList(current));
} else {
for (String p: parent.get(cur)) {
search(result, current, p, step, maxStep, dict, parent);
}
}
step--;
current.remove(step);
}
public static ArrayList<String> inverseList(ArrayList<String> a) {
int n = a.size();
ArrayList<String> ret = new ArrayList<String>(n);
for (int i = 0; i < n; i++) {
ret.add(a.get(n - i - 1));
}
return ret;
}
}

Tuesday, December 3, 2013

LeetCode OJ: Max Points on a Line

This LeetCode problem has the second lowest AC rate. It deserves some attention as it asks for the use of hash table and choice of hash function that requires some thought.

Given n points on a 2D plane, find the maximum number of points that lie on the same straight line.
note: the points have integer coordinates.

A naive O(n^3) algorithm will consider all pairs of points (p_i, p_j), then for all k != i or j, check whether p_k is on the line defined by p_i and p_j. Identical points needs to be taken care of. An implementation of O(n^3) algorithm is as follows:

/**
* Definition for a point.
* class Point {
* int x;
* int y;
* Point() { x = 0; y = 0; }
* Point(int a, int b) { x = a; y = b; }
* }
*/
public class Solution {
public int maxPoints(Point[] points) {
if (points.length == 0) {
return 0;
}
Arrays.sort(points, new PointComparator());
int cur = 0;
int i = 1;
int[] counts = new int[points.length];
counts[0] = 1;
while (i < points.length) {
if (points[cur].x == points[i].x && points[cur].y == points[i].y) {
i++;
counts[cur]++;
} else {
cur++;
points[cur].x = points[i].x;
points[cur].y = points[i].y;
i++;
counts[cur] = 1;
}
}
return maxDistinct(points, counts, cur + 1);
}
private static class PointComparator implements Comparator<Point> {
public int compare (Point p1, Point p2) {
if (p1.x > p2.x) {
return 1;
} else if (p1.x < p2.x) {
return -1;
} else if (p1.y > p2.y) {
return 1;
} else if (p1.y < p2.y) {
return -1;
} else {
return 0;
}
}
}
public int maxDistinct(Point[] points, int[] counts, int n) {
if (points.length == 0) {
return 0;
} else if (n == 1) {
return counts[0];
} else if (n == 2) {
return counts[0] + counts[1];
}
int max = 2;
for (int i = 0; i < n; i++) {
Point p1 = points[i];
for (int j = i + 1; j < n; j++) {
Point p2 = points[j];
int val = counts[i] + counts[j];
for (int k = j + 1; k < n; k++) {
Point p3 = points[k];
if ((p3.y - p2.y) * (p2.x - p1.x) == (p3.x - p2.x) * (p2.y - p1.y)) {
val += counts[k];
}
}
if (val > max) {
max = val;
}
}
}
return max;
}
}
view raw gistfile1.java hosted with ❤ by GitHub
A better O(n^2) solution is possible by using hash tables. Notice that this problem is known to be 3SUM-hard so a non-parallel, non-randomized algorithm is unlikely to be better than O(n^2).

To implement a O(n^2) algorithm, one need to create a hash table whose key is encoding of line. One way to encode a line in 2D plan is using slope-intercept form: y = mx + b. However, a and b could be float numbers even if x and y are integers and it is well-known that equality test for float number is error-prone. Instead, I choose to use the standard form: ax + by + c = 0. If we have two distinct points (x1, y1) and (x2, y2), then we can set a = y2 - y1, b = x1 - x2, c = y2 * x1 - y1 * x2, and all of a, b and c will be integers. However, this (a, b, c) is not unique as any integral multiple of it represents the same line. To guarantee uniqueness, we require b to be nonnegative. If b is zero, i.e. the line is vertical, then we require a to be positive. Notice that a and b can not both be zero.

We need to find a good hash function in order to use this (a, b, c) representation as key of hash table. The hash function for hashing integer vectors mentioned in my last post is a good candidate. In particular for the problem at hand, I used:

h(a, b, c) = a + b * 31 + c * 31 ^ 2.

The algorithm itself is straightforward. For each point p_i, we consider all points after it {p_i+1, ..., p_n}. We create a hash table whose key is (a, b, c) tuple and whose value is the number of points in {p_i+1, ..., p_n} that connects to p_i by the line represented by the corresponding key. The the largest value in the constructed hash table is the maximum number of points that lie on the same straight line if the smallest index of those points is i. An implementation is as follows:

/**
* Definition for a point.
* class Point {
* int x;
* int y;
* Point() { x = 0; y = 0; }
* Point(int a, int b) { x = a; y = b; }
* }
*/
public class Solution {
private class Line {
int a;
int b;
int c;
public Line(Point p1, Point p2) {
this(p2.y - p1.y, p1.x - p2.x, p2.y * p1.x - p2.x * p1.y);
}
public Line(int a, int b, int c) {
if (b != 0) {
if (b < 0) {
a = -a;
b = -b;
c = -c;
}
int gcd1 = gcd(b, Math.abs(a));
int gcd2 = gcd(gcd1, Math.abs(c));
this.a = a / gcd2;
this.b = b / gcd2;
this.c = c / gcd2;
} else {
if (a != 0) {
if (a < 0) {
a = -a;
c = -c;
}
int gcd1 = gcd(a, Math.abs(c));
this.a = a / gcd1;
this.b = 0;
this.c = c / gcd1;
} else {
this.a = 0;
this.b = 0;
this.c = 0;
}
}
}
private int gcd(int x, int y) {
// Assume x >= 0, y >= 0.
if (x > y) {
return gcd(y, x);
} else if (x == 0) {
return y;
} else {
return gcd(y % x, x);
}
}
public boolean equals(Object o) {
if (o instanceof Line) {
Line l = (Line)o;
return (a == l.a && b == l.b && c == l.c);
} else {
return false;
}
}
public int hashCode() {
return a * 31 * 31 + b * 31 + c;
}
}
public int maxPoints(Point[] points) {
if (points.length <= 2) {
return points.length;
}
Arrays.sort(points, new PointComparator());
Map<Line, Integer> m = new HashMap<Line, Integer>();
int max = 0;
for (int i = 0; i < points.length; ) {
m.clear();
int sameCount = 1;
for (int j = i+1; j < points.length; j ++) {
if (points[i].x == points[j].x && points[i].y == points[j].y) {
sameCount ++;
} else {
Line l = new Line(points[i], points[j]);
if (m.containsKey(l)) {
m.put(l, m.get(l) + 1);
} else {
m.put(l, sameCount + 1);
}
}
}
m.put(new Line(0, 0, 0), sameCount);
for (int v : m.values()) {
max = Math.max(max, v);
}
i += sameCount;
}
return max;
}
private static class PointComparator implements Comparator<Point> {
public int compare (Point p1, Point p2) {
if (p1.x > p2.x) {
return 1;
} else if (p1.x < p2.x) {
return -1;
} else if (p1.y > p2.y) {
return 1;
} else if (p1.y < p2.y) {
return -1;
} else {
return 0;
}
}
}
}
view raw gistfile1.java hosted with ❤ by GitHub
Interestingly, the running time of O(n^2) and O(n^3) solutions are not very different on LeetCode OJ, which is probably due to large overhead and small testing cases.

Monday, December 2, 2013

Choice of Hash functions for integer k-tuple in Java

In theory classes, we learn about Universal Hashing, Perfect Hashing, Cuckoo Hashing, Tabulation Hashing, so on and so forth. Most of those techniques enjoy great theoretical value, yet face difficulties in practice. For example, Universal Hashing asks as to encode some randomness into the hashcode() method, which is hard to achieve and we would rather just hard code some coefficient a and b. On the other hand, we rarely in practice need to consider adversarial guarantees. For another example, we simply don't want to afford the creation time and large constant factor of a Perfect Hashing.

We often find ourself just want a hashcode() function for a class whose object is uniquely identified by a k-tuple of integers (a1, a2, ..., ak). Since the HashMap in Java already has a CRC style Hash function to map integer to bins, the hashcode() function only need to convert the k-tuple into a single integer. The following one is what Java uses to hash a String, and is what I suggest to use for hashing k-tuple of integers:

h((a1, a2, ..., ak))
= a1 + a2 * p + a3 * p^2 + ... + ak * p^(k-1)

There is an implicit module operation due to potential integer overflow. Also, it seems people often choose a Mersenne prime for p, such as 31.