AOJ1281 The Morning after Halloween

問題リンク The Morning after Halloween

  • 概要

W*Hの大きさのオバケ屋敷がある。オバケはN匹存在する。
オバケはアルファベット小文字で表され、それぞれ定位置へ移動したい。各オバケの定位置は対応するアルファベットの大文字で表される。オバケは1ステップに隣接する4方向へ動くことができる。その場に留まってもよい。ただし、オバケは壁を通ったり、同じマスに2匹以上のオバケが止まったり、すれ違って体がぶつかるような動き方はできない。全てのオバケが定位置に動くまでの最短ステップを答えよ。以下のことが保証されている
4 <= W, H <= 16
1 <= N <= 3
屋敷の周囲は壁マスで覆われている
任意の2マスについてお互いに到達可能である
壁マスは全て連続している
任意の2*2マスにおいて少なくとも1つの壁マスがある

  • 解法

3匹のオバケの位置の組をノードにした幅優先探索で解きました。
オバケは5通りの動き方があるので、次の状態は5^3 = 125個あります。このとき、同じマスに止まっているオバケがいないか、すれ違う動きをしていないかを調べます。すれ違いを検出する方法は、オバケA, Bの現在の位置がA1, B1, 次の移動先がA2, B2だとすると
A2 == B1 && B2 == A1
が成立したときこの2匹はすれ違っています。
更に、オバケの動き方を3重for文で書くと思いますが、ネストの一番内側に枝刈りをまとめて書くのではなく、forを書くたびに、書ける枝刈りを書いておくと実行時間が結構早くなると思います。
探索はこれで問題ないのですが、一番ネックなのがメモリ制限です。32MBと指定されており、使用言語がJavaだとすると、その倍の64MBまで使えます。
1匹のオバケの状態は高々256通りです。最大で3匹なので256^3 = 16777216通りの状態があります。開始状態からの最短ステップ数を格納するために、256^3のint配列を作ろうとすると、
256^3 * 4 = 67108KB = 67MB
でアウトです。そこで、2Byteの大きさのshort配列ならこの半分で済むので大丈夫かと思うのですがMLEになりました(ジャッジには80MBと記載されていました)。よって、1Byteのbyte配列を使うことにしました。byteにすると、答えがオーバーフローを起こすという別の問題が生じます。正の数が127までしか使えないからです。byteを符号なしで解釈すれば255まで使えますが、ジャッジデータには答えがそれよりも大きいケースがあるのでこれでもアウトです。自分のソースでは、ステップ数が255を超えそうになったら、ステップ数をint型の変数へ保存させ、探索済みの状態全ての最短ステップ数を-128で初期化しています。解を出すときには、求まった答えに更に退避させておいたステップ数を足しこむということをしています。これで何とか解くことができました。

  • ソース
import java.util.LinkedList;
import java.util.Queue;
import java.util.Scanner;

//The Morning after Halloween
public class AOJ1281 {
	int w, h, N;
	char[][] map;
	int[][] d = {{-1,0},{0,1},{1,0},{0,-1},{0,0}};
	byte INF = Byte.MAX_VALUE, MIN = Byte.MIN_VALUE;
	//short型で256^3のメモリをとるとMLEになってしまうのでbyteでガマン
	//unsigned風に値を考える必要がある -2^7〜2^7-1 -> 0〜2^8-1 つまり128を足せばいい
	byte[][][] dist;
	int[] si, sj, gi, gj;
	
