알고리즘/백준

[백준] 15824. 너 봄에는 캡사이신이 맛있단다 (python / 파이썬)

AquaplaneMode 2023. 6. 30. 17:05

문제 출처 : https://www.acmicpc.net/problem/15824

 

15824번: 너 봄에는 캡사이신이 맛있단다

한 줄에 모든 조합의 주헌고통지수 합을 1,000,000,007로 나눈 나머지를 출력한다.

www.acmicpc.net


1. 첫 풀이 : (50점)

가령 입력 배열이 다음과 같이 주어졌다고 가정하자.

1 4 5 5 6 10

해당 배열을 오름차순으로 정렬한 다음, 두 수 a, b를 뽑는다. 만약 a = 1, b = 6이라고 가정해보자.

a, b는 각각 최소값, 최대값이기 때문에, 두 수 사이에 있는 음식은 먹더라도 먹지 않더라도 주헌고통지수는 변하지 않는다.

 

- 캡사이신 수치가 4인 음식은 먹을수도 있고, 먹지 않을 수도 있다. (경우의 수 2)

- 캡사이신 수치가 5인 음식은 먹을수도 있고, 먹지 않을 수도 있다. (경우의 수 2)

- 캡사이신 수치가 5인 음식은 먹을수도 있고, 먹지 않을 수도 있다. (경우의 수 2)

 

따라서 이 경우, 주헌고통지수는 (6-1) * 2^3 = 40 이며, 이를 일반화한다면 다음과 같이 쓸 수 있다.

주헌고통지수 = (b - a) * 2^(b와 a 사이에 있는 음식의 개수)

해당 공식에 따라, 모든 음식 조합을 찾아 주헌고통지수를 구하면 된다.

코드는 다음과 같다.

import sys
input = sys.stdin.readline

"""
정렬 후 두 점을 잡고
두 점 사이에 있는 모든 부분집합 개수 (2^n)
"""

num_food = int(input())
foods = sorted(list(map(int, input().split())))

answer = 0

for first_idx in range(0, num_food-1):
    first_food = foods[first_idx]

    for second_idx in range(first_idx+1, num_food):
        second_food = foods[second_idx]

        juheon_index = second_food - first_food # 주헌 고통 지수 계산
        answer += juheon_index * (1<< (second_idx - first_idx -1))

print(answer%(10**9+7))

2. 보다 일반화된 풀이 : (50점)

1번 풀이는 다음과 같은 방식으로 계산횟수를 줄일 수 있다. 다시 한 번, 입력 배열이 다음과 같이 주어졌다고 가정해보자.

1 4 5 5 6 10

마찬가지로 두 음식 a, b를 뽑는데, 두 음식 사이에 있는 다른 음식의 개수로 상황을 한 번 나눠보자.

 

- 상황 1 : 두 음식 사이에 음식이 하나도 없는 경우

 

1번 풀이와 같이 했을 경우, 해당 경우에서 주헌고통지수를 구하는 방법은 다음과 같다.

$$
(4-1) * (2^{0}) + (5-4) * (2^0) + (5-5) * (2^0) + (6-5) * (2^0) + (10 - 6) * (2^0)
$$

 

- 상황 2 : 두 음식 사이에 음식이 한 개 있는 경우

 

1번 풀이와 같이 했을 경우, 해당 경우에서 주헌고통지수를 구하는 방법은 다음과 같다.

$$
(5-1) * (2^1) + (5-4) * (2^1) + (6-5) * (2^1) + (10-5)  * (2^1)
$$

 

- 상황 3 : 두 음식 사이에 음식이 두 개 있는 경우

 

1번 풀이와 같이 했을 경우, 해당 경우에서 주헌고통지수를 구하는 방법은 다음과 같다.

$$
(5-1) * (2^2) + (6-4) * (2^2) + (10-5) * (2^2)
$$

 

그리고 각각의 상황은 다음과 같이 정리할 수 있다.

$$
{(4+5+5+6+10) - (1+4+4+5+6)} * 2^0 \\
{(5+5+6+10) - (1+4+5+5)} * 2^1 \\
{(5+6+10) - (1+4+5)} * 2^2
$$

 

덧셈이 발생하는 구간과 뺄셈이 발생하는 구간을 하나로 묶어주는 것으로, 곱셈 연산과 제곱 연산을 1회로 줄일 수 있는 것이다. 이를 시각화하자면 다음과 같다.

 

두 음식 사이에 다른 음식이 없을 때
두 음식 사이에 하나의 음식이 있을 때
두 음식 사이에 두 개의 음식이 있을 때

 

이를 코드로 나타내면 다음과 같다.

import sys
input = sys.stdin.readline

# 입력값 받기

len_arr = int(input())
arr = list(map(int, input().split()))

arr.sort() # 정렬

steps = 0 # 최소값과 최대값 사이의 거리

left_pointer = 0
right_pointer = len_arr-1

