added generalized min and max heap implementations

This commit is contained in:
alexchao26
2020-12-09 22:27:50 -05:00
parent e48c6de353
commit 1c498b4389
2 changed files with 195 additions and 0 deletions
+124
View File
@@ -0,0 +1,124 @@
package structures
// MinHeap is an implementation of a min heap
type MinHeap struct {
heap
}
// NewMinHeap initializes a heap with a closerToRootFunction that simply
// returns true if the first arg is smaller than the second
func NewMinHeap() MinHeap {
nestedHeap := heap{
closerToRoot: func(val1, val2 int) bool {
return val1 < val2
},
}
return MinHeap{nestedHeap}
}
// MaxHeap is an implementation of max heap
type MaxHeap struct {
heap
}
// NewMaxHeap initializes a heap with a closerToRootFunction that simply
// returns true if the first arg is larger than the second
func NewMaxHeap() MaxHeap {
nestedHeap := heap{
closerToRoot: func(val1, val2 int) bool {
return val1 > val2
},
}
return MaxHeap{nestedHeap}
}
// heap contains a slice of heapNodes
// A heap can be represented as an array/slice with no gaps because
// calculating the indices of two children or the parent is simple
// from any given index
type heap struct {
nodes []heapNode
closerToRoot func(val1, val2 int) bool
}
// heapNode is an interface making the type for a Min/MaxHeap node flexible
// nodes must be be able to state their value to be sorted by
type heapNode interface {
Value() int
}
// Add appends a new node onto the heap and heapifies it
// to ensure correct ordering
func (h *heap) Add(newNode heapNode) {
h.nodes = append(h.nodes, newNode)
h.heapifyFromEnd()
}
// Remove returns the node at the root, i.e. the minimum value node
func (h *heap) Remove() heapNode {
if len(h.nodes) == 0 {
return nil
}
rootNode := h.nodes[0]
// move last node to start & reduce length by one
h.nodes[0] = h.nodes[len(h.nodes)-1]
h.nodes = h.nodes[:len(h.nodes)-1]
// heapify the heap from the start to sort the minimum value into the 0 index
h.heapifyFromStart()
return rootNode
}
func (h *heap) swap(i, j int) {
h.nodes[i], h.nodes[j] = h.nodes[j], h.nodes[i]
}
// heapify from end expects an unordered value in the last index, it will compare
// it to its parent index and swapped if applicable, and repeated until the heap
// is valid
func (h *heap) heapifyFromEnd() {
currentIndex := len(h.nodes) - 1
for currentIndex > 0 {
parentIndex := (currentIndex - 1) / 2
parentNode := h.nodes[parentIndex]
if h.closerToRoot(h.nodes[currentIndex].Value(), parentNode.Value()) {
h.swap(parentIndex, currentIndex)
currentIndex = parentIndex
} else {
break
}
}
}
// heapify from start expects an unordered value in the heap in index zero,
// that node's value is compared to its children, and swaps are made as needed
// until the heap is valid
func (h *heap) heapifyFromStart() {
currentIndex := 0
for {
// find smaller of two children
smallerChildIndex := currentIndex
for i := 1; i <= 2; i++ {
childIndex := currentIndex*2 + i
// if a child value is closer to the root than the current node,
// store it's index
if childIndex < len(h.nodes) &&
h.closerToRoot(h.nodes[childIndex].Value(), h.nodes[smallerChildIndex].Value()) {
smallerChildIndex = childIndex
}
}
// if smallerChildIndex was not reassigned, no swap is needed, return out
if smallerChildIndex == currentIndex {
return
}
// otherwise swap & update currentIndex to keep checking on next loop
h.swap(smallerChildIndex, currentIndex)
currentIndex = smallerChildIndex
}
}
+71
View File
@@ -0,0 +1,71 @@
package structures
import (
"testing"
)
type mockNode int
func (n mockNode) Value() int {
return int(n)
}
func TestMinHeap(t *testing.T) {
h := NewMinHeap()
h.Add(mockNode(5))
h.Add(mockNode(93))
if h.nodes[0].Value() != 5 {
t.Errorf("After adding 5, h.nodes[0].Value() = %d, want 5", h.nodes[0].Value())
}
if h.nodes[1].Value() != 93 {
t.Errorf("After adding 93, h.nodes[1].Value() = %d, want 93", h.nodes[1].Value())
}
// Add a bunch of nodes, make sure they are removed in order
h.Add(mockNode(10))
h.Add(mockNode(2))
h.Add(mockNode(1))
h.Add(mockNode(3))
h.Add(mockNode(4))
h.Add(mockNode(123))
h.Add(mockNode(32))
h.Add(mockNode(-15))
// Ensure removing nodes returns in ascending order
for _, want := range []int{-15, 1, 2, 3, 4, 5, 10, 32, 93, 123} {
if got := h.Remove(); got.Value() != want {
t.Errorf("h.Remove().Value() = %d, want %d", got.Value(), want)
}
}
}
func TestMaxHeap(t *testing.T) {
h := NewMaxHeap()
h.Add(mockNode(5))
h.Add(mockNode(93))
if h.nodes[0].Value() != 93 {
t.Errorf("After adding 93, h.nodes[0].Value() = %d, want 93", h.nodes[1].Value())
}
if h.nodes[1].Value() != 5 {
t.Errorf("After adding 5, h.nodes[1].Value() = %d, want 5", h.nodes[0].Value())
}
// Add a bunch of nodes, make sure they are removed in order
h.Add(mockNode(10))
h.Add(mockNode(2))
h.Add(mockNode(1))
h.Add(mockNode(3))
h.Add(mockNode(4))
h.Add(mockNode(123))
h.Add(mockNode(32))
h.Add(mockNode(-15))
// Ensure removing returns in descending order
for _, want := range []int{123, 93, 32, 10, 5, 4, 3, 2, 1, -15} {
if got := h.Remove(); got.Value() != want {
t.Errorf("h.Remove().Value() = %d, want %d", got.Value(), want)
}
}
}