diff --git a/src/heap.c b/src/heap.c index 490e159..124cd1e 100644 --- a/src/heap.c +++ b/src/heap.c @@ -30,6 +30,7 @@ static int parent(int i) bool heap_peek(struct _heap* self, void* obj) { assert(self != NULL); + assert(obj != NULL); if(obj == NULL) { return false; @@ -52,13 +53,14 @@ static void heap_swap(struct _heap* self, int i, int j) memmove(self->obj + j * self->_obj_size, tmp, self->_obj_size); } -static void heap_fixed(struct _heap* self, int i) +static void heap_fixed_up(struct _heap* self, int i) { assert(self != NULL); int p = 0; while(1) { p = parent(i); + // 若当前节点大于其父节点,则交换位置,否则退出循环 if(p < 0 || self->compare(self->obj + i * self->_obj_size, self->obj + p * self->_obj_size) <= 0) { break; @@ -79,7 +81,37 @@ bool heap_push(struct _heap* self, void* obj) memmove(self->obj + index * self->_obj_size, obj, self->_obj_size); self->_size++; - heap_fixed(self, index); + heap_fixed_up(self, index); +} + +static void heap_fixed_down(struct _heap* self, int i) +{ + assert(self != NULL); + int l = 0,r = 0; + int max = 0; + + while(1) + { + l = left(i); + r = right(i); + max = i; + + if(l < self->size(self) && self->compare(self->obj + l * self->_obj_size, self->obj + max * self->_obj_size) > 0) + { + max = l; + } + + if(r < self->size(self) && self->compare(self->obj + r * self->_obj_size, self->obj + max * self->_obj_size) > 0) + { + max = r; + } + if(max == i) + { + break; + } + heap_swap(self, i, max); + i = max; + } } bool heap_pop(struct _heap* self, void* obj) @@ -89,13 +121,15 @@ bool heap_pop(struct _heap* self, void* obj) { return false; } - heap_swap(self, 0, self->size(self) - 1); + int index = self->size(self) - 1; + heap_swap(self, 0, index); if(obj != NULL) { - memmove(obj, self->obj, self->_obj_size); + memmove(obj, self->obj + index * self->_obj_size, self->_obj_size); } self->_size--; - heap_fixed(self, 0); + heap_fixed_down(self, 0); + return true; } void heap_setmin(struct _heap* self, bool min_flag) @@ -141,7 +175,7 @@ void heap_print(struct _heap* self) void* obj = NULL; uint32_t offset = 0; - for (int i = self->size(self) - 1; i >= 0; i--) + for (int i = 0; i < self->size(self); i++) { offset = self->_obj_size * i; obj = (char *)self->obj + offset; diff --git a/test/test_heap.c b/test/test_heap.c index 9b4cce5..dc81fe3 100644 --- a/test/test_heap.c +++ b/test/test_heap.c @@ -13,10 +13,10 @@ void test_heap_num(void) { uint32_t i = 0; - int data[] = { 2,1,3,4}; + // int data[] = { 2,1,3,4}; // int data[] = { 1,2,3,4,5,6}; // int data[] = { 5,2,3,1,7,8,6 }; - // int data[] = { 5,2,3,1,7,8,6,4,9,10,12,11,15,14,13 }; + int data[] = { 5,2,3,1,7,8,6,4,9,10,12,11,15,14,13 }; int temp = 0; uint32_t len = sizeof(data) / sizeof(data[0]); @@ -25,6 +25,10 @@ void test_heap_num(void) heap->print_obj = print_num; heap->compare = compare_num; + // default: maxheap + // maxheap or minheap + // heap->setmin(heap, true); + printf("\n\n----- test_heap_num -----\n"); printf("----- push -----\n"); @@ -57,6 +61,24 @@ void test_heap_num(void) temp = data[i]; heap->push(heap, &temp); } + + printf("----- pop -----\n"); + for (i = 0; i < len; i++) + { + temp = data[i]; + heap->pop(heap, &temp); + + printf("pop = "); + heap->print_obj(&temp); + printf("size = %2d : ", heap->size(heap)); + heap->print(heap); + printf("\n"); + } + if(heap->empty(heap)) + { + printf("----- empty -----\n"); + } + heap_free(heap); } void test_heap(void)