	int one(){
		for(int i=0;i<w*h;i++)dist[0][0][i]=INF;
		dist[0][0][si[0]*w+sj[0]] = MIN;
		Queue<Integer> q = new LinkedList<Integer>();
		q.add(si[0]*w+sj[0]);
		while(!q.isEmpty()){
			int V = q.poll();
			int i = V/w, j = V%w;
			if(i==gi[0]&&j==gj[0])return (int)dist[0][0][V] + 128;
			for(int k=0;k<4;k++){
				int ni = i+d[k][0], nj = j+d[k][1];
				if(map[ni][nj]==' '&&dist[0][0][ni*w+nj]==INF){
					dist[0][0][ni*w+nj] = (byte) (dist[0][0][V]+1);
					q.add(ni*w+nj);
				}
			}
		}
		return -1;
	}
	int two(){
		int M = w*h;
		for(int i=0;i<M;i++)for(int j=0;j<M;j++)dist[i][j][0]=INF;
		dist[si[0]*w+sj[0]][si[1]*w+sj[1]][0] = MIN;
		Queue<Integer> q = new LinkedList<Integer>();
		q.add((si[0]*w+sj[0])*M+si[1]*w+sj[1]);
		while(!q.isEmpty()){
			int V = q.poll(), VA = V/M, VB = V%M;
			int ai = VA/w, aj = VA%w, bi = VB/w, bj = VB%w;
			if(ai==gi[0]&&aj==gj[0]&&bi==gi[1]&&bj==gj[1])return (int)dist[VA][VB][0] + 128;
			for(int k0=0;k0<5;k0++){
				int Ai = ai+d[k0][0], Aj = aj+d[k0][1];
				if(map[Ai][Aj]=='#')continue;
				for(int k1=0;k1<5;k1++){
					int Bi = bi+d[k1][0], Bj = bj+d[k1][1];
					if(map[Bi][Bj]=='#')continue;
					if(Ai==Bi&&Aj==Bj)continue;
					if(Ai==bi&&Aj==bj&&Bi==ai&&Bj==aj)continue;
					int PA = Ai*w+Aj, PB = Bi*w+Bj;
					if(dist[PA][PB][0]!=INF)continue;
					dist[PA][PB][0] = (byte) (dist[VA][VB][0]+1);
					q.add(PA*M+PB);
				}
			}
		}
		return -1;
	}
	int three(){
		int FIX = 0;
		int M = w*h, MM = w*h*w*h, SA = si[0]*w+sj[0], SB = si[1]*w+sj[1], SC = si[2]*w+sj[2];
		for(int i=0;i<M;i++)for(int j=0;j<M;j++)for(int k=0;k<M;k++)dist[i][j][k]=INF;
		dist[SA][SB][SC] = MIN;
		Queue<Integer> q = new LinkedList<Integer>();
		q.add(SA*MM+SB*M+SC);
		while(!q.isEmpty()){
			int V = q.poll(), VA = V/MM, VB = (V%MM)/M, VC = V%M;
			int ai = VA/w, aj = VA%w, bi = VB/w, bj = VB%w, ci = VC/w, cj = VC%w;
			//答えが2^8以上になる場合のための対処
			//カウントをMINでリセットし、現在までカウントした分をFIXにとっておき、答えを返すときに足しこむ
			if(dist[VA][VB][VC]==INF-1){
				FIX = dist[VA][VB][VC] + 128;
				for(int i=0;i<M;i++)for(int j=0;j<M;j++)for(int k=0;k<M;k++)if(dist[i][j][k]!=INF)dist[i][j][k]=MIN;
			}
			if(ai==gi[0]&&aj==gj[0]&&bi==gi[1]&&bj==gj[1]&&ci==gi[2]&&cj==gj[2])return (int)dist[VA][VB][VC] + 128 + FIX;
			for(int k0=0;k0<5;k0++){
				int Ai = ai+d[k0][0], Aj = aj+d[k0][1], PA = Ai*w+Aj;
				if(map[Ai][Aj]=='#')continue;
				for(int k1=0;k1<5;k1++){
					int Bi = bi+d[k1][0], Bj = bj+d[k1][1], PB = Bi*w+Bj;
					if(map[Bi][Bj]=='#')continue;
					if(PA==PB)continue;
					if(PA==VB&&PB==VA)continue;
					for(int k2=0;k2<5;k2++){
						int Ci = ci+d[k2][0], Cj = cj+d[k2][1], PC = Ci*w+Cj;
						if(map[Ci][Cj]=='#')continue;
						if(PA==PC||PB==PC)continue;
						if(PA==VC&&PC==VA || PB==VC&&PC==VB)continue;
						if(dist[PA][PB][PC]!=INF)continue;
						dist[PA][PB][PC] = (byte) (dist[VA][VB][VC]+1);
						q.add(PA*MM+PB*M+PC);
					}
				}
			}
		}
		return -1;
	}
	
	void run(){
		Scanner sc = new Scanner(System.in);
		dist = new byte[256][256][256];
		si = new int[3]; sj = new int[3]; gi = new int[3]; gj = new int[3];
		for(;;){
			w = sc.nextInt(); h = sc.nextInt(); N = sc.nextInt();
			if((w|h|N)==0)break;
			sc.nextLine();
			map = new char[h][];
			for(int i=0;i<h;i++){
				map[i] = sc.nextLine().toCharArray();
				for(int j=0;j<w;j++){
					if(Character.isLowerCase(map[i][j])){
						si[map[i][j]-'a'] = i; sj[map[i][j]-'a'] = j; map[i][j] = ' ';
					}
					else if(Character.isUpperCase(map[i][j])){
						gi[map[i][j]-'A'] = i; gj[map[i][j]-'A'] = j; map[i][j] = ' ';
					}
				}
			}
			System.out.println(N==1?one():N==2?two():three());
		}
	}
	
	public static void main(String[] args) {
		new AOJ1281().run();
	}
}