mirror of https://github.com/grafana/grafana.git
				
				
				
			
		
			
				
	
	
		
			291 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			291 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Go
		
	
	
	
| package postgres
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"path/filepath"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| 
 | |
| 	"github.com/grafana/grafana/pkg/components/securejsondata"
 | |
| 	"github.com/grafana/grafana/pkg/components/simplejson"
 | |
| 	"github.com/grafana/grafana/pkg/infra/log"
 | |
| 	"github.com/grafana/grafana/pkg/models"
 | |
| 	"github.com/grafana/grafana/pkg/setting"
 | |
| 	"github.com/stretchr/testify/assert"
 | |
| 	"github.com/stretchr/testify/require"
 | |
| 
 | |
| 	_ "github.com/lib/pq"
 | |
| )
 | |
| 
 | |
| var writeCertFileCallNum int
 | |
| 
 | |
| // TestDataSourceCacheManager is to test the Cache manager
 | |
| func TestDataSourceCacheManager(t *testing.T) {
 | |
| 	cfg := setting.NewCfg()
 | |
| 	cfg.DataPath = t.TempDir()
 | |
| 	mng := tlsManager{
 | |
| 		logger:          log.New("tsdb.postgres"),
 | |
| 		dsCacheInstance: datasourceCacheManager{locker: newLocker()},
 | |
| 		dataPath:        cfg.DataPath,
 | |
| 	}
 | |
| 
 | |
| 	jsonData := simplejson.NewFromAny(map[string]interface{}{
 | |
| 		"sslmode":                "verify-full",
 | |
| 		"tlsConfigurationMethod": "file-content",
 | |
| 	})
 | |
| 	secureJSONData := securejsondata.GetEncryptedJsonData(map[string]string{
 | |
| 		"tlsClientCert": "I am client certification",
 | |
| 		"tlsClientKey":  "I am client key",
 | |
| 		"tlsCACert":     "I am CA certification",
 | |
| 	})
 | |
| 
 | |
| 	mockValidateCertFilePaths()
 | |
| 	t.Cleanup(resetValidateCertFilePaths)
 | |
| 
 | |
| 	t.Run("Check datasource cache creation", func(t *testing.T) {
 | |
| 		var wg sync.WaitGroup
 | |
| 		wg.Add(10)
 | |
| 		for id := int64(1); id <= 10; id++ {
 | |
| 			go func(id int64) {
 | |
| 				ds := &models.DataSource{
 | |
| 					Id:             id,
 | |
| 					Version:        1,
 | |
| 					Database:       "database",
 | |
| 					JsonData:       jsonData,
 | |
| 					SecureJsonData: secureJSONData,
 | |
| 					Uid:            "testData",
 | |
| 				}
 | |
| 				s := tlsSettings{}
 | |
| 				err := mng.writeCertFiles(ds, &s)
 | |
| 				require.NoError(t, err)
 | |
| 				wg.Done()
 | |
| 			}(id)
 | |
| 		}
 | |
| 		wg.Wait()
 | |
| 
 | |
| 		t.Run("check cache creation is succeed", func(t *testing.T) {
 | |
| 			for id := int64(1); id <= 10; id++ {
 | |
| 				version, ok := mng.dsCacheInstance.cache.Load(strconv.Itoa(int(id)))
 | |
| 				require.True(t, ok)
 | |
| 				require.Equal(t, int(1), version)
 | |
| 			}
 | |
| 		})
 | |
| 	})
 | |
| 
 | |
| 	t.Run("Check datasource cache modification", func(t *testing.T) {
 | |
| 		t.Run("check when version not changed, cache and files are not updated", func(t *testing.T) {
 | |
| 			mockWriteCertFile()
 | |
| 			t.Cleanup(resetWriteCertFile)
 | |
| 			var wg1 sync.WaitGroup
 | |
| 			wg1.Add(5)
 | |
| 			for id := int64(1); id <= 5; id++ {
 | |
| 				go func(id int64) {
 | |
| 					ds := &models.DataSource{
 | |
| 						Id:             1,
 | |
| 						Version:        2,
 | |
| 						Database:       "database",
 | |
| 						JsonData:       jsonData,
 | |
| 						SecureJsonData: secureJSONData,
 | |
| 						Uid:            "testData",
 | |
| 					}
 | |
| 					s := tlsSettings{}
 | |
| 					err := mng.writeCertFiles(ds, &s)
 | |
| 					require.NoError(t, err)
 | |
| 					wg1.Done()
 | |
| 				}(id)
 | |
| 			}
 | |
| 			wg1.Wait()
 | |
| 			assert.Equal(t, writeCertFileCallNum, 3)
 | |
| 		})
 | |
| 
 | |
| 		t.Run("cache is updated with the last datasource version", func(t *testing.T) {
 | |
| 			dsV2 := &models.DataSource{
 | |
| 				Id:             1,
 | |
| 				Version:        2,
 | |
| 				Database:       "database",
 | |
| 				JsonData:       jsonData,
 | |
| 				SecureJsonData: secureJSONData,
 | |
| 				Uid:            "testData",
 | |
| 			}
 | |
| 			dsV3 := &models.DataSource{
 | |
| 				Id:             1,
 | |
| 				Version:        3,
 | |
| 				Database:       "database",
 | |
| 				JsonData:       jsonData,
 | |
| 				SecureJsonData: secureJSONData,
 | |
| 				Uid:            "testData",
 | |
| 			}
 | |
| 			s := tlsSettings{}
 | |
| 			err := mng.writeCertFiles(dsV2, &s)
 | |
| 			require.NoError(t, err)
 | |
| 			err = mng.writeCertFiles(dsV3, &s)
 | |
| 			require.NoError(t, err)
 | |
| 			version, ok := mng.dsCacheInstance.cache.Load("1")
 | |
| 			require.True(t, ok)
 | |
| 			require.Equal(t, int(3), version)
 | |
| 		})
 | |
| 	})
 | |
| }
 | |
