Implement strStr()

Question (LC.28)

Returns the index of the first occurrence of needle in haystack, or -1 if needle is not part of haystack.

Example

Input: haystack = "stack", needle = "queue"
Return: -1
Input: haystack = "abcdefgef", needle = "ef"
Return: 4

Analysis

Essentially, this question is asking us to implement haystack.indexOf(needle). We always start from brute force approach. For each character in the haystack we search for a match for needle. Always do examples on paper/whiteboard first. This will help you understand the problem and think of a a clear algorithm. Also these examples can often lead to good test cases.

Test cases:
case 0: "", "" return 0
case 1: "dsgsfg","" return 0
case 2: "", "asdfsdf" return -1
case 3: "abcd", "fg" return -1
case 4: "abcdegef", "ef" return 6
case 5: "abcdeef", "ef" return 5 you have to reset the pointer every time
case 6: "abcdege", "ef" return -1

Two Pointers

for i from 0 to length(hay)-1-(length(need)-1)
    for j from 0 to length(need)-1
        if (hey[i] == need[j]) {
            i++;
        } else {
            j = 0;
            break;
        }
    end for
    if (j == length(need)-1)
        return i;
end for

return -1; // we didn't find any occurrence

Can you see the problem with this approach? You have i++ but never reset it back before the next iteration in the outer loop. This is not very clear. Instead of using two explicit pointers, we can "combine" them. Also j++ in the to break the inner for loop. So you have to check j == length(needle).

for i from 0 to length(hay)-1-(length(need)-1)
    for j from 0 to length(need)-1
        if (hey[i+j] != need[j]) {
            break;
        }
    end for
    if (j == length(need)-1)
        return i;
end for
return -1; // we didn't find any occurrence

Code

public int strStr(String source, String target) {
    if (source == null || target == null) {
        return -1;
    }
    int n = source.length(), m = target.length();

    for (int i = 0; i < n-m+1; i++) {
        int j = 0;
        for (j = 0; j < m; j++) {
            if (source.charAt(i+j) != target.charAt(j)) {
                 break;
            }
        }
        if (j == m) {
            return i;
        }
    }

    return -1;
}

Time Complexity

Time O(n*m) Space O(1)

Follow Up Question

Can you do this in linear time? O(m+n)

Example

Input: "abcdeef", "ef"
Return: 5

You can't get all possible substrings then hash. Enumerate all possible substrings are O(n^2) already and you need O(n^2) space as well. So how can we speed things up?

Robin-Karp (Rolling Hash)

Rabin–Karp algorithm a string searching algorithm that uses hashing to find any one of a set of pattern strings in a text. For text of length n and p patterns of combined length m, its average and best case running time is O(n+m) in space O(p), but its worst-case time is O(nm).

Robin-Karp essentially speeds up the inner loop O(m) from the brute force approach to O(1) with a hash function. The trick is using a rolling hash, removing the first element then add the next one.

Approach

I: source = "abcde", target = "cde"

Step 1 Compute the hash value for target
hash("cde") = (c * 31^2) + (d * 31^1) + (e * 31^0)

Step 2 Compute the rolling hash for source then compare
hash("abc") = (a * 31^2) + (b * 31*1) + (c * 31^0)
hash("abcd") = hash("abc") + (d * 31^0)
hash("bcd") = hash("abcd") - (a * 31^3)

Code

static final int HASH_SIZE = 1000000;
static final int PRIME = 31;

public int strStr2(String source, String target) {
    if (source == null || target == null) {
        return -1;
    }

    int n = source.length(), m = target.length();

    if (n < m) {
        return -1;
    }
    if (m == 0) {
        return 0;
    }

    // Step 0 Get the remove power
    int removePow = 1;
    for (int i = 0; i < m - 1; i++) {
        removePow = (removePow * PRIME) % HASH_SIZE;
    }

    // Step 1 Compute target hash
    int targetHash = myHash(target);

    // Step 2 Get an initial hash for source
    int rollHash = myHash(source.substring(0, m - 1));

    // Step 3 Rolling hash then compare
    int j = m - 1;
    for (int i = 0; i < n - m + 1; i++) {
        // extend the hash
        rollHash = rollHash * PRIME + source.charAt(j++);
        rollHash = rollHash % HASH_SIZE;

        // check hash
        if (rollHash == targetHash) {
            if (source.substring(i, j).equals(target)) {
                return i;
            }
        }

        // System.out.println("substring: " + source.substring(i, j));
        // System.out.println("rollHash : " + rollHash);
        // System.out.println("targetHash : " + targetHash);

        // shrink the hash
        rollHash = rollHash - (source.charAt(i) * removePow);
        rollHash = rollHash % HASH_SIZE;
        if (rollHash < 0) { // see comment below
            rollHash = rollHash + HASH_SIZE;
        }
    }

    return -1; // no match found
}

// encode the string with prime 31 and hash_size 10^6
private int myHash(String input) {
    int hValue = 0;
    for (int i = 0; i < input.length(); i++) {
        hValue = hValue * PRIME + input.charAt(i);
        hValue = hValue % HASH_SIZE;
    }
    return hValue;
}

Comment: % does different things in different languages. In java, % will get the reminder. Whereas, in Python, % will get the modulus.

Time Complexity

O(n) linear scan and rolling hash of the source plus the hash of the target

so O(n + m)

Knuth-Morris-Pratt (Story)

References

Last updated