기술/알고리즘

[c++] tree structure를 base로 한 max heap 구현하기

하기싫지만어떡해해야지 2024. 12. 2. 20:11

원래 흔히 배우는 max heap은

자료 구조를 array로 많이 사용한다

 

나도 배울 때는 분명 array 자료구조를 이용해서

max heap을 구현하는 것을 배웠지만

어째서인지 (..) 이번 과제는 array가 아닌

tree 구조를 base로 Max heap을 구현하는 것이었다

 

max heap에 대한 강의 내용 정리는 아래 링크를 참고!

https://think0905.tistory.com/entry/computer-science-Binary-Tree-Max-Heap

 

[computer science] Binary Tree, Max Heap

이 게시글은서울대학교 데이터사이언스대학원조요한 교수님의데이터사이언스 응용을 위한 컴퓨팅 강의를학습을 위해 재구성하였습니다.중간고사가 지나고 이전까지는 기본적인c++에 대해서

think0905.tistory.com

 

우선 test로 실행할 main.cpp 파일은 아래와 같았다

#include <iostream> 
#include <climits> 
#include <queue>
#include <cmath>
using namespace std; 

#include "functions.hpp"
  
// Driver program to test above functions 
int main() 
{ 
    MaxHeap h; 
    h.enqueue(10);
    h.enqueue(30);
    h.enqueue(40);    
    h.enqueue(500);
    h.enqueue(80);
    h.enqueue(170);
    h.enqueue(1000);
    h.enqueue(180);
    h.enqueue(180);
    h.enqueue(80);
    h.enqueue(90);
    h.enqueue(250);
    h.enqueue(580);

    h.dequeue();
    h.dequeue();

    h.printHeap();

    cout << "Maximum value of the heap: " << h.getMax()->val << endl; 

    return 0; 
}

 

enqueue의 case들은 내가 다양하게 늘려봤다

 

그리고 Node와 MaxHeap class를 정의한 functions.hpp파일은 아래와 같다

#include <iostream> 
#include <climits> 
#include <queue>
#include <cmath>
using namespace std; 

struct Node {
    int val;
    Node * left = nullptr;
    Node * right = nullptr;
    Node * parent = nullptr;
};
  

// A class for Min Heap 
class MaxHeap 
{ 
    Node * root;
    Node * last_node;
    int heapsize;

public: 
    // Constructor 
    MaxHeap();
  
    // Returns the node pointer that includes the maximum value (root node)
    Node * getMax();

    void printHeap();

    void swap(Node * a, Node * b);

    // Inserts a new key 'k' 
    void enqueue(int k); 

    // Delete the maxmium value (root node)
    void dequeue();
    
};

 

몇 가지 주목해서 봐야할 점은

보통의 max heap을 구현할 때 Node에는

parent를 거의 저장해서 사용하지 않지만

왜인지 모르겠지만 이번 과제에는 Node마다 parent node정보를 저장해야했다

 

또한, MaxHeap class를 보면 heapsize를 int로 저장하고

last_node를 또 따로 저장해주고있다

 

이도 일반적인 max heap 구현에서는 거의 없지만

이번 과제에서는 maxHeap의 size와 last_node값을

계속해서 업데이트해서 저장해줘야했다

 

그리고 MaxHeap class의 function들은

가장 큰 node(root)를 return하는 getMax

Heap을 순서대로 출력하는 printHeap

node a와 node b의 값을 바꿔주는 swap

maxHeap에 새로운 node를 추가해주는 enqueue

maxHeap에 가장 값이 큰 node(root)를 제거하는 dequeue

이렇게 5개의 함수가 존재한다

 

이 maxHeap을 구현하면서 가장 중요한 점은

enqueue와 dequeue에서 시간 복잡도를 반드시 지켜야한다는 점인데

maxHeap의 enqueue, dequeue의 time complexity는 O(log n)이다

 

따라서 이를 계속 염두에 두며 코딩해야한다

 

 

1. getMax

Node * MaxHeap::getMax(){
    return root;
}

 

MaxHeap class 외부에 정의해줬으므로

:: operator를 통해서 함수를 정의해준다

 

