-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement a simple prefix-Trie structure for storing arbitrary payloads
PiperOrigin-RevId: 446284709
- Loading branch information
1 parent
6077c1c
commit 13cc10b
Showing
3 changed files
with
335 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
load("//tools:defs.bzl", "go_library", "go_test") | ||
|
||
package(licenses = ["notice"]) | ||
|
||
go_library( | ||
name = "trie", | ||
srcs = ["trie.go"], | ||
visibility = ["//:sandbox"], | ||
) | ||
|
||
go_test( | ||
name = "trie_test", | ||
srcs = ["trie_test.go"], | ||
library = ":trie", | ||
deps = [ | ||
"@com_github_google_go_cmp//cmp:go_default_library", | ||
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// Copyright 2022 The gVisor Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
// Package trie provides a character-based prefix trie data structure for storing arbitrary payloads | ||
// in an efficiently retrievable manner. | ||
package trie | ||
|
||
// Visitor accepts a prefix string and an associated value, and returns true iff searching should | ||
// continue deeper into the Trie. It is used by FindMatching(). | ||
type Visitor func(prefix string, value interface{}) bool | ||
|
||
// Trie stores data at given strings in tree structure, for linear-time retrieval. | ||
// Call New() to obtain a valid Trie. | ||
type Trie struct { | ||
root *node | ||
size int | ||
} | ||
|
||
// New creates a new instance of the Trie interface. | ||
func New() *Trie { | ||
return &Trie{root: &node{children: make(map[rune]*node)}, size: 0} | ||
} | ||
|
||
type node struct { | ||
value interface{} | ||
children map[rune]*node | ||
} | ||
|
||
// FindPrefixes invokes the Visitor with all key-value pairs where the key is a prefix of `key`, | ||
// including exact matches. It does this in increasing order of key length, and terminates early if | ||
// Visitor returns false. | ||
func (t *Trie) FindPrefixes(key string, f Visitor) { | ||
cur := t.root | ||
if cur.value != nil && !f("", cur.value) { | ||
return | ||
} | ||
|
||
for i, r := range key { | ||
next, ok := cur.children[r] | ||
if !ok { | ||
return | ||
} | ||
|
||
if next.value != nil && !f(key[:(i+1)], next.value) { | ||
return | ||
} | ||
cur = next | ||
} | ||
} | ||
|
||
func (t *Trie) updateNode(n *node, newValue interface{}) { | ||
if n.value != nil { | ||
t.size-- | ||
} | ||
if newValue != nil { | ||
t.size++ | ||
} | ||
n.value = newValue | ||
} | ||
|
||
// SetValue associates the specified key with the given value, replacing any existing value. | ||
func (t *Trie) SetValue(key string, value interface{}) { | ||
cur := t.root | ||
for _, r := range key { | ||
next, ok := cur.children[r] | ||
if !ok { | ||
next = &node{children: make(map[rune]*node)} | ||
cur.children[r] = next | ||
} | ||
cur = next | ||
} | ||
|
||
if cur.value != nil { | ||
t.size-- | ||
} | ||
if value != nil { | ||
t.size++ | ||
} | ||
cur.value = value | ||
} | ||
|
||
type queueEntry struct { | ||
key string | ||
value *node | ||
} | ||
|
||
// FindSuffixes invokes the Visitor with all key-value pairs where the key is prefixed by `key`, | ||
// including exact matches. It does this in an unspecified order, and terminates early if the | ||
// Visitor returns false. | ||
// | ||
// Invoking FindSuffixes with the empty string as a key will iterate over all values. | ||
func (t *Trie) FindSuffixes(key string, f Visitor) { | ||
cur := t.root | ||
for _, r := range key { | ||
next, ok := cur.children[r] | ||
if !ok { | ||
return | ||
} | ||
cur = next | ||
} | ||
|
||
queue := make([]queueEntry, 0) | ||
queue = append(queue, queueEntry{key: key, value: cur}) | ||
|
||
for len(queue) > 0 { | ||
cur := queue[0] | ||
queue = queue[1:] | ||
|
||
if cur.value.value != nil && !f(cur.key, cur.value.value) { | ||
return | ||
} | ||
|
||
for r, v := range cur.value.children { | ||
queue = append(queue, queueEntry{key: cur.key + string(r), value: v}) | ||
} | ||
} | ||
} | ||
|
||
// Size returns the total number of values in the Trie. | ||
func (t *Trie) Size() int { | ||
return t.size | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
// Copyright 2022 The gVisor Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package trie | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
"github.com/google/go-cmp/cmp/cmpopts" | ||
) | ||
|
||
type Entry struct { | ||
Key string | ||
Value string | ||
} | ||
|
||
func collectPrefixes(tr *Trie, key string) []Entry { | ||
arr := make([]Entry, 0) | ||
tr.FindPrefixes(key, func(p string, v interface{}) bool { | ||
arr = append(arr, Entry{Key: p, Value: v.(string)}) | ||
return true | ||
}) | ||
return arr | ||
} | ||
|
||
func collectSuffixes(tr *Trie, key string) []Entry { | ||
arr := make([]Entry, 0) | ||
tr.FindSuffixes(key, func(p string, v interface{}) bool { | ||
arr = append(arr, Entry{Key: p, Value: v.(string)}) | ||
return true | ||
}) | ||
return arr | ||
} | ||
|
||
func sortEntries(a Entry, b Entry) bool { | ||
return a.Key < b.Key | ||
} | ||
|
||
func TestEmpty(t *testing.T) { | ||
tr := New() | ||
if tr.Size() != 0 { | ||
t.Errorf("tr.Size() = %d; want 0", tr.Size()) | ||
} | ||
|
||
arr := collectPrefixes(tr, "foo") | ||
if d := cmp.Diff([]Entry{}, arr); d != "" { | ||
t.Errorf("collectPrefixes(tr, 'foo') returned diff (-want +got):\n%s", d) | ||
} | ||
|
||
arr = collectSuffixes(tr, "foo") | ||
if d := cmp.Diff([]Entry{}, arr); d != "" { | ||
t.Errorf("collectSuffixes(tr, '') returned diff (-want +got):\n%s", d) | ||
} | ||
|
||
arr = collectPrefixes(tr, "") | ||
if d := cmp.Diff([]Entry{}, arr); d != "" { | ||
t.Errorf("collectPrefixes(tr, '') returned diff (-want +got):\n%s", d) | ||
} | ||
|
||
arr = collectSuffixes(tr, "") | ||
if d := cmp.Diff([]Entry{}, arr); d != "" { | ||
t.Errorf("collectSuffixes(tr, '') returned diff (-want +got):\n%s", d) | ||
} | ||
} | ||
|
||
func TestAscendingSearch(t *testing.T) { | ||
tr := New() | ||
tr.SetValue("a", "value a") | ||
tr.SetValue("ab", "value ab") | ||
tr.SetValue("abc", "value abc") | ||
tr.SetValue("abcd", "value abcd") | ||
tr.SetValue("abcde", "value abcde") | ||
|
||
expected := []Entry{ | ||
Entry{Key: "a", Value: "value a"}, | ||
Entry{Key: "ab", Value: "value ab"}, | ||
Entry{Key: "abc", Value: "value abc"}, | ||
Entry{Key: "abcd", Value: "value abcd"}, | ||
Entry{Key: "abcde", Value: "value abcde"}} | ||
arr := collectPrefixes(tr, "abcdef") | ||
if d := cmp.Diff(expected, arr); d != "" { | ||
t.Errorf("collectPrefixes(tr, 'abcdef') returned diff (-want +got):\n%s", d) | ||
} | ||
|
||
suffixTests := []struct { | ||
key string | ||
entries []Entry | ||
}{ | ||
{"", expected}, | ||
{"zzz", []Entry{}}, | ||
{"a", expected}, | ||
{"ab", expected[1:]}, | ||
{"abc", expected[2:]}, | ||
{"abd", []Entry{}}, | ||
{"abcd", expected[3:]}, | ||
{"abcde", expected[4:]}, | ||
} | ||
for _, tt := range suffixTests { | ||
t.Run(tt.key, func(t *testing.T) { | ||
arr := collectSuffixes(tr, tt.key) | ||
if d := cmp.Diff(tt.entries, arr, cmpopts.SortSlices(sortEntries)); d != "" { | ||
t.Errorf("collectSuffixes(tr, %q) returned sorted diff (-want +got):\n%s", tt.key, d) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestRoot(t *testing.T) { | ||
tr := New() | ||
tr.SetValue("", "root value") | ||
if tr.Size() != 1 { | ||
t.Errorf("tr.Size() = %d; want 1", tr.Size()) | ||
} | ||
|
||
expected := []Entry{Entry{Key: "", Value: "root value"}} | ||
arr := collectPrefixes(tr, "foo") | ||
if d := cmp.Diff(expected, arr); d != "" { | ||
t.Errorf("collectPrefixes(tr, 'foo') returned diff (-want +got):\n%s", d) | ||
} | ||
|
||
arr = collectPrefixes(tr, "") | ||
if d := cmp.Diff(expected, arr); d != "" { | ||
t.Errorf("collectPrefixes(tr, '') returned diff (-want +got):\n%s", d) | ||
} | ||
} | ||
|
||
func TestMultiplePrefixes(t *testing.T) { | ||
tr := New() | ||
tr.SetValue("foo", "old foo value") | ||
if tr.Size() != 1 { | ||
t.Errorf("tr.Size() = %d; want 1", tr.Size()) | ||
} | ||
tr.SetValue("foobar", "foobar value") | ||
if tr.Size() != 2 { | ||
t.Errorf("tr.Size() = %d; want 2", tr.Size()) | ||
} | ||
tr.SetValue("qux", "qux value") | ||
if tr.Size() != 3 { | ||
t.Errorf("tr.Size() = %d; want 3", tr.Size()) | ||
} | ||
tr.SetValue("foo", "foo value") | ||
if tr.Size() != 3 { | ||
t.Errorf("tr.Size() = %d; want 3", tr.Size()) | ||
} | ||
|
||
fooEntry := Entry{Key: "foo", Value: "foo value"} | ||
foobarEntry := Entry{Key: "foobar", Value: "foobar value"} | ||
quxEntry := Entry{Key: "qux", Value: "qux value"} | ||
|
||
prefixTests := []struct { | ||
key string | ||
entries []Entry | ||
}{ | ||
{"foobar", []Entry{fooEntry, foobarEntry}}, | ||
{"fooba", []Entry{fooEntry}}, | ||
{"foo", []Entry{fooEntry}}, | ||
{"quxiho", []Entry{quxEntry}}, | ||
{"fo", []Entry{}}, | ||
{"qu", []Entry{}}, | ||
{"nowhere", []Entry{}}, | ||
{"", []Entry{}}, | ||
} | ||
for _, tt := range prefixTests { | ||
t.Run(tt.key, func(t *testing.T) { | ||
arr := collectPrefixes(tr, tt.key) | ||
if d := cmp.Diff(tt.entries, arr); d != "" { | ||
t.Errorf("collectPrefixes(tr, %q) returned diff (-want +got):\n%s", tt.key, d) | ||
} | ||
}) | ||
} | ||
} |