oktopus/backend/services/mochi/vendor/github.com/timshannon/badgerhold/query.go
2023-03-20 11:19:02 -03:00

1023 lines
22 KiB
Go

// Copyright 2019 Tim Shannon. All rights reserved.
// Use of this source code is governed by the MIT license
// that can be found in the LICENSE file.
package badgerhold
import (
"fmt"
"reflect"
"regexp"
"sort"
"strings"
"unicode"
"github.com/dgraph-io/badger"
)
const (
eq = iota // ==
ne // !=
gt // >
lt // <
ge // >=
le // <=
in // in
re // regular expression
fn // func
isnil // test's for nil
sw // string starts with
ew // string ends with
)
// Key is shorthand for specifying a query to run again the Key in a badgerhold, simply returns ""
// Where(badgerhold.Key).Eq("testkey")
const Key = ""
// Query is a chained collection of criteria of which an object in the badgerhold needs to match to be returned
// an empty query matches against all records
type Query struct {
index string
currentField string
fieldCriteria map[string][]*Criterion
ors []*Query
badIndex bool
dataType reflect.Type
tx *badger.Txn
writable bool
subquery bool
bookmark *iterBookmark
limit int
skip int
sort []string
reverse bool
}
// IsEmpty returns true if the query is an empty query
// an empty query matches against everything
func (q *Query) IsEmpty() bool {
if q.index != "" {
return false
}
if len(q.fieldCriteria) != 0 {
return false
}
if q.ors != nil {
return false
}
return true
}
// Criterion is an operator and a value that a given field needs to match on
type Criterion struct {
query *Query
operator int
value interface{}
inValues []interface{}
}
func hasMatchFunc(criteria []*Criterion) bool {
for _, c := range criteria {
if c.operator == fn {
return true
}
}
return false
}
// Field allows for referencing a field in structure being compared
type Field string
// Where starts a query for specifying the criteria that an object in the badgerhold needs to match to
// be returned in a Find result
/*
Query API Example
s.Find(badgerhold.Where("FieldName").Eq(value).And("AnotherField").Lt(AnotherValue).
Or(badgerhold.Where("FieldName").Eq(anotherValue)
Since Gobs only encode exported fields, this will panic if you pass in a field with a lower case first letter
*/
func Where(field string) *Criterion {
if !startsUpper(field) {
panic("The first letter of a field in a badgerhold query must be upper-case")
}
return &Criterion{
query: &Query{
currentField: field,
fieldCriteria: make(map[string][]*Criterion),
},
}
}
// And creates a nother set of criterion the needs to apply to a query
func (q *Query) And(field string) *Criterion {
if !startsUpper(field) {
panic("The first letter of a field in a badgerhold query must be upper-case")
}
q.currentField = field
return &Criterion{
query: q,
}
}
// Skip skips the number of records that match all the rest of the query criteria, and does not return them
// in the result set. Setting skip multiple times, or to a negative value will panic
func (q *Query) Skip(amount int) *Query {
if amount < 0 {
panic("Skip must be set to a positive number")
}
if q.skip != 0 {
panic(fmt.Sprintf("Skip has already been set to %d", q.skip))
}
q.skip = amount
return q
}
// Limit sets the maximum number of records that can be returned by a query
// Setting Limit multiple times, or to a negative value will panic
func (q *Query) Limit(amount int) *Query {
if amount < 0 {
panic("Limit must be set to a positive number")
}
if q.limit != 0 {
panic(fmt.Sprintf("Limit has already been set to %d", q.limit))
}
q.limit = amount
return q
}
// SortBy sorts the results by the given fields name
// Multiple fields can be used
func (q *Query) SortBy(fields ...string) *Query {
for i := range fields {
if fields[i] == Key {
panic("Cannot sort by Key.")
}
var found bool
for k := range q.sort {
if q.sort[k] == fields[i] {
found = true
break
}
}
if !found {
q.sort = append(q.sort, fields[i])
}
}
return q
}
// Reverse will reverse the current result set
// useful with SortBy
func (q *Query) Reverse() *Query {
q.reverse = !q.reverse
return q
}
// Index specifies the index to use when running this query
func (q *Query) Index(indexName string) *Query {
if strings.Contains(indexName, ".") {
// NOTE: I may reconsider this in the future
panic("Nested indexes are not supported. Only top level structures can be indexed")
}
q.index = indexName
return q
}
// Or creates another separate query that gets unioned with any other results in the query
// Or will panic if the query passed in contains a limit or skip value, as they are only
// allowed on top level queries
func (q *Query) Or(query *Query) *Query {
if query.skip != 0 || query.limit != 0 {
panic("Or'd queries cannot contain skip or limit values")
}
q.ors = append(q.ors, query)
return q
}
func (q *Query) matchesAllFields(key []byte, value reflect.Value, currentRow interface{}) (bool, error) {
if q.IsEmpty() {
return true, nil
}
for field, criteria := range q.fieldCriteria {
if field == q.index && !q.badIndex && !hasMatchFunc(criteria) {
// already handled by index Iterator
continue
}
if field == Key {
ok, err := matchesAllCriteria(criteria, key, true, q.dataType.Name(), currentRow)
if err != nil {
return false, err
}
if !ok {
return false, nil
}
continue
}
fVal, err := fieldValue(value, field)
if err != nil {
return false, err
}
ok, err := matchesAllCriteria(criteria, fVal.Interface(), false, "", currentRow)
if err != nil {
return false, err
}
if !ok {
return false, nil
}
}
return true, nil
}
func fieldValue(value reflect.Value, field string) (reflect.Value, error) {
fields := strings.Split(field, ".")
current := value
for i := range fields {
if current.Kind() == reflect.Ptr {
current = current.Elem().FieldByName(fields[i])
} else {
current = current.FieldByName(fields[i])
}
if !current.IsValid() {
return reflect.Value{}, fmt.Errorf("The field %s does not exist in the type %s", field, value)
}
}
return current, nil
}
func (c *Criterion) op(op int, value interface{}) *Query {
c.operator = op
c.value = value
q := c.query
q.fieldCriteria[q.currentField] = append(q.fieldCriteria[q.currentField], c)
return q
}
// Eq tests if the current field is Equal to the passed in value
func (c *Criterion) Eq(value interface{}) *Query {
return c.op(eq, value)
}
// Ne test if the current field is Not Equal to the passed in value
func (c *Criterion) Ne(value interface{}) *Query {
return c.op(ne, value)
}
// Gt test if the current field is Greater Than the passed in value
func (c *Criterion) Gt(value interface{}) *Query {
return c.op(gt, value)
}
// Lt test if the current field is Less Than the passed in value
func (c *Criterion) Lt(value interface{}) *Query {
return c.op(lt, value)
}
// Ge test if the current field is Greater Than or Equal To the passed in value
func (c *Criterion) Ge(value interface{}) *Query {
return c.op(ge, value)
}
// Le test if the current field is Less Than or Equal To the passed in value
func (c *Criterion) Le(value interface{}) *Query {
return c.op(le, value)
}
// In test if the current field is a member of the slice of values passed in
func (c *Criterion) In(values ...interface{}) *Query {
c.operator = in
c.inValues = values
q := c.query
q.fieldCriteria[q.currentField] = append(q.fieldCriteria[q.currentField], c)
return q
}
// RegExp will test if a field matches against the regular expression
// The Field Value will be converted to string (%s) before testing
func (c *Criterion) RegExp(expression *regexp.Regexp) *Query {
return c.op(re, expression)
}
// IsNil will test if a field is equal to nil
func (c *Criterion) IsNil() *Query {
return c.op(isnil, nil)
}
// HasPrefix will test if a field starts with provided string
func (c *Criterion) HasPrefix(prefix string) *Query {
return c.op(sw, prefix)
}
// HasSuffix will test if a field ends with provided string
func (c *Criterion) HasSuffix(suffix string) *Query {
return c.op(ew, suffix)
}
// MatchFunc is a function used to test an arbitrary matching value in a query
type MatchFunc func(ra *RecordAccess) (bool, error)
// RecordAccess allows access to the current record, field or allows running a subquery within a
// MatchFunc
type RecordAccess struct {
record interface{}
field interface{}
query *Query
}
// Field is the current field being queried
func (r *RecordAccess) Field() interface{} {
return r.field
}
// Record is the complete record for a given row in badgerhold
func (r *RecordAccess) Record() interface{} {
return r.record
}
// SubQuery allows you to run another query in the same transaction for each
// record in a parent query
func (r *RecordAccess) SubQuery(result interface{}, query *Query) error {
query.subquery = true
query.bookmark = r.query.bookmark
return findQuery(r.query.tx, result, query)
}
// SubAggregateQuery allows you to run another aggregate query in the same transaction for each
// record in a parent query
func (r *RecordAccess) SubAggregateQuery(query *Query, groupBy ...string) ([]*AggregateResult, error) {
query.subquery = true
query.bookmark = r.query.bookmark
return aggregateQuery(r.query.tx, r.record, query, groupBy...)
}
// MatchFunc will test if a field matches the passed in function
func (c *Criterion) MatchFunc(match MatchFunc) *Query {
if c.query.currentField == Key {
panic("Match func cannot be used against Keys, as the Key type is unknown at runtime, and there is " +
"no value compare against")
}
return c.op(fn, match)
}
// test if the criterion passes with the passed in value
func (c *Criterion) test(testValue interface{}, encoded bool, keyType string, currentRow interface{}) (bool, error) {
var value interface{}
if encoded {
if len(testValue.([]byte)) != 0 {
if c.operator == in {
// value is a slice of values, use c.inValues
value = reflect.New(reflect.TypeOf(c.inValues[0])).Interface()
err := decode(testValue.([]byte), value)
if err != nil {
return false, err
}
} else {
// used with keys
value = reflect.New(reflect.TypeOf(c.value)).Interface()
if keyType != "" {
err := decodeKey(testValue.([]byte), value, keyType)
if err != nil {
return false, err
}
} else {
err := decode(testValue.([]byte), value)
if err != nil {
return false, err
}
}
}
}
} else {
value = testValue
}
switch c.operator {
case in:
for i := range c.inValues {
result, err := c.compare(value, c.inValues[i], currentRow)
if err != nil {
return false, err
}
if result == 0 {
return true, nil
}
}
return false, nil
case re:
return c.value.(*regexp.Regexp).Match([]byte(fmt.Sprintf("%s", value))), nil
case fn:
return c.value.(MatchFunc)(&RecordAccess{
field: value,
record: currentRow,
query: c.query,
})
case isnil:
return reflect.ValueOf(value).IsNil(), nil
case sw:
return strings.HasPrefix(fmt.Sprintf("%s", value), fmt.Sprintf("%s", c.value)), nil
case ew:
return strings.HasSuffix(fmt.Sprintf("%s", value), fmt.Sprintf("%s", c.value)), nil
default:
// comparison operators
result, err := c.compare(value, c.value, currentRow)
if err != nil {
return false, err
}
switch c.operator {
case eq:
return result == 0, nil
case ne:
return result != 0, nil
case gt:
return result > 0, nil
case lt:
return result < 0, nil
case le:
return result < 0 || result == 0, nil
case ge:
return result > 0 || result == 0, nil
default:
panic("invalid operator")
}
}
}
func matchesAllCriteria(criteria []*Criterion, value interface{}, encoded bool, keyType string,
currentRow interface{}) (bool, error) {
for i := range criteria {
ok, err := criteria[i].test(value, encoded, keyType, currentRow)
if err != nil {
return false, err
}
if !ok {
return false, nil
}
}
return true, nil
}
func startsUpper(str string) bool {
if str == "" {
return true
}
for _, r := range str {
return unicode.IsUpper(r)
}
return false
}
func (q *Query) String() string {
s := ""
if q.index != "" {
s += "Using Index [" + q.index + "] "
}
s += "Where "
for field, criteria := range q.fieldCriteria {
for i := range criteria {
s += field + " " + criteria[i].String()
s += "\n\tAND "
}
}
// remove last AND
s = s[:len(s)-6]
for i := range q.ors {
s += "\nOr " + q.ors[i].String()
}
return s
}
func (c *Criterion) String() string {
s := ""
switch c.operator {
case eq:
s += "=="
case ne:
s += "!="
case gt:
s += ">"
case lt:
s += "<"
case le:
s += "<="
case ge:
s += ">="
case in:
return "in " + fmt.Sprintf("%v", c.inValues)
case re:
s += "matches the regular expression"
case fn:
s += "matches the function"
case isnil:
return "is nil"
case sw:
return "starts with " + fmt.Sprintf("%+v", c.value)
case ew:
return "ends with " + fmt.Sprintf("%+v", c.value)
default:
panic("invalid operator")
}
return s + " " + fmt.Sprintf("%v", c.value)
}
type record struct {
key []byte
value reflect.Value
}
func runQuery(tx *badger.Txn, dataType interface{}, query *Query, retrievedKeys keyList, skip int,
action func(r *record) error) error {
storer := newStorer(dataType)
tp := dataType
for reflect.TypeOf(tp).Kind() == reflect.Ptr {
tp = reflect.ValueOf(tp).Elem().Interface()
}
query.dataType = reflect.TypeOf(tp)
if len(query.sort) > 0 {
return runQuerySort(tx, dataType, query, action)
}
iter := newIterator(tx, storer.Type(), query, query.bookmark)
if (query.writable || query.subquery) && query.bookmark == nil {
query.bookmark = iter.createBookmark()
}
defer func() {
iter.Close()
query.bookmark = nil
}()
if query.index != "" && query.badIndex {
return fmt.Errorf("The index %s does not exist", query.index)
}
newKeys := make(keyList, 0)
limit := query.limit - len(retrievedKeys)
for k, v := iter.Next(); k != nil; k, v = iter.Next() {
if len(retrievedKeys) != 0 {
// don't check this record if it's already been retrieved
if retrievedKeys.in(k) {
continue
}
}
val := reflect.New(reflect.TypeOf(tp))
err := decode(v, val.Interface())
if err != nil {
return err
}
query.tx = tx
ok, err := query.matchesAllFields(k, val, val.Interface())
if err != nil {
return err
}
if ok {
if skip > 0 {
skip--
continue
}
err = action(&record{
key: k,
value: val,
})
if err != nil {
return err
}
// track that this key's entry has been added to the result list
newKeys.add(k)
if query.limit != 0 {
limit--
if limit == 0 {
break
}
}
}
}
if iter.Error() != nil {
return iter.Error()
}
if query.limit != 0 && limit == 0 {
return nil
}
if len(query.ors) > 0 {
iter.Close()
for i := range newKeys {
retrievedKeys.add(newKeys[i])
}
for i := range query.ors {
err := runQuery(tx, tp, query.ors[i], retrievedKeys, skip, action)
if err != nil {
return err
}
}
}
return nil
}
// runQuerySort runs the query without sort, skip, or limit, then applies them to the entire result set
func runQuerySort(tx *badger.Txn, dataType interface{}, query *Query, action func(r *record) error) error {
// Validate sort fields
for _, field := range query.sort {
fields := strings.Split(field, ".")
current := query.dataType
for i := range fields {
var structField reflect.StructField
found := false
if current.Kind() == reflect.Ptr {
structField, found = current.Elem().FieldByName(fields[i])
} else {
structField, found = current.FieldByName(fields[i])
}
if !found {
return fmt.Errorf("The field %s does not exist in the type %s", field, query.dataType)
}
current = structField.Type
}
}
// Run query without sort, skip or limit
// apply sort, skip and limit to entire dataset
qCopy := *query
qCopy.sort = nil
qCopy.limit = 0
qCopy.skip = 0
var records []*record
err := runQuery(tx, dataType, &qCopy, nil, 0,
func(r *record) error {
records = append(records, r)
return nil
})
if err != nil {
return err
}
sort.Slice(records, func(i, j int) bool {
for _, field := range query.sort {
val, err := fieldValue(records[i].value.Elem(), field)
if err != nil {
panic(err.Error()) // shouldn't happen due to field check above
}
value := val.Interface()
val, err = fieldValue(records[j].value.Elem(), field)
if err != nil {
panic(err.Error()) // shouldn't happen due to field check above
}
other := val.Interface()
if query.reverse {
value, other = other, value
}
cmp, cerr := compare(value, other)
if cerr != nil {
// if for some reason there is an error on compare, fallback to a lexicographic compare
valS := fmt.Sprintf("%s", value)
otherS := fmt.Sprintf("%s", other)
if valS < otherS {
return true
} else if valS == otherS {
continue
}
return false
}
if cmp == -1 {
return true
} else if cmp == 0 {
continue
}
return false
}
return false
})
// apply skip and limit
limit := query.limit
skip := query.skip
if skip > len(records) {
records = records[0:0]
} else {
records = records[skip:]
}
if limit > 0 && limit <= len(records) {
records = records[:limit]
}
for i := range records {
err = action(records[i])
if err != nil {
return err
}
}
return nil
}
func findQuery(tx *badger.Txn, result interface{}, query *Query) error {
if query == nil {
query = &Query{}
}
query.writable = false
resultVal := reflect.ValueOf(result)
if resultVal.Kind() != reflect.Ptr || resultVal.Elem().Kind() != reflect.Slice {
panic("result argument must be a slice address")
}
sliceVal := resultVal.Elem()
elType := sliceVal.Type().Elem()
tp := elType
for tp.Kind() == reflect.Ptr {
tp = tp.Elem()
}
var keyType reflect.Type
var keyField string
for i := 0; i < tp.NumField(); i++ {
if strings.Contains(string(tp.Field(i).Tag), BadgerholdKeyTag) ||
tp.Field(i).Tag.Get(badgerholdPrefixTag) == badgerholdPrefixKeyValue {
keyType = tp.Field(i).Type
keyField = tp.Field(i).Name
break
}
}
val := reflect.New(tp)
err := runQuery(tx, val.Interface(), query, nil, query.skip,
func(r *record) error {
var rowValue reflect.Value
if elType.Kind() == reflect.Ptr {
rowValue = r.value
} else {
rowValue = r.value.Elem()
}
if keyType != nil {
rowKey := rowValue
for rowKey.Kind() == reflect.Ptr {
rowKey = rowKey.Elem()
}
err := decodeKey(r.key, rowKey.FieldByName(keyField).Addr().Interface(), tp.Name())
if err != nil {
return err
}
}
sliceVal = reflect.Append(sliceVal, rowValue)
return nil
})
if err != nil {
return err
}
resultVal.Elem().Set(sliceVal.Slice(0, sliceVal.Len()))
return nil
}
func deleteQuery(tx *badger.Txn, dataType interface{}, query *Query) error {
if query == nil {
query = &Query{}
}
query.writable = true
var records []*record
err := runQuery(tx, dataType, query, nil, query.skip,
func(r *record) error {
records = append(records, r)
return nil
})
if err != nil {
return err
}
storer := newStorer(dataType)
for i := range records {
err := tx.Delete(records[i].key)
if err != nil {
return err
}
// remove any indexes
err = indexDelete(storer, tx, records[i].key, records[i].value.Interface())
if err != nil {
return err
}
}
return nil
}
func updateQuery(tx *badger.Txn, dataType interface{}, query *Query, update func(record interface{}) error) error {
if query == nil {
query = &Query{}
}
query.writable = true
var records []*record
err := runQuery(tx, dataType, query, nil, query.skip,
func(r *record) error {
records = append(records, r)
return nil
})
if err != nil {
return err
}
storer := newStorer(dataType)
for i := range records {
upVal := records[i].value.Interface()
// delete any existing indexes bad on original value
err := indexDelete(storer, tx, records[i].key, upVal)
if err != nil {
return err
}
err = update(upVal)
if err != nil {
return err
}
encVal, err := encode(upVal)
if err != nil {
return err
}
err = tx.Set(records[i].key, encVal)
if err != nil {
return err
}
// insert any new indexes
err = indexAdd(storer, tx, records[i].key, upVal)
if err != nil {
return err
}
}
return nil
}
func aggregateQuery(tx *badger.Txn, dataType interface{}, query *Query, groupBy ...string) ([]*AggregateResult, error) {
if query == nil {
query = &Query{}
}
query.writable = false
var result []*AggregateResult
if len(groupBy) == 0 {
result = append(result, &AggregateResult{})
}
err := runQuery(tx, dataType, query, nil, query.skip,
func(r *record) error {
if len(groupBy) == 0 {
result[0].reduction = append(result[0].reduction, r.value)
return nil
}
grouping := make([]reflect.Value, len(groupBy))
for i := range groupBy {
fVal := r.value.Elem().FieldByName(groupBy[i])
if !fVal.IsValid() {
return fmt.Errorf("The field %s does not exist in the type %s", groupBy[i],
r.value.Type())
}
grouping[i] = fVal
}
var err error
var c int
var allEqual bool
i := sort.Search(len(result), func(i int) bool {
for j := range grouping {
c, err = compare(result[i].group[j].Interface(), grouping[j].Interface())
if err != nil {
return true
}
if c != 0 {
return c >= 0
}
// if group part is equal, compare the next group part
}
allEqual = true
return true
})
if err != nil {
return err
}
if i < len(result) {
if allEqual {
// group already exists, append results to reduction
result[i].reduction = append(result[i].reduction, r.value)
return nil
}
}
// group not found, create another grouping at i
result = append(result, nil)
copy(result[i+1:], result[i:])
result[i] = &AggregateResult{
group: grouping,
reduction: []reflect.Value{r.value},
}
return nil
})
if err != nil {
return nil, err
}
return result, nil
}