mirror of https://github.com/grafana/grafana.git
implement and test the migration
This commit is contained in:
parent
a6f731072d
commit
14bfe58a69
|
|
@ -46,7 +46,7 @@ type GlobalEncryptedValueStorage interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncryptedValueMigrationExecutor interface {
|
type EncryptedValueMigrationExecutor interface {
|
||||||
Execute(ctx context.Context) error
|
Execute(ctx context.Context) (int, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConsolidationService interface {
|
type ConsolidationService interface {
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,6 @@ import (
|
||||||
"github.com/grafana/grafana/pkg/util"
|
"github.com/grafana/grafana/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
keyIdDelimiter = '#'
|
|
||||||
)
|
|
||||||
|
|
||||||
type EncryptionManager struct {
|
type EncryptionManager struct {
|
||||||
tracer trace.Tracer
|
tracer trace.Tracer
|
||||||
store contracts.DataKeyStorage
|
store contracts.DataKeyStorage
|
||||||
|
|
@ -286,20 +282,6 @@ func (s *EncryptionManager) Decrypt(ctx context.Context, namespace string, paylo
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// payload = payload[1:]
|
|
||||||
// endOfKey := bytes.Index(payload, []byte{keyIdDelimiter})
|
|
||||||
// if endOfKey == -1 {
|
|
||||||
// err = fmt.Errorf("could not find valid key id in encrypted payload")
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
// b64Key := payload[:endOfKey]
|
|
||||||
// payload = payload[endOfKey+1:]
|
|
||||||
// keyId := make([]byte, base64.RawStdEncoding.DecodedLen(len(b64Key)))
|
|
||||||
// _, err = base64.RawStdEncoding.Decode(keyId, b64Key)
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
|
|
||||||
dataKey, err := s.dataKeyById(ctx, namespace, payload.DataKeyID)
|
dataKey, err := s.dataKeyById(ctx, namespace, payload.DataKeyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.FromContext(ctx).Error("Failed to lookup data key by id", "id", payload.DataKeyID, "error", err)
|
s.log.FromContext(ctx).Error("Failed to lookup data key by id", "id", payload.DataKeyID, "error", err)
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ import (
|
||||||
osskmsproviders "github.com/grafana/grafana/pkg/registry/apis/secret/encryption/kmsproviders"
|
osskmsproviders "github.com/grafana/grafana/pkg/registry/apis/secret/encryption/kmsproviders"
|
||||||
"github.com/grafana/grafana/pkg/registry/apis/secret/encryption/manager"
|
"github.com/grafana/grafana/pkg/registry/apis/secret/encryption/manager"
|
||||||
"github.com/grafana/grafana/pkg/registry/apis/secret/secretkeeper/sqlkeeper"
|
"github.com/grafana/grafana/pkg/registry/apis/secret/secretkeeper/sqlkeeper"
|
||||||
|
"github.com/grafana/grafana/pkg/registry/apis/secret/testutils"
|
||||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/storage/secret/database"
|
"github.com/grafana/grafana/pkg/storage/secret/database"
|
||||||
|
|
@ -65,7 +66,8 @@ func setupTestService(t *testing.T, cfg *setting.Cfg) (*OSSKeeperService, error)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Initialize the keeper service
|
// Initialize the keeper service
|
||||||
keeperService, err := ProvideService(tracer, encValueStore, encryptionManager, nil)
|
keeperService, err := ProvideService(tracer, encValueStore, encryptionManager, &testutils.NoopMigrationExecutor{}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
return keeperService, err
|
return keeperService, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ func NewSQLKeeper(
|
||||||
reg prometheus.Registerer,
|
reg prometheus.Registerer,
|
||||||
) (*SQLKeeper, error) {
|
) (*SQLKeeper, error) {
|
||||||
// Run the encrypted value store migration before anything else, otherwise operations may fail
|
// Run the encrypted value store migration before anything else, otherwise operations may fail
|
||||||
err := migrationExecutor.Execute(context.Background())
|
_, err := migrationExecutor.Execute(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to execute encrypted value store migration: %w", err)
|
return nil, fmt.Errorf("failed to execute encrypted value store migration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -126,8 +126,13 @@ func Setup(t *testing.T, opts ...func(*SetupConfig)) Sut {
|
||||||
globalEncryptedValueStorage, err := encryptionstorage.ProvideGlobalEncryptedValueStorage(database, tracer)
|
globalEncryptedValueStorage, err := encryptionstorage.ProvideGlobalEncryptedValueStorage(database, tracer)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// TODO create a migration executor
|
// Initialize a noop migration executor for the sql keeper so it doesn't interfere with initialization
|
||||||
sqlKeeper, err := sqlkeeper.NewSQLKeeper(tracer, encryptionManager, encryptedValueStorage, nil, nil)
|
noopMigrationExecutor := &NoopMigrationExecutor{}
|
||||||
|
sqlKeeper, err := sqlkeeper.NewSQLKeeper(tracer, encryptionManager, encryptedValueStorage, noopMigrationExecutor, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Initialize a real migration executor for test
|
||||||
|
realMigrationExecutor, err := encryptionstorage.ProvideEncryptedValueMigrationExecutor(database, tracer, encryptedValueStorage, globalEncryptedValueStorage)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var keeperService contracts.KeeperService = newKeeperServiceWrapper(sqlKeeper)
|
var keeperService contracts.KeeperService = newKeeperServiceWrapper(sqlKeeper)
|
||||||
|
|
@ -160,39 +165,41 @@ func Setup(t *testing.T, opts ...func(*SetupConfig)) Sut {
|
||||||
keeperService)
|
keeperService)
|
||||||
|
|
||||||
return Sut{
|
return Sut{
|
||||||
SecureValueService: secureValueService,
|
SecureValueService: secureValueService,
|
||||||
SecureValueMetadataStorage: secureValueMetadataStorage,
|
SecureValueMetadataStorage: secureValueMetadataStorage,
|
||||||
DecryptStorage: decryptStorage,
|
DecryptStorage: decryptStorage,
|
||||||
DecryptService: decryptService,
|
DecryptService: decryptService,
|
||||||
EncryptedValueStorage: encryptedValueStorage,
|
EncryptedValueStorage: encryptedValueStorage,
|
||||||
GlobalEncryptedValueStorage: globalEncryptedValueStorage,
|
GlobalEncryptedValueStorage: globalEncryptedValueStorage,
|
||||||
SQLKeeper: sqlKeeper,
|
EncryptedValueMigrationExecutor: realMigrationExecutor,
|
||||||
Database: database,
|
SQLKeeper: sqlKeeper,
|
||||||
AccessClient: accessClient,
|
Database: database,
|
||||||
ConsolidationService: consolidationService,
|
AccessClient: accessClient,
|
||||||
EncryptionManager: encryptionManager,
|
ConsolidationService: consolidationService,
|
||||||
GlobalDataKeyStore: globalDataKeyStore,
|
EncryptionManager: encryptionManager,
|
||||||
GarbageCollectionWorker: garbageCollectionWorker,
|
GlobalDataKeyStore: globalDataKeyStore,
|
||||||
Clock: clock,
|
GarbageCollectionWorker: garbageCollectionWorker,
|
||||||
KeeperService: keeperService,
|
Clock: clock,
|
||||||
KeeperMetadataStorage: keeperMetadataStorage,
|
KeeperService: keeperService,
|
||||||
|
KeeperMetadataStorage: keeperMetadataStorage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Sut struct {
|
type Sut struct {
|
||||||
SecureValueService contracts.SecureValueService
|
SecureValueService contracts.SecureValueService
|
||||||
SecureValueMetadataStorage contracts.SecureValueMetadataStorage
|
SecureValueMetadataStorage contracts.SecureValueMetadataStorage
|
||||||
DecryptStorage contracts.DecryptStorage
|
DecryptStorage contracts.DecryptStorage
|
||||||
DecryptService decryptcontracts.DecryptService
|
DecryptService decryptcontracts.DecryptService
|
||||||
EncryptedValueStorage contracts.EncryptedValueStorage
|
EncryptedValueStorage contracts.EncryptedValueStorage
|
||||||
GlobalEncryptedValueStorage contracts.GlobalEncryptedValueStorage
|
GlobalEncryptedValueStorage contracts.GlobalEncryptedValueStorage
|
||||||
SQLKeeper *sqlkeeper.SQLKeeper
|
EncryptedValueMigrationExecutor contracts.EncryptedValueMigrationExecutor
|
||||||
Database *database.Database
|
SQLKeeper *sqlkeeper.SQLKeeper
|
||||||
AccessClient types.AccessClient
|
Database *database.Database
|
||||||
ConsolidationService contracts.ConsolidationService
|
AccessClient types.AccessClient
|
||||||
EncryptionManager contracts.EncryptionManager
|
ConsolidationService contracts.ConsolidationService
|
||||||
GlobalDataKeyStore contracts.GlobalDataKeyStorage
|
EncryptionManager contracts.EncryptionManager
|
||||||
GarbageCollectionWorker *garbagecollectionworker.Worker
|
GlobalDataKeyStore contracts.GlobalDataKeyStorage
|
||||||
|
GarbageCollectionWorker *garbagecollectionworker.Worker
|
||||||
// The fake clock passed to implementations to make testing easier
|
// The fake clock passed to implementations to make testing easier
|
||||||
Clock *FakeClock
|
Clock *FakeClock
|
||||||
KeeperService contracts.KeeperService
|
KeeperService contracts.KeeperService
|
||||||
|
|
@ -368,3 +375,10 @@ func (c *FakeClock) Now() time.Time {
|
||||||
func (c *FakeClock) AdvanceBy(duration time.Duration) {
|
func (c *FakeClock) AdvanceBy(duration time.Duration) {
|
||||||
c.Current = c.Current.Add(duration)
|
c.Current = c.Current.Add(duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NoopMigrationExecutor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NoopMigrationExecutor) Execute(ctx context.Context) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
package encryption
|
package encryption
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -343,26 +345,75 @@ func (s *globalEncryptedValStorage) CountAll(ctx context.Context, untilTime *int
|
||||||
}
|
}
|
||||||
|
|
||||||
type encryptedValMigrationExecutor struct {
|
type encryptedValMigrationExecutor struct {
|
||||||
db contracts.Database
|
db contracts.Database
|
||||||
dialect sqltemplate.Dialect
|
dialect sqltemplate.Dialect
|
||||||
tracer trace.Tracer
|
tracer trace.Tracer
|
||||||
|
encryptedValueStore contracts.EncryptedValueStorage
|
||||||
|
globalStore contracts.GlobalEncryptedValueStorage
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProvideEncryptedValueMigrationExecutor(
|
func ProvideEncryptedValueMigrationExecutor(
|
||||||
db contracts.Database,
|
db contracts.Database,
|
||||||
tracer trace.Tracer,
|
tracer trace.Tracer,
|
||||||
|
encryptedValueStore contracts.EncryptedValueStorage,
|
||||||
|
globalStore contracts.GlobalEncryptedValueStorage,
|
||||||
) (contracts.EncryptedValueMigrationExecutor, error) {
|
) (contracts.EncryptedValueMigrationExecutor, error) {
|
||||||
return &encryptedValMigrationExecutor{
|
return &encryptedValMigrationExecutor{
|
||||||
db: db,
|
db: db,
|
||||||
dialect: sqltemplate.DialectForDriver(db.DriverName()),
|
dialect: sqltemplate.DialectForDriver(db.DriverName()),
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
|
encryptedValueStore: encryptedValueStore,
|
||||||
|
globalStore: globalStore,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *encryptedValMigrationExecutor) Execute(ctx context.Context) error {
|
func (s *encryptedValMigrationExecutor) Execute(ctx context.Context) (int, error) {
|
||||||
ctx, span := s.tracer.Start(ctx, "EncryptedValueMigrationExecutor.Execute")
|
ctx, span := s.tracer.Start(ctx, "EncryptedValueMigrationExecutor.Execute")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
panic("not implemented")
|
// 1. Retrieve all encrypted values
|
||||||
return nil
|
encryptedValues, err := s.globalStore.ListAll(ctx, contracts.ListOpts{}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("listing all encrypted values: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This doesn't need to be done in a single transaction because there's no risk to successful rows if other rows fail
|
||||||
|
rowsAffected := 0
|
||||||
|
for _, encryptedValue := range encryptedValues {
|
||||||
|
// 2. If the value already has the data key id broken out, skip it
|
||||||
|
if encryptedValue.DataKeyID != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Split the data key id and the encrypted data out from the encoded payload
|
||||||
|
payload := encryptedValue.EncryptedData
|
||||||
|
const keyIdDelimiter = '#'
|
||||||
|
payload = payload[1:]
|
||||||
|
endOfKey := bytes.Index(payload, []byte{keyIdDelimiter})
|
||||||
|
if endOfKey == -1 {
|
||||||
|
return 0, fmt.Errorf("could not find valid key id in encrypted payload with namespace %s and name %s and version %d", encryptedValue.Namespace, encryptedValue.Name, encryptedValue.Version)
|
||||||
|
}
|
||||||
|
b64Key := payload[:endOfKey]
|
||||||
|
encryptedData := payload[endOfKey+1:]
|
||||||
|
if len(encryptedData) == 0 {
|
||||||
|
return 0, fmt.Errorf("encrypted data is empty with namespace %s and name %s and version %d", encryptedValue.Namespace, encryptedValue.Name, encryptedValue.Version)
|
||||||
|
}
|
||||||
|
keyId := make([]byte, base64.RawStdEncoding.DecodedLen(len(b64Key)))
|
||||||
|
_, err := base64.RawStdEncoding.Decode(keyId, b64Key)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("decoding key id with namespace %s and name %s and version %d: %w", encryptedValue.Namespace, encryptedValue.Name, encryptedValue.Version, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Update the encrypted value with the data key id and the encrypted data
|
||||||
|
err = s.encryptedValueStore.Update(ctx, encryptedValue.Namespace, encryptedValue.Name, encryptedValue.Version, contracts.EncryptedPayload{
|
||||||
|
DataKeyID: string(keyId),
|
||||||
|
EncryptedData: encryptedData,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("updating encrypted value with namespace %s and name %s and version %d: %w", encryptedValue.Namespace, encryptedValue.Name, encryptedValue.Version, err)
|
||||||
|
}
|
||||||
|
rowsAffected++
|
||||||
|
}
|
||||||
|
|
||||||
|
return rowsAffected, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,23 @@ package encryption_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||||
"github.com/grafana/grafana/pkg/registry/apis/secret/contracts"
|
"github.com/grafana/grafana/pkg/registry/apis/secret/contracts"
|
||||||
|
"github.com/grafana/grafana/pkg/registry/apis/secret/encryption/cipher"
|
||||||
|
cipherService "github.com/grafana/grafana/pkg/registry/apis/secret/encryption/cipher/service"
|
||||||
"github.com/grafana/grafana/pkg/registry/apis/secret/testutils"
|
"github.com/grafana/grafana/pkg/registry/apis/secret/testutils"
|
||||||
"github.com/grafana/grafana/pkg/storage/secret/encryption"
|
"github.com/grafana/grafana/pkg/storage/secret/encryption"
|
||||||
|
"github.com/grafana/grafana/pkg/storage/unified/sql/sqltemplate"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel/trace/noop"
|
||||||
"pgregory.net/rapid"
|
"pgregory.net/rapid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -233,6 +241,135 @@ func TestEncryptedValueStoreImpl(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEncryptedValueMigration(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
sut := testutils.Setup(t)
|
||||||
|
tracer := noop.NewTracerProvider().Tracer("test")
|
||||||
|
usageStats := &usagestats.UsageStatsMock{T: t}
|
||||||
|
enc, err := cipherService.ProvideAESGCMCipherService(tracer, usageStats)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
namespace string
|
||||||
|
name string
|
||||||
|
version int64
|
||||||
|
plaintext string
|
||||||
|
dataKeyId string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
namespace: "test-namespace-1",
|
||||||
|
name: "test-name-1",
|
||||||
|
version: 1,
|
||||||
|
plaintext: "test-plaintext-1",
|
||||||
|
dataKeyId: "test-data-key-id-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
namespace: "test-namespace-1",
|
||||||
|
name: "test-name-2",
|
||||||
|
version: 1,
|
||||||
|
plaintext: "test-plaintext-2",
|
||||||
|
dataKeyId: "test-data-key-id-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
namespace: "test-namespace-2",
|
||||||
|
name: "test-name-3",
|
||||||
|
version: 1,
|
||||||
|
plaintext: "test-plaintext-3",
|
||||||
|
dataKeyId: "test-data-key-id-2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seed with data in the legacy format
|
||||||
|
for _, tc := range testCases {
|
||||||
|
err := createLegacyEncryptedData(t, sut, enc, tc.namespace, tc.name, tc.version, tc.plaintext, tc.dataKeyId)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the migration and blindy trust it
|
||||||
|
rowsAffected, err := sut.EncryptedValueMigrationExecutor.Execute(t.Context())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, len(testCases), rowsAffected)
|
||||||
|
|
||||||
|
// Now validate that the data is in the new format
|
||||||
|
encryptedValues, err := sut.GlobalEncryptedValueStorage.ListAll(t.Context(), contracts.ListOpts{}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, encryptedValues, 3)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
ev, err := sut.EncryptedValueStorage.Get(t.Context(), tc.namespace, tc.name, tc.version)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Decrypt the encrypted data and check for equality
|
||||||
|
decrypted, err := enc.Decrypt(t.Context(), ev.EncryptedData, tc.dataKeyId)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tc.dataKeyId, ev.DataKeyID)
|
||||||
|
require.Equal(t, tc.plaintext, string(decrypted))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function that bypasses interfaces and creates data in the legacy format directly in the database.
|
||||||
|
// The format is "#{encoded_key_id}#{encrypted_data}".
|
||||||
|
func createLegacyEncryptedData(t *testing.T, sut testutils.Sut, enc cipher.Cipher, namespace, name string, version int64, plaintext string, dataKeyId string) error {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
encryptedData, err := enc.Encrypt(t.Context(), []byte(plaintext), dataKeyId)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Encode using the legacy format
|
||||||
|
const keyIdDelimiter = '#'
|
||||||
|
prefix := make([]byte, base64.RawStdEncoding.EncodedLen(len(dataKeyId))+2)
|
||||||
|
base64.RawStdEncoding.Encode(prefix[1:], []byte(dataKeyId))
|
||||||
|
prefix[0] = keyIdDelimiter
|
||||||
|
prefix[len(prefix)-1] = keyIdDelimiter
|
||||||
|
|
||||||
|
blob := make([]byte, len(prefix)+len(encryptedData))
|
||||||
|
copy(blob, prefix)
|
||||||
|
copy(blob[len(prefix):], encryptedData)
|
||||||
|
|
||||||
|
createdTime := time.Now().Unix()
|
||||||
|
|
||||||
|
encryptedValue := &encryption.EncryptedValue{
|
||||||
|
Namespace: namespace,
|
||||||
|
Name: name,
|
||||||
|
Version: version,
|
||||||
|
EncryptedData: blob,
|
||||||
|
DataKeyID: "",
|
||||||
|
Created: createdTime,
|
||||||
|
Updated: createdTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := struct {
|
||||||
|
sqltemplate.SQLTemplate
|
||||||
|
Row *encryption.EncryptedValue
|
||||||
|
}{
|
||||||
|
SQLTemplate: sqltemplate.New(sqltemplate.DialectForDriver(sut.Database.DriverName())),
|
||||||
|
Row: encryptedValue,
|
||||||
|
}
|
||||||
|
tmpl, err := template.ParseFiles("data/encrypted_value_create.sql")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
query, err := sqltemplate.Execute(tmpl, req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("executing template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := sut.Database.ExecContext(t.Context(), query, req.GetArgs()...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inserting row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowsAffected, err := res.RowsAffected(); err != nil {
|
||||||
|
return fmt.Errorf("getting rows affected: %w", err)
|
||||||
|
} else if rowsAffected != 1 {
|
||||||
|
return fmt.Errorf("expected 1 row affected, got %d", rowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestStateMachine(t *testing.T) {
|
func TestStateMachine(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
@ -345,7 +482,7 @@ func (m *model) create(namespace, name string, version int64, encryptedData []by
|
||||||
if err != nil && !errors.Is(err, encryption.ErrEncryptedValueNotFound) {
|
if err != nil && !errors.Is(err, encryption.ErrEncryptedValueNotFound) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// The entry being creted already exists
|
// The entry being created already exists
|
||||||
if v != nil {
|
if v != nil {
|
||||||
return nil, encryption.ErrEncryptedValueAlreadyExists
|
return nil, encryption.ErrEncryptedValueAlreadyExists
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue