[알고리즘][X] 토끼와 경주

2024. 4. 3. 21:58알고리즘 풀이/Java

https://www.codetree.ai/training-field/frequent-problems/problems/rabit-and-race/description?page=1&pageSize=20

 

코드트리 | 코딩테스트 준비를 위한 알고리즘 정석

국가대표가 만든 코딩 공부의 가이드북 코딩 왕초보부터 꿈의 직장 코테 합격까지, 국가대표가 엄선한 커리큘럼으로 준비해보세요.

www.codetree.ai

import java.awt.Point;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class Main {

    static class Rabbit implements Comparable<Rabbit>{
        int id;
        Point point;
        int jump;
        int d;
        long score;
        long minus;
        public Rabbit(int id, Point point, int jump, int d) {
            super();
            this.id = id;
            this.point = point;
            this.jump = jump;
            this.d = d;
            this.score = 0;
            this.minus = 0;
        }
        @Override
        public int compareTo(Rabbit o) {
            if (jump == o.jump) {
                if ((point.x+point.y) == (o.point.x + o.point.y)) {
                    if (point.x == o.point.x) {
                        if (point.y == o.point.y) {
                            return id - o.id;
                        }
                        return point.y - o.point.y;
                    }
                    return point.x - o.point.x;
                }
                return (point.x+point.y) - (o.point.x + o.point.y);
            }
            return jump - o.jump;
        }
        @Override
        public String toString() {
            return "Rabbit [id=" + id + ", point=" + point + ", jump=" + jump + ", d=" + d + ", score=" + score
                    + ", minus=" + minus + "]";
        }


    }

    private static void printBest() {
        long result = Integer.MIN_VALUE;
        for(Rabbit rabbit : pq) {
//			System.out.println(rabbit.id + " " + (totalSum - rabbit.minus + rabbit.score));
//            System.out.println(totalSum + " " + rabbit.minus + " " + rabbit.score);
            result = Math.max(result, totalSum - rabbit.minus + rabbit.score);
        }
        System.out.println(result);
//		System.out.println();
    }

    private static void changeDir() {
        int pid = Integer.parseInt(st.nextToken());
        int L = Integer.parseInt(st.nextToken());
        Rabbit rabbit = rabbits.get(pid);
        rabbit.d *= L;
    }

    private static void race() {
        int K = Integer.parseInt(st.nextToken());
        int S = Integer.parseInt(st.nextToken());

        List<Rabbit> temp = new ArrayList<>();
        for (int k = 0; k < K; k++) {
            Rabbit rabbit = pq.poll();
            PriorityQueue<Point> dir = new PriorityQueue<>(new Comparator<Point>() {

                @Override
                public int compare(Point o1, Point o2) {
                    if (o1.x + o1.y == o2.x + o2.y) {
                        if (o1.x == o2.x) {
                            return o2.y - o1.y;
                        }
                        return o2.x - o1.x;
                    }
                    return (o2.x + o2.y) - (o1.x + o1.y);
                }

            });
            for (int i = 0; i < 4; i++) {
                int nx = rabbit.point.x;
                int ny = rabbit.point.y;

                Point converted = convert(nx, ny, i, rabbit.d);
                dir.add(converted);
            }
            rabbit.point = dir.poll();
            totalSum += rabbit.point.x + rabbit.point.y;
            rabbit.minus += rabbit.point.x + rabbit.point.y;
            rabbit.jump++;
            temp.add(rabbit);
            pq.add(rabbit);
//			System.out.println("k " + k + " " + totalSum);
//			print();
//			System.out.println(rabbit);
        }

        temp.sort(new Comparator<Rabbit>() {

            @Override
            public int compare(Rabbit o1, Rabbit o2) {
                if ((o1.point.x+o1.point.y) == (o2.point.x + o2.point.y)) {
                    if (o1.point.x == o2.point.x) {
                        if (o1.point.y == o2.point.y) {
                            return o2.id - o1.id;
                        }
                        return o2.point.y - o1.point.y;
                    }
                    return o2.point.x - o1.point.x;
                }
                return (o2.point.x+o2.point.y) - (o1.point.x + o1.point.y);
            }

        });

        Rabbit rabbit = temp.get(0);
        rabbit.score += S;
//		System.out.println("rabbit " + rabbit);
    }

    private static Point convert(int x, int y, int d, int L) {
        if (d < 2) {
            // 좌우
            int cycle = 2 * (M - 1);
            L %= cycle;
            // 우
            if (d == 1) {
                int gap = y - 1;
                y = 1;
                L += gap;
                L %= cycle;
                if (L <= cycle / 2) {
                    y += L;
                } else {
                    L -= cycle / 2;
                    y = M - L;
                }
            } else { // 좌
                int gap = M - y;
                y = M;
                L += gap;
                L %= cycle;
                if (L <= cycle / 2) {
                    y -= L;
                } else {
                    L -= cycle / 2;
                    y = 1 + L;
                }
            }
//			for (int i = 0; i < L; i++) {
//				int nx = x + dx[d];
//				int ny = y + dy[d];
//
//				if (!(1 <= nx && nx <= N && 1 <= ny && ny <= M)) {
//					d = 1 - d;
//				}
//				x += dx[d];
//				y += dy[d];
//			}
        } else {
            // 상하
            int cycle = 2 * (N - 1);
            L %= cycle;
            // 하
            if (d == 3) {
                int gap = x - 1;
                x = 1;
                L += gap;
                L %= cycle;
                if (L <= cycle / 2) {
                    x += L;
                } else {
                    L -= cycle / 2;
                    x = N - L;
                }
            } else { // 좌
                int gap = N - x;
                x = N;
                L += gap;
                L %= cycle;
                if (L <= cycle / 2) {
                    x -= L;
                } else {
                    L -= cycle / 2;
                    x = 1 + L;
                }
            }
        }
//		if (nx > N) {
//			nx %= N;
//		} else if (nx <= 0) {
//			nx = (nx % N) + N;
//		}
//
//		if (ny >= M) {
//			ny = ny % M;
//		} else if (ny <= 0) {
//			ny = (ny % M) + M;
//		}
        return new Point(x, y);
    }

    private static void print() {
        System.out.println("******************PRINT********************");
        for(Rabbit rabbit : pq) {
            System.out.println(rabbit);
        }
        System.out.println();
    }
    private static void init() throws Exception {
        pq = new PriorityQueue<>();
        rabbits = new HashMap<>();

        N = Integer.parseInt(st.nextToken());
        M = Integer.parseInt(st.nextToken());
        P = Integer.parseInt(st.nextToken());

        for (int i = 0; i < P; i++) {
            int pid = Integer.parseInt(st.nextToken());
            int d = Integer.parseInt(st.nextToken());

            Rabbit rabbit = new Rabbit(pid, new Point(1, 1), 0, d);
            pq.add(rabbit);
            rabbits.put(pid, rabbit);
        }
    }

    static BufferedReader br;
    static StringTokenizer st;
    static int N;
    static int M;
    static int P;
    static PriorityQueue<Rabbit> pq;
    static HashMap<Integer, Rabbit> rabbits;
    static long totalSum = 0;
    static int[] dx = {0, 0, -1, 1};
    static int[] dy = {-1, 1, 0, 0};

    public static void main(String[] args) throws Exception {
        br = new BufferedReader(new InputStreamReader(System.in));
        int Q = Integer.parseInt(br.readLine());
        for (int q = 0; q < Q; q++) {
            st = new StringTokenizer(br.readLine());
            int cmd = Integer.parseInt(st.nextToken());
            switch (cmd) {
                case 100:
                    init();
//				printBest();
                    break;
                case 200:
                    race();
//				printBest();
                    break;
                case 300:
                    changeDir();
//				printBest();
                    break;
                case 400:
                    printBest();
                    break;
            }
        }
//		N = 5;
//		M = 5;
//		System.out.println(convert(1, 1, 2, 16));
    }

}

나의 풀이

- 한번에 이동한 위치를 구하도록 수식을 짰다

- totalSum 100억까지 될 수 있기 때문에 int로 커버 불가능하고 long으로 커버해주어야 한다. 조심하자!!!!

'알고리즘 풀이 > Java' 카테고리의 다른 글

[알고리즘][X] 술래 잡기  (0) 2024.04.09
[알고리즘] 보급로  (0) 2024.04.04
[알고리즘][X] 키 순서  (0) 2024.04.03
[알고리즘] 싸움땅  (0) 2024.04.02
[알고리즘][X] 산타의 선물 공장  (0) 2024.04.02