left_sum = sum(arr)
right_sum = left_sum

answer = 0
D = 10**9+7

while right_pointer :

    left_sum -= arr[left_pointer]
    right_sum -= arr[right_pointer]

    left_pointer += 1
    right_pointer -= 1

    answer += ((left_sum - right_sum) << steps) % D
    steps += 1

answer %= D
print(answer)

1. 우선 입력 배열의 전체 합을 구한다.

2. 가장 왼쪽을 left_pointer, 가장 오른쪽을 right_pointer로 설정한다.

3. 전체 합에서 pointer가 가리키는 값을 빼고, 두 합의 차이에 2^(두 음식 사이에 있는 음식의 개수)를 곱해준다.

4. 포인터를 가운데로 움직인다

5. 1-4 과정을 반복한다.

 

가장 처음 방법은 835ms가 나왔는데 반해, 두 번째 방법은 44ms로 시간을 획기적으로 단축할 수 있었다.

그러나 아직도 부분적으로 맞았을 뿐, 문제를 완전히 맞추지 못했다.


3. 분할을 이용한 거듭제곱 활용 : (250점)

 

문제를 온전히 맞추기 위해서는, 배열의 길이가 30만일 때를 풀 수 있어야 한다.

그러나 위와 같은 방법으로 푼다면, 2의 30만제곱을 계산할 때 시간초과가 날 수밖에 없다.

따라서 분할을 이용한 거듭제곱을 사용하기 위해 다음과 같은 코드를 추가하였다

# 분할을 통한 거듭제곱 
pow_dic = dict()
pow_dic[0] = 1
pow_dic[1] = 2

def pow(num):

    cur_result = pow_dic.get(num)

    # 재귀 탈출 조건
    if cur_result:
        return cur_result
    
    # 재귀 진행 조건
    mid = num//2

    left_val = pow(mid)
    right_val = pow(num - mid)

    total_val = left_val *  right_val
    pow_dic[num] = total_val % P

    return pow_dic[num]

가령 2의 10제곱을 구한다고 해보자. (pow(10))

아직 10이 pow_dic에 해당 값이 저장되어 있지 않기 때문에, 위 함수는 다음 두 함수를 호출한다. (pow(5), pow(5))

- 5.1은 첫번째 5, 5.2는 두번째 5를 의미하는데, 이는 트리를 만들기 위해 임의로 붙인 값이다.

또한, 5가 아직 pow_dic에 저장되어 있지 않기 때문에, 위 함수는 다음 두 함수를 호출한다. (pow(2), pow(3))

마찬가지로, 2가 아직 pow_dic에 저장되어 있지 않기 때문에, 위 함수는 다음 두 함수를 호출한다. (pow(1), pow(1))

1은 pow_dic에 저장되어 있다. 따라서 pow(2)는 pow(1) * pow(1) = 2*2 = 4가 되어 pow_dic에 저장되고 반환된다.

pow(3)은 pow(2)와 pow(1)로 나눠지고, 두 값 모두 저장되어 있기 때문에, pow_dic[3] = 8이 저장되고 반환된다.

따라서 pow(5) = pow(2) * pow(3) = 32가 반환되고, pow(10) = pow(5) * pow(5) = 1024가 반환된다.

 

또한 모듈러 연산은 다음과 같은 공식이 성립한다.

$$
(A * A)\ mod\ P = ((A\ mod\ P) * (A\ mod\ P))\ mod\ P
$$

 

따라서 거듭제곱의 결과를 바로 dict에 넣지 않고, P로 mod 연산을 수행한 후에 넣어주었다.

이를 통해 보다 작은 수를 적은 계산을 통해 계산할 수 있다. 

 

전체 코드는 다음과 같다.

import sys
# sys.stdin = open("input.txt", "r")
input = sys.stdin.readline

# 입력값 받기

len_arr = int(input())
arr = list(map(int, input().split()))

arr.sort() # 정렬

steps = 0 # 최소값과 최대값 사이의 거리

left_pointer = 0
right_pointer = len_arr-1

left_sum = sum(arr)
right_sum = left_sum

answer = 0
D = 10**9+7 

# 분할을 통한 거듭제곱 
pow_dic = dict()
pow_dic[0] = 1
pow_dic[1] = 2

def pow(num):

    cur_result = pow_dic.get(num)

    # 재귀 탈출 조건
    if cur_result:
        return cur_result
    
    # 재귀 진행 조건
    mid = num//2

    left_val = pow(mid)
    right_val = pow(num - mid)

    total_val = left_val *  right_val
    pow_dic[num] = total_val % D

    return pow_dic[num]


# 고통 지수 찾기
while right_pointer :

    left_sum -= arr[left_pointer]
    right_sum -= arr[right_pointer]

    left_pointer += 1
    right_pointer -= 1

    answer += (left_sum - right_sum) * pow(steps)
    steps += 1

answer %= D

print(answer)
# print(pow_dic)