[알고리즘] 국가행정

2024. 3. 13. 18:04알고리즘 풀이/Java

import java.util.Comparator;
import java.util.PriorityQueue;
import java.util.Queue;

class UserSolution {

    static int[] population;
    static Queue<int[]> distances;
    static int[] distance;
    static int[] initAccSum;
    static Node root;
    static class Node {
    	int start;
    	int end;
    	int sum;
    	Node left;
    	Node right;
		public Node(int start, int end, int sum) {
			super();
			this.start = start;
			this.end = end;
			this.sum = sum;
		}
		public Node() {
			super();
		}
		public void init(int start, int end, int sum) {
			this.start = start;
			this.end = end;
			this.sum = sum;
			left = null;
			right = null;
		}
		
    }
    static Node[] nodes = new Node[20001];
    static int nodeCount = 0;
    void init(int N, int mPopulation[]) {
    	if (nodes[0] == null) {
    		for (int i = 0; i <= 20000; i++ ) {
    			nodes[i] = new Node();
    		}
    	}
    	nodeCount = 0;
        population = mPopulation;
        distances = new PriorityQueue<>(new Comparator<int[]>() {
            @Override
            public int compare(int[] o1, int[] o2) {
                if (o1[1] != o2[1]) {
                    return o2[1] - o1[1];
                }
                return o1[0] - o2[0];
            }
        });
        distance = new int[N + 1];
        for (int i = 1; i < N; i++) {
            // index i에 i-1과 i 간 거리 저장
            distances.add(new int[]{i, population[i - 1] + population[i], 1});
            distance[i] = population[i - 1] + population[i];
        }
        initAccSum = new int[N];
        initAccSum[0] = 0;
        for (int i = 1; i < N; i++) {
        	initAccSum[i] += distance[i] + initAccSum[i-1];
        }
        
        root = nodes[nodeCount++];
        root.init(0, N-1, initAccSum[N-1]);;
        initTree(root);
    }

    private void initTree(Node curNode) {
    	int start = curNode.start;
    	int end = curNode.end;
    	int sum = curNode.sum;

    	if (end - start == 1) {
    		return;
    	}
    	int mid = (start + end) / 2;
    	Node left = nodes[nodeCount++];
    	left.init(start, mid, initAccSum[mid] - initAccSum[start]);
    	Node right = nodes[nodeCount++];
    	right.init(mid, end, initAccSum[end] - initAccSum[mid]);
    	
    	curNode.left = left;
    	curNode.right = right;
    	initTree(curNode.left);
    	initTree(curNode.right);
    	return;
	}

	int expand(int M) {
        int newDistance = -1;
        for (int i = 0; i < M; i++) {
            int[] maxDistance = distances.poll();
            int index = maxDistance[0];
            int line = maxDistance[2];
            // 차선을 늘린 새로운 거리 저장
            newDistance = (population[index - 1] + population[index]) / (line + 1);
            int gap = distance[index] - newDistance;
            maxDistance[1] = newDistance;
            maxDistance[2]++;
            distance[index] = newDistance;
            distances.add(maxDistance);
            // 갱신해주기
            revalueTree(root, index, gap);
           
        }
        
        return newDistance;
    }

    private void revalueTree(Node curNode, int index, int gap) {
    	int start = curNode.start;
    	int end = curNode.end;
		
    	curNode.sum -= gap;
    	
    	if (end - start == 1) {
    		return;
    	}
    	if (curNode.left.start < index && index <= curNode.left.end) {
    		revalueTree(curNode.left, index, gap);
    	} else {
    		revalueTree(curNode.right, index, gap);
    	}
	}

	int calculate(int mFrom, int mTo) {
        if (mFrom > mTo) {
            int temp = mFrom;
            mFrom = mTo;
            mTo = temp;
        }
        int sum = findAccSum(root, mFrom, mTo);
        return sum;
    }

    private int findAccSum(Node curNode, int from, int to) {
    	int totalSum = 0;
    	int start = curNode.start;
    	int end = curNode.end;
    	int sum = curNode.sum;
    	
    	if (from == to) {
    		return 0;
    	}
    	if (start == from && end == to) {
    		return sum;
    	}
    	
    	int mid = (start + end) / 2;
    	if (from <= mid && mid <= to) {
    		totalSum += findAccSum(curNode.left, from, mid);
    		totalSum += findAccSum(curNode.right, mid, to);
    	} else if (from <= mid && to <= mid) {
    		totalSum += findAccSum(curNode.left, from, to);
    	} else {
			totalSum += findAccSum(curNode.right, from, to);
		}
    	
    	return totalSum;
	}

	// 이분탐색
    int divide(int mFrom, int mTo, int K) {
        int start = 1;
        int end = (int) 1e7;
        int ans = -1;
//        System.out.println("from " + mFrom);
//        System.out.println("to " + mTo);
        System.out.println("K " + K);
        while (start <= end) {
            // mid는 선거구의 최대 인구 수가 될 수 있는 값
            int mid = (start + end) / 2;
            int cnt = 0;
            
            for (int i = mFrom; i <= mTo && cnt <= K; cnt++) {
                int sum = 0;
                int j = i;
                if (population[i] > mid) {
                	
                	break;
                }
//                if (j == mTo) {
//                	cnt++;
//                	break;
//                }
                while (j <= mTo && sum + population[j] <= mid) {
                		sum += population[j++];
                }
                i = j;
            }
            
            System.out.println(cnt);
            if (cnt < K) {
                end = mid-1;
                ans = mid;
            } else if (cnt == K) {
            	end = mid-1;
            	ans = mid;
            } else {
                start = mid + 1;
            }
        }
        System.out.println(ans);
        return ans;
    }
}