BOJ 1182 부분수열의 합 (Java, Python)
문제 탐색하기
양수 크기의 부분수열 중 원소들의 합이 특정 값 S가 되는 경우를 찾는 부분집합 문제이다.
즉, 모든 부분집합을 생성한 후, 그 합이 S인지 확인하는 방식으로 접근한다.
- 이전에 python3로 풀었던 코드에서는
combination
내장 함수를 사용했었다. - Java로 구현하는 경우 부분 집합 코드를 직접 구현해 주어야 함에 유의하라.
코드 설계
부분집합 구현하기
재귀/백트래킹 방식
- 가장 직관적이며 코드 작성이 비교적 쉽다.
- 모든 부분집합을 자연스럽게 탐색할 수 있으며, 문제에 따라 불필요한 경우를 가지치기(pruning)하기도 용이하다.
/**
* 1) 재귀/백트래킹 방식
* 모든 부분집합(빈 집합 포함)을 출력하는 예시
*/
public static void printAllSubsetsBacktracking(int[] arr) {
List<Integer> current = new ArrayList<>();
backtrack(arr, 0, current);
}
private static void backtrack(int[] arr, int start, List<Integer> current) {
// 현재까지 만든 부분집합 출력
System.out.println(current);
// start 인덱스부터 끝까지 순회하며 원소를 하나씩 포함/미포함 결정
for (int i = start; i < arr.length; i++) {
current.add(arr[i]);
backtrack(arr, i + 1, current);
current.remove(current.size() - 1);
}
}
비트마스킹 방식
- 코드가 간결하지만, 일반적으로 N의 크기가 작을 때 주로 사용한다.
- N이 작다면 문제 해결에 충분하지만, 가독성 측면에서는 다소 떨어질 수 있다.
/**
* 2) 비트마스킹 방식
* 0 ~ (2^n - 1)까지의 수를 순회하며, 각 비트가 켜져 있으면 해당 원소를 부분집합에 포함.
*/public static void printAllSubsetsBitmask(int[] arr) {
int n = arr.length;
// 모든 비트마스크를 순회 (0은 모든 원소가 미포함인 빈 집합)
for (int mask = 0; mask < (1 << n); mask++) {
List<Integer> subset = new ArrayList<>();
for (int i = 0; i < n; i++) {
// i번째 비트가 켜져 있으면 arr[i]를 포함
if ((mask & (1 << i)) != 0) {
subset.add(arr[i]);
}
}
System.out.println(subset);
}
}
조합 방식
- 특정 크기의 조합만 고려한다면 유용하다.
- 하지만 모든 부분집합(크기가 양수인 모든 경우)을 탐색해야 한다면, 재귀적 백트래킹 방식이 더 자연스럽다.
/**
* 3) 조합(Combination) 방식
* 특정 크기 r의 부분집합(조합)만을 구하고 싶을 때 유용하다.
* 다음 예시는 크기 r인 모든 조합을 출력하는 예시이다.
*/public static void printCombinationsOfSize(int[] arr, int r) {
combinationHelper(arr, 0, r, new ArrayList<>());
}
private static void combinationHelper(int[] arr, int start, int r, List<Integer> current) {
// 현재 조합의 크기가 r이면 출력(혹은 다른 처리)
if (current.size() == r) {
System.out.println(current);
return;
}
// start부터 남은 원소들을 이용해 r - current.size() 만큼 뽑는 방식
for (int i = start; i < arr.length; i++) {
current.add(arr[i]);
combinationHelper(arr, i + 1, r, current);
current.remove(current.size() - 1);
}
}
위와 같은 부분집합 구현 방식을 참고하고, N = 20이므로 가능한 경우의 수는 약 100만 개 수준이다. 따라서 시간 복잡도를 고려하여 재귀 방식과 비트마스킹 방식을 모두 구현해 보기로 했다.
정답 코드 (Java)
Backtracking을 사용한 풀이
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
public class b1182_backtracking {
private static int N, S, answer;
private static int[] arr;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
N = Integer.parseInt(st.nextToken());
S = Integer.parseInt(st.nextToken());
arr = new int[N];
answer = 0;
// 배열 입력받기
st = new StringTokenizer(br.readLine());
for (int i = 0; i < N; i++) {
arr[i] = Integer.parseInt(st.nextToken());
}
backtracking(0, 0);
// S가 0이면 공집합은 제외한다.
System.out.println(S == 0 ? answer - 1 : answer );
}
private static void backtracking(int idx, int sum) {
if (idx == N) {
if (sum == S) {
answer++;
}
return;
}
backtracking(idx + 1, sum + arr[idx]); // 선택하는 경우
backtracking(idx + 1, sum); // 선택하지 않는 경우
}
}
Bitmasking을 사용한 풀이
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
public class b1182_bitmask {
private static int N, S, answer;
private static int[] arr;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
N = Integer.parseInt(st.nextToken());
S = Integer.parseInt(st.nextToken());
arr = new int[N];
answer = 0;
// 배열 입력받기
st = new StringTokenizer(br.readLine());
for (int i = 0; i < N; i++) {
arr[i] = Integer.parseInt(st.nextToken());
}
bitmasking(N, arr, S);
System.out.println(answer);
}
private static void bitmasking(int N, int[] arr, int S) {
for (int bitmask = 1; bitmask < (1 << N); bitmask++) {
int sum = 0;
for (int j = 0; j < N; j++) {
if ((bitmask & (1 << j)) != 0) {
sum += arr[j];
}
}
if (sum == S) {
answer++;
}
}
}
}
정답 코드 (Python3)
itertools
모듈의 combinations
함수를 활용한다.
import sys
input = sys.stdin.readline
from itertools import combinations
N, S = map(int, input().split())
data = list(map(int, input().split()))
count = 0
for i in range(1,N+1):
per = combinations(data,i)
for j in list(per):
if sum(j) == S:
count += 1
print(count)
Leave a comment