max heap에서 값이 가장 큰 node는 root node이므로

getMax는 root를 return해준다

 

2. printHeap()

void MaxHeap::printHeap(){
    Node * currNode = root;
    queue<Node*> q;

    q.push(root);
    std::cout << "Print Heap: ";
    while (!q.empty() && q.front()){
        std::cout << q.front()->val << " ";
        if (q.front()->left){
            q.push(q.front()->left);
        }

        if (q.front()->right){
            q.push(q.front()->right);
        }

        q.pop();
    }
    std::cout << std::endl;
}

 

값이 가장 큰 root node부터 차례대로 node의 값들을 출력하는

printHeap 함수이다

 

값의 차례대로 출력하기 위해서는 넓이우선탐색(BFS)로

tree를 탐색해야하므로 queue를 이용해서 탐색해주는 코드를 작성한다

 

 

3. swap

void MaxHeap::swap(Node * a, Node * b){
    int temp = a->val;
    a->val = b->val;
    b->val = temp;
}

 

두 node의 val만 가져와서 바꾸어주는 swap함수이다

 

maxHeap을 포함해서 대부분의 tree구조를 사용하는 알고리즘이 그렇겠지만

node들을 바꿔준다는게 물리적인 위치를 바꿔준다기보단

그 안의 값들만 바꿔주는 경우가 많다

 

 

4. enqueue

핵심적인 알고리즘인 enqueue를 살펴보자

 

기본적으로 max heap을 구현할 때 새로운 node를 추가할 때는

(1). 가장 마지막 node 위치에 새 node 추가

(2). max heapify를 통해 자기 위치 찾아가기

이렇게 2단계의 과정을 거친다

 

그럼 저 단계에 따라서 구현해보자

 

(1). 가장 마지막 node 위치 찾아서 추가

처음에는 그냥 위의 printHeap 함수와 비슷하게

BFS로 가장 마지막 노드의 위치를 찾아주도록 했다

 

사실상 이렇게 해줘도 크게 문제는 없으나

만약 maxHeap의 node개수가 굉장히 많다면

수행 속도가 느려진다는 단점이 있다

(그래서 실제로 과제에서 timeout이 떴다..)

 

timeout을 어떻게 해결할지 고민을 하다가

우리의 maxHeap class에는 heapsize를 저장하고있기때문에

이를 활용할 수 있지않을까 생각했다

 

그래서 계속 구글링을 하던 중

마지막 node의 인덱스를 이진법으로 표기하면

maxHeap tree의 height만큼의 자릿수가 나오고

여기서 1이면 해당 높이에서는 right로

0이면 해당 높이에서는 left로 이동한 것이라는 사실을 알게되었다

 

따라서 heapsize는 사실 마지막 node의 인덱스와 같으므로

heapsize를 이진법으로 표기한 후

자릿수의 숫자에 따라 right 혹은 left로 이동한 뒤

이진법 표기에서 마지막 자릿수 확인을 통해

마지막 자릿수가 1이면 -> 오른쪽 node가 비었음

마지막 자릿수가 0이면 -> 왼쪽 node가 비었음

이기에 해당 위치에 바로 새 node를 삽입해주면 될 것 같았다

 

이에따라 코드를 작성해주었다

Node* newNode = new Node();
newNode->val = k;

if (!root) {
    root = newNode;
    last_node = newNode;
    heapsize++;
    return;
}

 

우선 enqueue함수에서는 새로운 node의 값만 input k로 받아오므로

함수 내부에서 new Node를 동적할당해주고

그 안에 val을 k로 설정해준다

 

그 다음 만약 root가 없다면

(maxHeap tree가 아예 없다면)

root에 그냥 새 node를 넣어주면 되므로

해당 로직을 구현해준다

 

그리고 heapsize를 1개 올려주고 return을 통해 함수를 종료한다

 

 

이제 heapsize를 이진법으로 표현한뒤

위에서 말한 로직으로 마지막 node 위치를 찾고

거기에 새로운 node를 추가해보자

heapsize++;

Node* parent = root;
int path = heapsize;
int level = (int)log2(path);
for (int i = level - 1; i > 0; --i) {
    if (path & (1 << i)) {
        parent = parent->right;
    } else {
        parent = parent->left;
    }
}

 

우선 enqueue를 할 것이므로 heapsize를 한개 올려준다

 

그 다음 최초 node는 root로 설정해준다

path는 heapsize를 할당해주고

level은 maxHeap의 높이이므로 path에 log2를 해준 값이다

 

그런 다음 이진법 계산을 해주는데

흔히 이진법을 계산하다보면 나누기 한 나머지를

거꾸로 적어줘야된다

 

따라서 for문을 가장 큰 수에서부터 작아지도록 구현해줘야한다

따라서 level-1에서부터 높이가 0보다 클 때까지 loop를 돌리며

path를 이진법으로 표기한 수에서 i번째 자릿수가 1이면 right로

i번째 자릿수가 0이면 left로 이동한다

 

이렇게 하면 for loop이 끝나고 parent에 저장된 노드는

가장 마지막 node의 위치의 parent 위치가 된다

 

왜냐하면 전체 level을 다 돌지 않고 level-1값을

최대 i값으로 해주었기 때문에

마지막 level에서 한 단계 위의 값까지만 위치를 저장했고

이는 새로 삽입될 node의 parent node가 되는 것이다

 

if (path & 1) {
    parent->right = newNode;
} else {
    parent->left = newNode;
}
newNode->parent = parent;
last_node = newNode;

 

이제 path의 이진법 표기에서 마지막 자릿수만 확인해서

1이면 right에 새로운 node를 넣어주고

0이면 left에 새로운 node를 넣어주었다

 

그런 다음 새로 할당해준 node의 parent도 할당해줘야하므로

newNode->parent = parent를 해주었고

새로 삽입된 node가 마지막 노드이므로

last_node를 newNode로 해주었다

 

 

(2). max heapify를 통해 자기 위치 찾아가기

이제 새로 추가될 node를 가장 마지막 위치에 추가했으므로

maxHeapify를 통해 제 위치를 찾아주도록하자

Node* currNode = newNode;
while (currNode->parent && currNode->val > currNode->parent->val) {
    swap(currNode, currNode->parent);
    currNode = currNode->parent;
}

 

우선 현재 node를 아까 삽입한 newNode로 선언해준다

 

while문을 돌면서 currNode의 parent가 존재하며

currNode의 값이 parent의 값보다 클 경우

parent와 currNode의 값을 swap해주는 방식으로 maxHeapify를 진행해준다

그런 다음 currNode를 parent로 할당해줘야

while문이 정상적으로 한 칸씩 올라가며 작동될 것이다

 

 

enqueue 전체 코드는 아래와 같다

// Inserts a new key 'k' 
void MaxHeap::enqueue(int k) 
{ 
    Node* newNode = new Node();
    newNode->val = k;

    if (!root) {
        root = newNode;
        last_node = newNode;
        heapsize++;
        return;
    }
    heapsize++;

    Node* parent = root;
    int path = heapsize;
    int level = (int)log2(path);
    for (int i = level - 1; i > 0; --i) {
        if (path & (1 << i)) {
            parent = parent->right;
        } else {
            parent = parent->left;
        }
    }

    if (path & 1) {
        parent->right = newNode;
    } else {
        parent->left = newNode;
    }
    newNode->parent = parent;
    last_node = newNode;

    Node* currNode = newNode;
    while (currNode->parent && currNode->val > currNode->parent->val) {
        swap(currNode, currNode->parent);
        currNode = currNode->parent;
    }
    
}

 

이제 마지막으로 dequeue를 구현해보자

 

 

5. dequeue

 

일반적인 maxHeap의 dequeue 구현 과정은

(1). 가장 마지막 node를 찾아 root와 swap

(2). 가장 마지막 위치로 swap된 root node 제거

(3). maxHeapify를 통해 정렬

3단계를 거친다

 

이번에도 이 단계에 맞게 로직을 작성해보자

 

가장 마지막 node를 찾아 제거하는 과정은

enqueue와 마찬가지로 heapsize의 이진법을 바탕으로 한다

 

if (!root) return;
if (root->left == nullptr && root->right == nullptr) {
    delete root;
    root = nullptr;
    last_node = nullptr;
    heapsize = 0;
    return;
}

 

우선 만약 root node가 없다면 아무것도 dequeue할게 없으므로

그냥 함수를 종료한다

 

또, left나 right같은 자식 node가 없이 root node만 있다면

그냥 root node만 제거해주고 heapsize를 0으로 해주면된다

그렇게 코드를 작성해주고 함수를 종료한다

 

(1). 가장 마지막 node를 찾아 root와 swap

원래대로라면 가장 마지막 node의 위치를 찾아서 swap해야겠지만

우리의 maxHeap class에는 last_node를 저장하는 field가 있다

따라서 이를 이용해서 그냥 바로 swap해준다

 

swap(root, last_node);

 

이제 가장 마지막 node의 parent를 찾아

가장 마지막 node로 가는 path를 모두 지워줘야한다

 

이 과정은 위의 enqueue와 마찬가지로

heapsize의 이진법을 통해 찾는다

 

Node* parent = root;
int path = heapsize;
int level = (int)log2(path);
for (int i = level - 1; i > 0; --i) {
    if (path & (1 << i)) {
        parent = parent->right;
    } else {
        parent = parent->left;
    }
}

if (path & 1) {
    delete parent->right;
    parent->right = nullptr;
} else {
    delete parent->left;
    parent->left = nullptr;
}
last_node = parent;
heapsize--;

 

이진법을 통해 가장 마지막 node의 parent 위치를 찾고

enqueue과정과 동일하게

마지막 자릿수가 1이면 right를 nullptr로 변경해주고

마지막 자릿수가 0이면 left를 nullptr로 변경해준다

 

그런다음 last_node를 parent로 변경해주고

heapsize를 1개 줄여준다

 

이제 마지막 노드를 삭제한 tree에서 다시 마지막 node를 찾아서

last_node를 다시 할당시켜줘야한다

Node* new_last_node = root;
int new_path = heapsize;
level = (int)log2(new_path);
for (int i = level - 1; i >= 0; --i) {
    if (new_path & (1 << i)) {
        new_last_node = new_last_node->right;
    } else {
        new_last_node = new_last_node->left;
    }
}
last_node = new_last_node;

 

마지막 노드를 찾는 과정은 위와 동일해서 설명은 생략한다

 

 

이제 root node부터 시작해서 maxHeapify를 통해 재정렬해주면된다

Node* curr_node = root;
while (curr_node) {
    Node* max_node = curr_node;
    if (curr_node->left && curr_node->left->val > max_node->val) {
        max_node = curr_node->left;
    }
    if (curr_node->right && curr_node->right->val > max_node->val) {
        max_node = curr_node->right;
    }
    if (max_node == curr_node) {
        break;
    }

    swap(curr_node, max_node);
    curr_node = max_node;
}

 

curr_node를 우선 root로 설정해준뒤

curr_node를 이용해서 while문을 돌려준다

 

왼쪽 자식이나 오른쪽 자식의 값이 curr_node보다 크다면

max_node에 curr_node의 left 혹은 right를 할당하고

swap을 시켜준다

 

그런 다음 curr_node를 다시 max_node로 할당해서 while문을 돌린다

이런식으로 하면 위의 root에서부터 maxHeapify가 가능해진다

 

 

dequeue의 전체 코드는 다음과 같다

// Removes the root node and heapify
void MaxHeap::dequeue(){
    if (!root) return;
    if (root->left == nullptr && root->right == nullptr) {
        delete root;
        root = nullptr;
        last_node = nullptr;
        heapsize = 0;
        return;
    }

    swap(root, last_node);

    Node* parent = root;
    int path = heapsize;
    int level = (int)log2(path);
    for (int i = level - 1; i > 0; --i) {
        if (path & (1 << i)) {
            parent = parent->right;
        } else {
            parent = parent->left;
        }
    }

    if (path & 1) {
        delete parent->right;
        parent->right = nullptr;
    } else {
        delete parent->left;
        parent->left = nullptr;
    }
    last_node = parent;
    heapsize--;

    Node* new_last_node = root;
    int new_path = heapsize;
    level = (int)log2(new_path);
    for (int i = level - 1; i >= 0; --i) {
        if (new_path & (1 << i)) {
            new_last_node = new_last_node->right;
        } else {
            new_last_node = new_last_node->left;
        }
    }
    last_node = new_last_node;

    Node* curr_node = root;
    while (curr_node) {
        Node* max_node = curr_node;
        if (curr_node->left && curr_node->left->val > max_node->val) {
            max_node = curr_node->left;
        }
        if (curr_node->right && curr_node->right->val > max_node->val) {
            max_node = curr_node->right;
        }
        if (max_node == curr_node) {
            break;
        }

        swap(curr_node, max_node);
        curr_node = max_node;
    }

}

 

 

