# 算法与数据结构 - 平衡二叉树的实现

## 可旋转的二叉排序树

``````package main

import "fmt"

type TreeNode struct {
Parent *TreeNode

Value  int

Left   *TreeNode
Right  *TreeNode
}

// Insert 往树中合适位置插入节点
func Insert(root *TreeNode, value int) *TreeNode {
if root == nil {
return &TreeNode{Value: value}
}

if value < root.Value {
root.Left = Insert(root.Left, value)
root.Left.Parent = root
} else if value > root.Value {
root.Right = Insert(root.Right, value)
root.Right.Parent = root
} else {
return root
}

return root
}

// RotateRight 右旋
func RotateRight(root *TreeNode) *TreeNode {
if root.Left == nil {
return root
}

newRoot := root.Left
tmp := newRoot.Right
newRoot.Right = root
root.Left = tmp

if tmp != nil {
tmp.Parent = root
}
newRoot.Parent = root.Parent
root.Parent = newRoot

return newRoot
}

// RotateLeft 左旋
func RotateLeft(root *TreeNode) *TreeNode {
if root.Right == nil {
return root
}

newRoot := root.Right
tmp := newRoot.Left
newRoot.Left = root
root.Right = tmp

if tmp != nil {
tmp.Parent = root
}
newRoot.Parent = root.Parent
root.Parent = newRoot

return newRoot
}

// PrintTree 以树状形式打印树
func PrintTree(root *TreeNode) {
// 这里先不管
}

func main() {
var root *TreeNode
root = Insert(root, 7)
root = Insert(root, 3)
root = Insert(root, 2)
root = Insert(root, 5)
root = Insert(root, 8)
PrintTree(root)
fmt.Println("------------")

root = RotateLeft(root)
PrintTree(root)
fmt.Println("------------")

root = RotateRight(root)
PrintTree(root)
fmt.Println("------------")
}
``````

## 添加 Height 参数

``````type TreeNode struct {
Parent *TreeNode

Value  int
Height int

Left   *TreeNode
Right  *TreeNode
}

func NewTreeNode(value int) *TreeNode {
return &TreeNode{Value: value, Height: 1}
}
``````

## 检测树平衡

``````func Insert(root *TreeNode, value int) *TreeNode {
if root == nil {
return &TreeNode{Value: value}
}

if value < root.Value {
root.Left = Insert(root.Left, value)
root.Left.Parent = root
} else if value > root.Value {
root.Right = Insert(root.Right, value)
root.Right.Parent = root
} else {
return root
}

return root
}
``````

``````func max(a, b int) int {
if a > b {
return a
}
return b
}

// GetHeight 用来处理节点为 nil 的情况
func GetHeight(node *TreeNode) int {
if node == nil {
return 0
}

return node.Height
}

func Insert(root *TreeNode, value int) *TreeNode {
// ...

root.Height = max(GetHeight(root.Left), GetHeight(root.Right)) + 1

return root
}
``````

``````// GetBalanceFactor 获取平衡因子
func GetBalanceFactor(node *TreeNode) int {
if node == nil {
return 0
}

return GetHeight(node.Left) - GetHeight(node.Right)
}

func Insert(root *TreeNode, value int) *TreeNode {
// ...

root.Height = max(GetHeight(root.Left), GetHeight(root.Right)) + 1

bf := GetBalanceFactor(root)
if bf < -1 { // 应该左旋
root = RotateLeft(root)
} else if bf > 1 { // 应该右旋
root = RotateRight(root)
} else {
// do nothing
}

return root
}
``````

## 旋转时更新树高

``````// RotateRight 右旋
func RotateRight(root *TreeNode) *TreeNode {
// ...

root.Height = max(GetHeight(root.Left), GetHeight(root.Right)) + 1
newRoot.Height = max(GetHeight(newRoot.Left), GetHeight(newRoot.Right)) + 1

return newRoot
}

// RotateLeft 左旋
func RotateLeft(root *TreeNode) *TreeNode {
// ...

root.Height = max(GetHeight(root.Left), GetHeight(root.Right)) + 1
newRoot.Height = max(GetHeight(newRoot.Left), GetHeight(newRoot.Right)) + 1

return newRoot
}
``````

## 目前为止的完整代码

``````package main

import "fmt"

type TreeNode struct {
Parent *TreeNode

Value  int
Height int

Left  *TreeNode
Right *TreeNode
}