| 
 | |
| // Test getFileName
 | |
| 
 | |
| func TestGetFileName(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc                  string
 | |
| 		datadir               string
 | |
| 		fileType              certFileType
 | |
| 		expErr                string
 | |
| 		expectedGeneratedPath string
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:                  "Get File Name for root certification",
 | |
| 			datadir:               ".",
 | |
| 			fileType:              rootCert,
 | |
| 			expectedGeneratedPath: "root.crt",
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                  "Get File Name for client certification",
 | |
| 			datadir:               ".",
 | |
| 			fileType:              clientCert,
 | |
| 			expectedGeneratedPath: "client.crt",
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                  "Get File Name for client certification",
 | |
| 			datadir:               ".",
 | |
| 			fileType:              clientKey,
 | |
| 			expectedGeneratedPath: "client.key",
 | |
| 		},
 | |
| 	}
 | |
| 	for _, tt := range testCases {
 | |
| 		t.Run(tt.desc, func(t *testing.T) {
 | |
| 			generatedPath := getFileName(tt.datadir, tt.fileType)
 | |
| 			assert.Equal(t, tt.expectedGeneratedPath, generatedPath)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Test getTLSSettings.
 | |
| func TestGetTLSSettings(t *testing.T) {
 | |
| 	cfg := setting.NewCfg()
 | |
| 	cfg.DataPath = t.TempDir()
 | |
| 
 | |
| 	mockValidateCertFilePaths()
 | |
| 	t.Cleanup(resetValidateCertFilePaths)
 | |
| 	testCases := []struct {
 | |
| 		desc           string
 | |
| 		expErr         string
 | |
| 		jsonData       map[string]interface{}
 | |
| 		secureJSONData map[string]string
 | |
| 		uid            string
 | |
| 		tlsSettings    tlsSettings
 | |
| 		version        int
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:    "Custom TLS authentication disabled",
 | |
| 			version: 1,
 | |
| 			jsonData: map[string]interface{}{
 | |
| 				"sslmode":                "disable",
 | |
| 				"sslRootCertFile":        "i/am/coding/ca.crt",
 | |
| 				"sslCertFile":            "i/am/coding/client.crt",
 | |
| 				"sslKeyFile":             "i/am/coding/client.key",
 | |
| 				"tlsConfigurationMethod": "file-path",
 | |
| 			},
 | |
| 			tlsSettings: tlsSettings{Mode: "disable"},
 | |
| 		},
 | |
| 		{
 | |
| 			desc:    "Custom TLS authentication with file path",
 | |
| 			version: 2,
 | |
| 			jsonData: map[string]interface{}{
 | |
| 				"sslmode":                "verify-full",
 | |
| 				"sslRootCertFile":        "i/am/coding/ca.crt",
 | |
| 				"sslCertFile":            "i/am/coding/client.crt",
 | |
| 				"sslKeyFile":             "i/am/coding/client.key",
 | |
| 				"tlsConfigurationMethod": "file-path",
 | |
| 			},
 | |
| 			tlsSettings: tlsSettings{
 | |
| 				Mode:                "verify-full",
 | |
| 				ConfigurationMethod: "file-path",
 | |
| 				RootCertFile:        "i/am/coding/ca.crt",
 | |
| 				CertFile:            "i/am/coding/client.crt",
 | |
| 				CertKeyFile:         "i/am/coding/client.key",
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			desc:    "Custom TLS mode verify-full with certificate files content",
 | |
| 			version: 3,
 | |
| 			uid:     "xxx",
 | |
| 			jsonData: map[string]interface{}{
 | |
| 				"sslmode":                "verify-full",
 | |
| 				"tlsConfigurationMethod": "file-content",
 | |
| 			},
 | |
| 			secureJSONData: map[string]string{
 | |
| 				"tlsCACert":     "I am CA certification",
 | |
| 				"tlsClientCert": "I am client certification",
 | |
| 				"tlsClientKey":  "I am client key",
 | |
| 			},
 | |
| 			tlsSettings: tlsSettings{
 | |
| 				Mode:                "verify-full",
 | |
| 				ConfigurationMethod: "file-content",
 | |
| 				RootCertFile:        filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"),
 | |
| 				CertFile:            filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"),
 | |
| 				CertKeyFile:         filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"),
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 	for _, tt := range testCases {
 | |
| 		t.Run(tt.desc, func(t *testing.T) {
 | |
| 			var settings tlsSettings
 | |
| 			var err error
 | |
| 			mng := tlsManager{
 | |
| 				logger:          log.New("tsdb.postgres"),
 | |
| 				dsCacheInstance: datasourceCacheManager{locker: newLocker()},
 | |
| 				dataPath:        cfg.DataPath,
 | |
| 			}
 | |
| 
 | |
| 			jsonData := simplejson.NewFromAny(tt.jsonData)
 | |
| 			ds := &models.DataSource{
 | |
| 				JsonData:       jsonData,
 | |
| 				SecureJsonData: securejsondata.GetEncryptedJsonData(tt.secureJSONData),
 | |
| 				Uid:            tt.uid,
 | |
| 				Version:        tt.version,
 | |
| 			}
 | |
| 
 | |
| 			settings, err = mng.getTLSSettings(ds)
 | |
| 
 | |
| 			if tt.expErr == "" {
 | |
| 				require.NoError(t, err, tt.desc)
 | |
| 				assert.Equal(t, tt.tlsSettings, settings)
 | |
| 			} else {
 | |
| 				require.Error(t, err, tt.desc)
 | |
| 				assert.True(t, strings.HasPrefix(err.Error(), tt.expErr),
 | |
| 					fmt.Sprintf("%s: %q doesn't start with %q", tt.desc, err, tt.expErr))
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func mockValidateCertFilePaths() {
 | |
| 	validateCertFunc = func(rootCert, clientCert, clientKey string) error {
 | |
| 		return nil
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func resetValidateCertFilePaths() {
 | |
| 	validateCertFunc = validateCertFilePaths
 | |
| }
 | |
| 
 | |
| func mockWriteCertFile() {
 | |
| 	writeCertFileCallNum = 0
 | |
| 	writeCertFileFunc = func(ds *models.DataSource, logger log.Logger, fileContent string, generatedFilePath string) error {
 | |
| 		writeCertFileCallNum++
 | |
| 		return nil
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func resetWriteCertFile() {
 | |
| 	writeCertFileCallNum = 0
 | |
| 	writeCertFileFunc = writeCertFile
 | |
| }
 |