Given a string s, rearrange the characters of s so that no same characters are adjacent to each other.
Examples
I: s = "aab"
O: "aba"
I: s = "aaab"
O: "" no solution
I: s = "aabb"
O: "abab", "baba" multiple solutions
Brute Force
Generate all permutations, check if one of them is a valid rearrangement.
Guarantee to work in EXP time. Length really matters in this case. < 20 will kind of work.
Code
from itertools import permutations
def reorganize_string(self, s: str) -> str:
if len(s) <= 1:
return s
# permutations yields next permuation in sorted order
for perm in permutations(s):
if self.is_not_adj(perm):
return "".join(perm)
return ""
def is_not_adj(self, p_list) -> bool:
if len(p_list) < 1:
return True
for i in range(len(p_list) - 1):
if p_list[i] == p_list[i+1]:
return False
return True
Passed 26 / 71
Time Limit Exceeded on s = "kkkkzrkatkwpkkkktrq"
A Better Approach
We want to bring down the time complexity to polynomial time. One intuitive solution is to group the letters, sort by frequency, and then do an insertion sort for each letter starting from the end.
The initial grouping and sorting by frequency guarantees an initial state that is solvable within n moves. No insertion sort operation should be wasted if starting in this initial state. If one insertion sort cannot find a valid place to insert the last element, then the input is unsolvable.
The proof of correctness is not immediately apparent. We can try out a few examples to validate this approach.
The motivation of having this initial set up and inserting the last element is to maximize the number of valid positions per insertion. We'll explore that part a bit more later.
Code
from collections import defaultdict
def reorganize_string(self, s: str) -> str:
if len(s) <= 1:
return s
# convert to a list of char
s_list = list(s)
# group and sort
group_list = self.group_and_sort(s_list)
# insertion sort the last element
n = len(group_list)
l = len(group_list) - 1
for _ in range(n, 1, -1):
if self.is_not_adj(group_list):
return "".join(group_list)
inserted = False
for i in range(0, n - 1, 1):
# insert to front
if i == 0 and group_list[i] != group_list[l]:
last_val = group_list.pop()
group_list.insert(0, last_val)
inserted = True
break
elif group_list[i] != group_list[l] and group_list[i+1] != group_list[l]:
# insert in between
last_val = group_list.pop()
group_list.insert(i+1, last_val)
inserted = True
break
if not inserted:
return ""
return "".join(group_list)
def group_and_sort(self, s_list):
char_map = defaultdict(int)
for char in s_list:
char_map[char] += 1
group_list = []
for char in sorted(char_map, key=char_map.get, reverse=True):
for i in range(char_map[char]):
group_list.append(char)
return group_list
def is_not_adj(self, p_list) -> bool:
if len(p_list) < 1:
return True
for i in range(len(p_list) - 1):
if p_list[i] == p_list[i+1]:
return False
return True
This code is accepted with 102ms runtime. The initial grouping and sorting can be O(nlogn). The insertion sort is O(n^2).
Heap Approach
To do better than O(n^2), we have to evaluate solutions that are O(nlogn) or O(n). We can achieve O(nlogn) with heap.
Code
import heapq
def reorganizeString(self, s: str) -> str:
if len(s) <= 1:
return s
# step 1 group letters
letter_map = defaultdict(int)
for l in s:
letter_map[l] += 1
# step 2 init max heap
max_heap = []
for l in letter_map:
heapq.heappush(max_heap, (-letter_map[l], l))
# step 3 pop from max heap, append to result, insert back to heap
result = ""
while len(max_heap) > 0:
first_pop = heapq.heappop(max_heap)
first_count = abs(first_pop[0])
first_letter = first_pop[1]
if len(result) == 0 or result[-1] != first_letter:
result += first_letter
if first_count > 1:
first_count -= 1
heapq.heappush(max_heap, (-first_count, first_letter))
else:
if len(max_heap) == 0:
return ""
second_pop = heapq.heappop(max_heap)
second_count = abs(second_pop[0])
second_letter = second_pop[1]
result += second_letter
if second_count > 1:
second_count -= 1
heapq.heappush(max_heap, (-second_count, second_letter))
# first group has to be pushed back too
heapq.heappush(max_heap, first_pop)
return result
Worst case O(nlogk) + O(2nlogk) which is just O(nlogk). 24ms beats 99.12%.