https://leetcode.com/problems/android-unlock-patterns/
这是一道dp题,但是一时间没想出怎么dp。
dfs解法:
class Solution {
vector<vector<bool> > visited;
int res;
int m;
int n;
bool isValid(int x1, int y1, int x2, int y2) {
int dx = abs(x1 - x2);
int dy = abs(y1 - y2);
if(x1 == x2 && y1 == y2) return false;
if(visited[x2][y2]) return false;
if(dx <= 1 && dy <= 1) return true;
if(dx == 2 && dy == 0) return visited[(x1+x2)/2][(y1+y2)/2];
if(dx == 0 && dy == 2) return visited[(x1+x2)/2][(y1+y2)/2];
if(dx == 2 && dy == 2) return visited[(x1+x2)/2][(y1+y2)/2];
return true;
}
void dfs(int x, int y, int keys) {
if(keys == n) { res++; return;}
visited[x][y] = true;
if(keys >= m && keys <= n)
res++;
for(int i = 0; i < 9; i++) {
int x2 = i / 3;
int y2 = i % 3;
if(x == x2 && y == y2) continue;
if(isValid(x,y,x2,y2))
dfs(x2,y2,keys+1);
}
visited[x][y] = false;
}
public:
int numberOfPatterns(int _m, int _n) {
visited.resize(3,vector<bool>(3,false));
res = 0;
m = _m;
n = _n;
dfs(0,0,1);
int res1 = res;
res = 0;
dfs(0,1,1);
int res2 = res;
res = 0;
dfs(1,1,1);
int res3 = res;
res = res1 * 4 + res2 * 4 + res3;
return res;
}
};
https://leetcode.com/discuss/104293/share-a-bitmask-dp-solution
dp:
dp[i][j],其中i表示当前的局面,j表示当前的key。
i共有1<<9个,从000 000 000,到111 111 111。
j共有9个,从0到8。
ret[i-1]用来统计key count为i的结果。
最后返回ret[m-1]到ret[n-1]。
初始化dp[1<<i][i] = 1;从第i个key开始只有1个key的pattern数目为1.
for(int i=1;i<(1<<9);++i)循环遍历所有情况。
1.首先数该情况下有多少个1,用count计算。
tmp&=(tmp-1)会把最后一个1置为0。
2.for(int j=0;j<9;++j)
把当前局面下到第j个key的所有情况从dp[i][j]加到res数组里。
更新下一个情况的dp[i+(1<<k)][k]的值。
需要仔细观察的是下面的if语句:
if(((x==x1&&abs(y-y1)==2)
|| (y==y1&&abs(x-x1)==2)
||(abs(x-x1)==2&&abs(y-y1)==2))
&& !(i&(1<<(3*(x+x1)/2+(y+y1)/2))))
continue;
这句话的意思是如果是中间有节点的情况,并且该节点在当前局面下没有被访问过,则continue。
int numberOfPatterns(int m, int n) {
vector<vector<int>> dp((1<<9),vector<int>(9,0));
vector<int> ret(9,0);
for(int i=0;i<9;++i) dp[1<<i][i]=1;
for(int i=1;i<(1<<9);++i) {
int count=0;
int tmp=i;
while(tmp) {
count++;
tmp&=(tmp-1);
}
if(count>n) continue;
for(int j=0;j<9;++j) {
ret[count-1]+=dp[i][j];
if(i&(1<<j)) {
int x=j/3;
int y=j%3;
for(int k=0;k<9;++k) {
if(!(i&(1<<k))) {
int x1=k/3;
int y1=k%3;
if(((x==x1&&abs(y-y1)==2)||(y==y1&&abs(x-x1)==2)||(abs(x-x1)==2&&abs(y-y1)==2))&&!(i&(1<<(3*(x+x1)/2+(y+y1)/2)))) continue;
dp[i+(1<<k)][k]+=dp[i][j];
}
}
}
}
}
int ans=0;
for(int i=m-1;i<n;++i) ans+=ret[i];
return ans;
}