Chore: add/update sqlstore-related helper functions (#77408)

* add/update sqlstore-related helper functions

* add documentation & tests for InsertQuery and UpdateQuery, make generated SQL deterministic by sorting columns

* remove old log line
This commit is contained in:
Dan Cech 2023-11-03 10:30:52 -04:00 committed by GitHub
parent 6b729389b5
commit 67b2972052
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 237 additions and 11 deletions

View File

@ -5,6 +5,7 @@ import (
"os"
"xorm.io/core"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
@ -29,6 +30,8 @@ type DB interface {
GetDialect() migrator.Dialect
// GetDBType returns the name of the database type available to the runtime.
GetDBType() core.DbType
// GetEngine returns the underlying xorm engine.
GetEngine() *xorm.Engine
// GetSqlxSession is an experimental extension to use sqlx instead of xorm to
// communicate with the database.
GetSqlxSession() *session.SessionDB

View File

@ -4,6 +4,7 @@ import (
"context"
"xorm.io/core"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
@ -43,6 +44,10 @@ func (f *FakeDB) GetDialect() migrator.Dialect {
return nil
}
func (f *FakeDB) GetEngine() *xorm.Engine {
return nil
}
func (f *FakeDB) GetSqlxSession() *session.SessionDB {
return nil
}

View File

@ -100,7 +100,6 @@ func (s *XormLogger) SetLevel(l core.LogLevel) {
// ShowSQL implement core.ILogger
func (s *XormLogger) ShowSQL(show ...bool) {
s.grafanaLog.Error("ShowSQL", "show", "show")
if len(show) == 0 {
s.showSQL = true
return

View File

@ -5,6 +5,7 @@ import (
"strconv"
"strings"
"golang.org/x/exp/slices"
"xorm.io/xorm"
)
@ -73,6 +74,14 @@ type Dialect interface {
Unlock(LockCfg) error
GetDBName(string) (string, error)
// InsertQuery accepts a table name and a map of column names to values to insert.
// It returns a query string and a slice of parameters that can be executed against the database.
InsertQuery(tableName string, row map[string]any) (string, []any, error)
// UpdateQuery accepts a table name, a map of column names to values to update, and a map of
// column names to values to use in the where clause.
// It returns a query string and a slice of parameters that can be executed against the database.
UpdateQuery(tableName string, row map[string]any, where map[string]any) (string, []any, error)
}
type LockCfg struct {
@ -344,3 +353,71 @@ func (b *BaseDialect) OrderBy(order string) string {
func (b *BaseDialect) GetDBName(_ string) (string, error) {
return "", nil
}
func (b *BaseDialect) InsertQuery(tableName string, row map[string]any) (string, []any, error) {
if len(row) < 1 {
return "", nil, fmt.Errorf("no columns provided")
}
// allocate slices
cols := make([]string, 0, len(row))
vals := make([]any, 0, len(row))
keys := make([]string, 0, len(row))
// create sorted list of columns
for col := range row {
keys = append(keys, col)
}
slices.Sort[string](keys)
// build query and values
for _, col := range keys {
cols = append(cols, b.dialect.Quote(col))
vals = append(vals, row[col])
}
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", b.dialect.Quote(tableName), strings.Join(cols, ", "), strings.Repeat("?, ", len(row)-1)+"?"), vals, nil
}
func (b *BaseDialect) UpdateQuery(tableName string, row map[string]any, where map[string]any) (string, []any, error) {
if len(row) < 1 {
return "", nil, fmt.Errorf("no columns provided")
}
if len(where) < 1 {
return "", nil, fmt.Errorf("no where clause provided")
}
// allocate slices
cols := make([]string, 0, len(row))
whereCols := make([]string, 0, len(where))
vals := make([]any, 0, len(row)+len(where))
keys := make([]string, 0, len(row))
// create sorted list of columns to update
for col := range row {
keys = append(keys, col)
}
slices.Sort[string](keys)
// build update query and values
for _, col := range keys {
cols = append(cols, b.dialect.Quote(col)+"=?")
vals = append(vals, row[col])
}
// create sorted list of columns for where clause
keys = make([]string, 0, len(where))
for col := range where {
keys = append(keys, col)
}
slices.Sort[string](keys)
// build where clause and values
for _, col := range keys {
whereCols = append(whereCols, b.dialect.Quote(col)+"=?")
vals = append(vals, where[col])
}
return fmt.Sprintf("UPDATE %s SET %s WHERE %s", b.dialect.Quote(tableName), strings.Join(cols, ", "), strings.Join(whereCols, " AND ")), vals, nil
}

View File

@ -0,0 +1,117 @@
package migrator
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestInsertQuery(t *testing.T) {
tests := []struct {
name string
tableName string
values map[string]any
expectedErr bool
expectedPostgresQuery string
expectedPostgresArgs []any
expectedMySQLQuery string
expectedMySQLArgs []any
expectedSQLiteQuery string
expectedSQLiteArgs []any
}{
{
"insert one",
"some_table",
map[string]any{"col1": "val1", "col2": "val2", "col3": "val3"},
false,
"INSERT INTO \"some_table\" (\"col1\", \"col2\", \"col3\") VALUES (?, ?, ?)",
[]any{"val1", "val2", "val3"},
"INSERT INTO `some_table` (`col1`, `col2`, `col3`) VALUES (?, ?, ?)",
[]any{"val1", "val2", "val3"},
"INSERT INTO `some_table` (`col1`, `col2`, `col3`) VALUES (?, ?, ?)",
[]any{"val1", "val2", "val3"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var db Dialect
db = NewPostgresDialect()
q, args, err := db.InsertQuery(tc.tableName, tc.values)
require.True(t, (err != nil) == tc.expectedErr)
require.Equal(t, tc.expectedPostgresQuery, q, "Postgres query incorrect")
require.Equal(t, tc.expectedPostgresArgs, args, "Postgres args incorrect")
db = NewMysqlDialect()
q, args, err = db.InsertQuery(tc.tableName, tc.values)
require.True(t, (err != nil) == tc.expectedErr)
require.Equal(t, tc.expectedMySQLQuery, q, "MySQL query incorrect")
require.Equal(t, tc.expectedMySQLArgs, args, "MySQL args incorrect")
db = NewSQLite3Dialect()
q, args, err = db.InsertQuery(tc.tableName, tc.values)
require.True(t, (err != nil) == tc.expectedErr)
require.Equal(t, tc.expectedSQLiteQuery, q, "SQLite query incorrect")
require.Equal(t, tc.expectedSQLiteArgs, args, "SQLite args incorrect")
})
}
}
func TestUpdateQuery(t *testing.T) {
tests := []struct {
name string
tableName string
values map[string]any
where map[string]any
expectedErr bool
expectedPostgresQuery string
expectedPostgresArgs []any
expectedMySQLQuery string
expectedMySQLArgs []any
expectedSQLiteQuery string
expectedSQLiteArgs []any
}{
{
"insert one",
"some_table",
map[string]any{"col1": "val1", "col2": "val2", "col3": "val3"},
map[string]any{"key1": 10},
false,
"UPDATE \"some_table\" SET \"col1\"=?, \"col2\"=?, \"col3\"=? WHERE \"key1\"=?",
[]any{"val1", "val2", "val3", 10},
"UPDATE `some_table` SET `col1`=?, `col2`=?, `col3`=? WHERE `key1`=?",
[]any{"val1", "val2", "val3", 10},
"UPDATE `some_table` SET `col1`=?, `col2`=?, `col3`=? WHERE `key1`=?",
[]any{"val1", "val2", "val3", 10},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var db Dialect
db = NewPostgresDialect()
q, args, err := db.UpdateQuery(tc.tableName, tc.values, tc.where)
require.True(t, (err != nil) == tc.expectedErr)
require.Equal(t, tc.expectedPostgresQuery, q, "Postgres query incorrect")
require.Equal(t, tc.expectedPostgresArgs, args, "Postgres args incorrect")
db = NewMysqlDialect()
q, args, err = db.UpdateQuery(tc.tableName, tc.values, tc.where)
require.True(t, (err != nil) == tc.expectedErr)
require.Equal(t, tc.expectedMySQLQuery, q, "MySQL query incorrect")
require.Equal(t, tc.expectedMySQLArgs, args, "MySQL args incorrect")
db = NewSQLite3Dialect()
q, args, err = db.UpdateQuery(tc.tableName, tc.values, tc.where)
require.True(t, (err != nil) == tc.expectedErr)
require.Equal(t, tc.expectedSQLiteQuery, q, "SQLite query incorrect")
require.Equal(t, tc.expectedSQLiteArgs, args, "SQLite args incorrect")
})
}
}

View File

@ -59,7 +59,7 @@ func NewScopedMigrator(engine *xorm.Engine, cfg *setting.Cfg, scope string) *Mig
mg.Logger = log.New("migrator")
} else {
mg.tableName = scope + "_migration_log"
mg.Logger = log.New(scope + " migrator")
mg.Logger = log.New(scope + "-migrator")
}
return mg
}

View File

@ -69,6 +69,9 @@ func (db *MySQLDialect) SQLType(c *Column) string {
c.Length = 64
case DB_NVarchar:
res = DB_Varchar
case DB_Uuid:
res = DB_Char
c.Length = 36
default:
res = c.Type
}

View File

@ -42,7 +42,7 @@ func (gs *SessionDB) NamedExec(ctx context.Context, query string, arg any) (sql.
return gs.sqlxdb.NamedExecContext(ctx, gs.sqlxdb.Rebind(query), arg)
}
func (gs *SessionDB) driverName() string {
func (gs *SessionDB) DriverName() string {
return gs.sqlxdb.DriverName()
}
@ -69,7 +69,7 @@ func (gs *SessionDB) WithTransaction(ctx context.Context, callback func(*Session
}
func (gs *SessionDB) ExecWithReturningId(ctx context.Context, query string, args ...any) (int64, error) {
return execWithReturningId(ctx, gs.driverName(), query, gs, args...)
return execWithReturningId(ctx, gs.DriverName(), query, gs, args...)
}
type SessionTx struct {
@ -125,3 +125,10 @@ func execWithReturningId(ctx context.Context, driverName string, query string, s
}
return id, nil
}
type SessionQuerier interface {
Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
var _ SessionQuerier = &SessionDB{}
var _ SessionQuerier = &SessionTx{}

View File

@ -1,8 +1,13 @@
package sqlstash
import "strings"
import (
"strings"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
)
type selectQuery struct {
dialect migrator.Dialect
fields []string // SELECT xyz
from string // FROM object
limit int64
@ -12,21 +17,27 @@ type selectQuery struct {
args []any
}
func (q *selectQuery) addWhere(f string, val any) {
q.args = append(q.args, val)
q.where = append(q.where, f+"=?")
func (q *selectQuery) addWhere(f string, val ...any) {
q.args = append(q.args, val...)
// if the field contains a question mark, we assume it's a raw where clause
if strings.Contains(f, "?") {
q.where = append(q.where, f)
// otherwise we assume it's a field name
} else {
q.where = append(q.where, q.dialect.Quote(f)+"=?")
}
}
func (q *selectQuery) addWhereInSubquery(f string, subquery string, subqueryArgs []any) {
q.args = append(q.args, subqueryArgs...)
q.where = append(q.where, f+" IN ("+subquery+")")
q.where = append(q.where, q.dialect.Quote(f)+" IN ("+subquery+")")
}
func (q *selectQuery) addWhereIn(f string, vals []string) {
count := len(vals)
if count > 1 {
sb := strings.Builder{}
sb.WriteString(f)
sb.WriteString(q.dialect.Quote(f))
sb.WriteString(" IN (")
for i := 0; i < count; i++ {
if i > 0 {
@ -46,7 +57,11 @@ func (q *selectQuery) toQuery() (string, []any) {
args := q.args
sb := strings.Builder{}
sb.WriteString("SELECT ")
sb.WriteString(strings.Join(q.fields, ","))
quotedFields := make([]string, len(q.fields))
for i, f := range q.fields {
quotedFields[i] = q.dialect.Quote(f)
}
sb.WriteString(strings.Join(quotedFields, ","))
sb.WriteString(" FROM ")
sb.WriteString(q.from)