Skip to content

Commit

Permalink
Implement a simple prefix-Trie structure for storing arbitrary payloads
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 446284709
  • Loading branch information
gvisor-bot committed May 3, 2022
1 parent 6077c1c commit 13cc10b
Show file tree
Hide file tree
Showing 3 changed files with 335 additions and 0 deletions.
19 changes: 19 additions & 0 deletions pkg/trie/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
load("//tools:defs.bzl", "go_library", "go_test")

package(licenses = ["notice"])

go_library(
name = "trie",
srcs = ["trie.go"],
visibility = ["//:sandbox"],
)

go_test(
name = "trie_test",
srcs = ["trie_test.go"],
library = ":trie",
deps = [
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
133 changes: 133 additions & 0 deletions pkg/trie/trie.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package trie provides a character-based prefix trie data structure for storing arbitrary payloads
// in an efficiently retrievable manner.
package trie

// Visitor accepts a prefix string and an associated value, and returns true iff searching should
// continue deeper into the Trie. It is used by FindMatching().
type Visitor func(prefix string, value interface{}) bool

// Trie stores data at given strings in tree structure, for linear-time retrieval.
// Call New() to obtain a valid Trie.
type Trie struct {
root *node
size int
}

// New creates a new instance of the Trie interface.
func New() *Trie {
return &Trie{root: &node{children: make(map[rune]*node)}, size: 0}
}

type node struct {
value interface{}
children map[rune]*node
}

// FindPrefixes invokes the Visitor with all key-value pairs where the key is a prefix of `key`,
// including exact matches. It does this in increasing order of key length, and terminates early if
// Visitor returns false.
func (t *Trie) FindPrefixes(key string, f Visitor) {
cur := t.root
if cur.value != nil && !f("", cur.value) {
return
}

for i, r := range key {
next, ok := cur.children[r]
if !ok {
return
}

if next.value != nil && !f(key[:(i+1)], next.value) {
return
}
cur = next
}
}

func (t *Trie) updateNode(n *node, newValue interface{}) {
if n.value != nil {
t.size--
}
if newValue != nil {
t.size++
}
n.value = newValue
}

// SetValue associates the specified key with the given value, replacing any existing value.
func (t *Trie) SetValue(key string, value interface{}) {
cur := t.root
for _, r := range key {
next, ok := cur.children[r]
if !ok {
next = &node{children: make(map[rune]*node)}
cur.children[r] = next
}
cur = next
}

if cur.value != nil {
t.size--
}
if value != nil {
t.size++
}
cur.value = value
}

type queueEntry struct {
key string
value *node
}

// FindSuffixes invokes the Visitor with all key-value pairs where the key is prefixed by `key`,
// including exact matches. It does this in an unspecified order, and terminates early if the
// Visitor returns false.
//
// Invoking FindSuffixes with the empty string as a key will iterate over all values.
func (t *Trie) FindSuffixes(key string, f Visitor) {
cur := t.root
for _, r := range key {
next, ok := cur.children[r]
if !ok {
return
}
cur = next
}

queue := make([]queueEntry, 0)
queue = append(queue, queueEntry{key: key, value: cur})

for len(queue) > 0 {
cur := queue[0]
queue = queue[1:]

if cur.value.value != nil && !f(cur.key, cur.value.value) {
return
}

for r, v := range cur.value.children {
queue = append(queue, queueEntry{key: cur.key + string(r), value: v})
}
}
}

// Size returns the total number of values in the Trie.
func (t *Trie) Size() int {
return t.size
}
183 changes: 183 additions & 0 deletions pkg/trie/trie_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package trie

import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)

type Entry struct {
Key string
Value string
}

func collectPrefixes(tr *Trie, key string) []Entry {
arr := make([]Entry, 0)
tr.FindPrefixes(key, func(p string, v interface{}) bool {
arr = append(arr, Entry{Key: p, Value: v.(string)})
return true
})
return arr
}

func collectSuffixes(tr *Trie, key string) []Entry {
arr := make([]Entry, 0)
tr.FindSuffixes(key, func(p string, v interface{}) bool {
arr = append(arr, Entry{Key: p, Value: v.(string)})
return true
})
return arr
}

func sortEntries(a Entry, b Entry) bool {
return a.Key < b.Key
}

func TestEmpty(t *testing.T) {
tr := New()
if tr.Size() != 0 {
t.Errorf("tr.Size() = %d; want 0", tr.Size())
}

arr := collectPrefixes(tr, "foo")
if d := cmp.Diff([]Entry{}, arr); d != "" {
t.Errorf("collectPrefixes(tr, 'foo') returned diff (-want +got):\n%s", d)
}

arr = collectSuffixes(tr, "foo")
if d := cmp.Diff([]Entry{}, arr); d != "" {
t.Errorf("collectSuffixes(tr, '') returned diff (-want +got):\n%s", d)
}

arr = collectPrefixes(tr, "")
if d := cmp.Diff([]Entry{}, arr); d != "" {
t.Errorf("collectPrefixes(tr, '') returned diff (-want +got):\n%s", d)
}

arr = collectSuffixes(tr, "")
if d := cmp.Diff([]Entry{}, arr); d != "" {
t.Errorf("collectSuffixes(tr, '') returned diff (-want +got):\n%s", d)
}
}

func TestAscendingSearch(t *testing.T) {
tr := New()
tr.SetValue("a", "value a")
tr.SetValue("ab", "value ab")
tr.SetValue("abc", "value abc")
tr.SetValue("abcd", "value abcd")
tr.SetValue("abcde", "value abcde")

expected := []Entry{
Entry{Key: "a", Value: "value a"},
Entry{Key: "ab", Value: "value ab"},
Entry{Key: "abc", Value: "value abc"},
Entry{Key: "abcd", Value: "value abcd"},
Entry{Key: "abcde", Value: "value abcde"}}
arr := collectPrefixes(tr, "abcdef")
if d := cmp.Diff(expected, arr); d != "" {
t.Errorf("collectPrefixes(tr, 'abcdef') returned diff (-want +got):\n%s", d)
}

suffixTests := []struct {
key string
entries []Entry
}{
{"", expected},
{"zzz", []Entry{}},
{"a", expected},
{"ab", expected[1:]},
{"abc", expected[2:]},
{"abd", []Entry{}},
{"abcd", expected[3:]},
{"abcde", expected[4:]},
}
for _, tt := range suffixTests {
t.Run(tt.key, func(t *testing.T) {
arr := collectSuffixes(tr, tt.key)
if d := cmp.Diff(tt.entries, arr, cmpopts.SortSlices(sortEntries)); d != "" {
t.Errorf("collectSuffixes(tr, %q) returned sorted diff (-want +got):\n%s", tt.key, d)
}
})
}
}

func TestRoot(t *testing.T) {
tr := New()
tr.SetValue("", "root value")
if tr.Size() != 1 {
t.Errorf("tr.Size() = %d; want 1", tr.Size())
}

expected := []Entry{Entry{Key: "", Value: "root value"}}
arr := collectPrefixes(tr, "foo")
if d := cmp.Diff(expected, arr); d != "" {
t.Errorf("collectPrefixes(tr, 'foo') returned diff (-want +got):\n%s", d)
}

arr = collectPrefixes(tr, "")
if d := cmp.Diff(expected, arr); d != "" {
t.Errorf("collectPrefixes(tr, '') returned diff (-want +got):\n%s", d)
}
}

func TestMultiplePrefixes(t *testing.T) {
tr := New()
tr.SetValue("foo", "old foo value")
if tr.Size() != 1 {
t.Errorf("tr.Size() = %d; want 1", tr.Size())
}
tr.SetValue("foobar", "foobar value")
if tr.Size() != 2 {
t.Errorf("tr.Size() = %d; want 2", tr.Size())
}
tr.SetValue("qux", "qux value")
if tr.Size() != 3 {
t.Errorf("tr.Size() = %d; want 3", tr.Size())
}
tr.SetValue("foo", "foo value")
if tr.Size() != 3 {
t.Errorf("tr.Size() = %d; want 3", tr.Size())
}

fooEntry := Entry{Key: "foo", Value: "foo value"}
foobarEntry := Entry{Key: "foobar", Value: "foobar value"}
quxEntry := Entry{Key: "qux", Value: "qux value"}

prefixTests := []struct {
key string
entries []Entry
}{
{"foobar", []Entry{fooEntry, foobarEntry}},
{"fooba", []Entry{fooEntry}},
{"foo", []Entry{fooEntry}},
{"quxiho", []Entry{quxEntry}},
{"fo", []Entry{}},
{"qu", []Entry{}},
{"nowhere", []Entry{}},
{"", []Entry{}},
}
for _, tt := range prefixTests {
t.Run(tt.key, func(t *testing.T) {
arr := collectPrefixes(tr, tt.key)
if d := cmp.Diff(tt.entries, arr); d != "" {
t.Errorf("collectPrefixes(tr, %q) returned diff (-want +got):\n%s", tt.key, d)
}
})
}
}

0 comments on commit 13cc10b

Please sign in to comment.