이렇게 5가지의 maxHeap class의 함수를 모두 구현했다

 

전체 소스코드는 다음과 같다

#include <iostream> 
#include <climits> 
#include <queue>
#include <cmath>
#include <stack>

#include "functions.hpp"

using namespace std; 

MaxHeap::MaxHeap(){
    root = nullptr;
    last_node = nullptr;
    heapsize = 0;   
}  

Node * MaxHeap::getMax(){
    return root;
}

void MaxHeap::printHeap(){
    Node * currNode = root;
    queue<Node*> q;

    q.push(root);
    std::cout << "Print Heap: ";
    while (!q.empty() && q.front()){
        std::cout << q.front()->val << " ";
        if (q.front()->left){
            q.push(q.front()->left);
        }

        if (q.front()->right){
            q.push(q.front()->right);
        }

        q.pop();
    }
    std::cout << std::endl;
}

void MaxHeap::swap(Node * a, Node * b){
    int temp = a->val;
    a->val = b->val;
    b->val = temp;
}

// Inserts a new key 'k' 
void MaxHeap::enqueue(int k) 
{ 
    Node* newNode = new Node();
    newNode->val = k;

    if (!root) {
        root = newNode;
        last_node = newNode;
        heapsize++;
        return;
    }
    heapsize++;

    Node* parent = root;
    int path = heapsize;
    int level = (int)log2(path);
    for (int i = level - 1; i > 0; --i) {
        if (path & (1 << i)) {
            parent = parent->right;
        } else {
            parent = parent->left;
        }
    }

    if (path & 1) {
        parent->right = newNode;
    } else {
        parent->left = newNode;
    }
    newNode->parent = parent;
    last_node = newNode;

    Node* currNode = newNode;
    while (currNode->parent && currNode->val > currNode->parent->val) {
        swap(currNode, currNode->parent);
        currNode = currNode->parent;
    }
    
} 

// Removes the root node and heapify
void MaxHeap::dequeue(){
    if (!root) return;
    if (root->left == nullptr && root->right == nullptr) {
        delete root;
        root = nullptr;
        last_node = nullptr;
        heapsize = 0;
        return;
    }

    swap(root, last_node);

    Node* parent = root;
    int path = heapsize;
    int level = (int)log2(path);
    for (int i = level - 1; i > 0; --i) {
        if (path & (1 << i)) {
            parent = parent->right;
        } else {
            parent = parent->left;
        }
    }

    if (path & 1) {
        delete parent->right;
        parent->right = nullptr;
    } else {
        delete parent->left;
        parent->left = nullptr;
    }
    last_node = parent;
    heapsize--;

    Node* new_last_node = root;
    int new_path = heapsize;
    level = (int)log2(new_path);
    for (int i = level - 1; i >= 0; --i) {
        if (new_path & (1 << i)) {
            new_last_node = new_last_node->right;
        } else {
            new_last_node = new_last_node->left;
        }
    }
    last_node = new_last_node;

    Node* curr_node = root;
    while (curr_node) {
        Node* max_node = curr_node;
        if (curr_node->left && curr_node->left->val > max_node->val) {
            max_node = curr_node->left;
        }
        if (curr_node->right && curr_node->right->val > max_node->val) {
            max_node = curr_node->right;
        }
        if (max_node == curr_node) {
            break;
        }

        swap(curr_node, max_node);
        curr_node = max_node;
    }

}