func NewTreeNode(value int) *TreeNode {
return &TreeNode{Value: value, Height: 1}
}

// Insert 往树中合适位置插入节点
func Insert(root *TreeNode, value int) *TreeNode {
if root == nil {
return &TreeNode{Value: value}
}

if value < root.Value {
root.Left = Insert(root.Left, value)
root.Left.Parent = root
} else if value > root.Value {
root.Right = Insert(root.Right, value)
root.Right.Parent = root
} else {
return root
}

root.Height = max(GetHeight(root.Left), GetHeight(root.Right)) + 1

bf := GetBalanceFactor(root)
if bf < -1 { // 应该左旋
root = RotateLeft(root)
} else if bf > 1 { // 应该右旋
root = RotateRight(root)
} else {
// do nothing
}

return root
}

func max(a, b int) int {
if a > b {
return a
}
return b
}

// GetHeight 用来处理节点为 nil 的情况
func GetHeight(node *TreeNode) int {
if node == nil {
return 0
}

return node.Height
}

// GetBalanceFactor 获取平衡因子
func GetBalanceFactor(node *TreeNode) int {
if node == nil {
return 0
}

return GetHeight(node.Left) - GetHeight(node.Right)
}

// RotateRight 右旋
func RotateRight(root *TreeNode) *TreeNode {
if root.Left == nil {
return root
}

// 旋转
newRoot := root.Left
tmp := newRoot.Right
newRoot.Right = root
root.Left = tmp

// 更新节点的父节点信息
if tmp != nil {
tmp.Parent = root
}
newRoot.Parent = root.Parent
root.Parent = newRoot

// 更新树高
root.Height = max(GetHeight(root.Left), GetHeight(root.Right)) + 1
newRoot.Height = max(GetHeight(newRoot.Left), GetHeight(newRoot.Right)) + 1

return newRoot
}

// RotateLeft 左旋
func RotateLeft(root *TreeNode) *TreeNode {
if root.Right == nil {
return root
}

// 旋转
newRoot := root.Right
tmp := newRoot.Left
newRoot.Left = root
root.Right = tmp

// 更新节点的父节点信息
if tmp != nil {
tmp.Parent = root
}
newRoot.Parent = root.Parent
root.Parent = newRoot

// 更新树高
root.Height = max(GetHeight(root.Left), GetHeight(root.Right)) + 1
newRoot.Height = max(GetHeight(newRoot.Left), GetHeight(newRoot.Right)) + 1

return newRoot
}

func PrintTree(root *TreeNode) {
}

func main() {
var root *TreeNode
root = Insert(root, 7)
root = Insert(root, 3)
root = Insert(root, 2)
root = Insert(root, 5)
root = Insert(root, 8)
PrintTree(root)
fmt.Println("------------")

root = Insert(root, 6)
PrintTree(root)
fmt.Println("------------")
}
``````

## 旋转的问题

1. 左子树高还是右子树高
2. 不平衡是由子树的左子树还是右子树引起的

## 旋转方式改进

1. 左子树的左子树引起的失衡，用 LL（Left-Left） 表示；
2. 左子树的右子树引起的失衡，用 LR（Left-Right） 表示；
3. 右子树的左子树引起的失衡，用 RL（Right-Left） 表示；
4. 右子树的右子树引起的失衡，用 RR（Right-Right） 表示。

``````func Insert(root *TreeNode, value int) *TreeNode {
// ...

bf := GetBalanceFactor(root)
if bf < -1 { // 应该左旋
if value < root.Right.Value { // 在右子树的左子树上
root = RLRotation(root)
} else { // 在右子树的右子树上
root = RRRotation(root)
}
} else if bf > 1 { // 应该右旋
if value < root.Left.Value { // 在左子树的左子树上
root = LLRotation(root)
} else { // 在左子树的右子树上
root = LRRotation(root)
}
} else {
// do nothing
}

return root
}

func LLRotation(root *TreeNode) *TreeNode {
return RotateRight(root)
}

func LRRotation(root *TreeNode) *TreeNode {
root.Left = RotateLeft(root.Left)
return RotateRight(root)
}

func RRRotation(root *TreeNode) *TreeNode {
return RotateLeft(root)
}

func RLRotation(root *TreeNode) *TreeNode {
root.Right = RotateRight(root.Right)
return RotateLeft(root)
}
``````