diff --git a/tsdb/agent/db.go b/tsdb/agent/db.go index 02a2eada88..7884366ebe 100644 --- a/tsdb/agent/db.go +++ b/tsdb/agent/db.go @@ -437,7 +437,7 @@ func (db *DB) resetWALReplayResources() { func (db *DB) loadWAL(r *wlog.Reader, multiRef map[chunks.HeadSeriesRef]chunks.HeadSeriesRef) (err error) { var ( syms = labels.NewSymbolTable() // One table for the whole WAL. - dec = record.NewDecoder(syms) + dec = record.NewDecoder(syms, db.logger) lastRef = chunks.HeadSeriesRef(db.nextRef.Load()) decoded = make(chan any, 10) diff --git a/tsdb/agent/db_test.go b/tsdb/agent/db_test.go index 7dc1f812a0..c2674c8871 100644 --- a/tsdb/agent/db_test.go +++ b/tsdb/agent/db_test.go @@ -211,7 +211,7 @@ func TestCommit(t *testing.T) { // Read records from WAL and check for expected count of series, samples, and exemplars. var ( r = wlog.NewReader(sr) - dec = record.NewDecoder(labels.NewSymbolTable()) + dec = record.NewDecoder(labels.NewSymbolTable(), promslog.NewNopLogger()) walSeriesCount, walSamplesCount, walExemplarsCount, walHistogramCount, walFloatHistogramCount int ) @@ -344,7 +344,7 @@ func TestRollback(t *testing.T) { // Read records from WAL and check for expected count of series and samples. var ( r = wlog.NewReader(sr) - dec = record.NewDecoder(labels.NewSymbolTable()) + dec = record.NewDecoder(labels.NewSymbolTable(), promslog.NewNopLogger()) walSeriesCount, walSamplesCount, walHistogramCount, walFloatHistogramCount, walExemplarsCount int ) @@ -892,7 +892,7 @@ func TestStorage_DuplicateExemplarsIgnored(t *testing.T) { defer sr.Close() r := wlog.NewReader(sr) - dec := record.NewDecoder(labels.NewSymbolTable()) + dec := record.NewDecoder(labels.NewSymbolTable(), promslog.NewNopLogger()) for r.Next() { rec := r.Record() if dec.Type(rec) == record.Exemplars { @@ -1332,7 +1332,7 @@ func readWALSamples(t *testing.T, walDir string) []*walSample { }(sr) r := wlog.NewReader(sr) - dec := record.NewDecoder(labels.NewSymbolTable()) + dec := record.NewDecoder(labels.NewSymbolTable(), promslog.NewNopLogger()) var ( samples []record.RefSample diff --git a/tsdb/db_test.go b/tsdb/db_test.go index 8e649982fc..cb4697f7bf 100644 --- a/tsdb/db_test.go +++ b/tsdb/db_test.go @@ -334,7 +334,7 @@ func TestDataNotAvailableAfterRollback(t *testing.T) { // Read records from WAL and check for expected count of series and samples. var ( r = wlog.NewReader(sr) - dec = record.NewDecoder(labels.NewSymbolTable()) + dec = record.NewDecoder(labels.NewSymbolTable(), promslog.NewNopLogger()) walSeriesCount, walSamplesCount, walHistogramCount, walFloatHistogramCount, walExemplarsCount int ) @@ -4572,7 +4572,7 @@ func testOOOWALWrite(t *testing.T, }() var records []any - dec := record.NewDecoder(nil) + dec := record.NewDecoder(nil, promslog.NewNopLogger()) for r.Next() { rec := r.Record() switch typ := dec.Type(rec); typ { @@ -7088,7 +7088,7 @@ func testWBLAndMmapReplay(t *testing.T, scenario sampleTypeScenario) { require.NoError(t, err) sr, err := wlog.NewSegmentsReader(originalWblDir) require.NoError(t, err) - dec := record.NewDecoder(labels.NewSymbolTable()) + dec := record.NewDecoder(labels.NewSymbolTable(), promslog.NewNopLogger()) r, markers, addedRecs := wlog.NewReader(sr), 0, 0 for r.Next() { rec := r.Record() diff --git a/tsdb/head_test.go b/tsdb/head_test.go index 37aef96a66..d2470952c2 100644 --- a/tsdb/head_test.go +++ b/tsdb/head_test.go @@ -34,6 +34,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" prom_testutil "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/prometheus/common/promslog" "github.com/stretchr/testify/require" "go.uber.org/atomic" "golang.org/x/sync/errgroup" @@ -185,7 +186,7 @@ func readTestWAL(t testing.TB, dir string) (recs []any) { require.NoError(t, sr.Close()) }() - dec := record.NewDecoder(labels.NewSymbolTable()) + dec := record.NewDecoder(labels.NewSymbolTable(), promslog.NewNopLogger()) r := wlog.NewReader(sr) for r.Next() { diff --git a/tsdb/head_wal.go b/tsdb/head_wal.go index d343355993..3c5390cab4 100644 --- a/tsdb/head_wal.go +++ b/tsdb/head_wal.go @@ -155,7 +155,7 @@ func (h *Head) loadWAL(r *wlog.Reader, syms *labels.SymbolTable, multiRef map[ch go func() { defer close(decoded) var err error - dec := record.NewDecoder(syms) + dec := record.NewDecoder(syms, h.logger) for r.Next() { switch dec.Type(r.Record()) { case record.Series: @@ -767,7 +767,7 @@ func (h *Head) loadWBL(r *wlog.Reader, syms *labels.SymbolTable, multiRef map[ch go func() { defer close(decodedCh) - dec := record.NewDecoder(syms) + dec := record.NewDecoder(syms, h.logger) for r.Next() { var err error rec := r.Record() @@ -1572,7 +1572,7 @@ func (h *Head) loadChunkSnapshot() (int, int, map[chunks.HeadSeriesRef]*memSerie refSeries map[chunks.HeadSeriesRef]*memSeries exemplarBuf []record.RefExemplar syms = labels.NewSymbolTable() // New table for the whole snapshot. - dec = record.NewDecoder(syms) + dec = record.NewDecoder(syms, h.logger) ) wg.Add(concurrency) diff --git a/tsdb/record/record.go b/tsdb/record/record.go index bcddad1b52..bd385058fe 100644 --- a/tsdb/record/record.go +++ b/tsdb/record/record.go @@ -18,6 +18,7 @@ package record import ( "errors" "fmt" + "log/slog" "math" "github.com/prometheus/common/model" @@ -202,10 +203,11 @@ type RefMmapMarker struct { // Decoder decodes series, sample, metadata and tombstone records. type Decoder struct { builder labels.ScratchBuilder + logger *slog.Logger } -func NewDecoder(*labels.SymbolTable) Decoder { // FIXME remove t - return Decoder{builder: labels.NewScratchBuilder(0)} +func NewDecoder(_ *labels.SymbolTable, logger *slog.Logger) Decoder { // FIXME remove t + return Decoder{builder: labels.NewScratchBuilder(0), logger: logger} } // Type returns the type of the record. @@ -433,7 +435,7 @@ func (*Decoder) MmapMarkers(rec []byte, markers []RefMmapMarker) ([]RefMmapMarke return markers, nil } -func (*Decoder) HistogramSamples(rec []byte, histograms []RefHistogramSample) ([]RefHistogramSample, error) { +func (d *Decoder) HistogramSamples(rec []byte, histograms []RefHistogramSample) ([]RefHistogramSample, error) { dec := encoding.Decbuf{B: rec} t := Type(dec.Byte()) if t != HistogramSamples && t != CustomBucketsHistogramSamples { @@ -457,6 +459,18 @@ func (*Decoder) HistogramSamples(rec []byte, histograms []RefHistogramSample) ([ } DecodeHistogram(&dec, rh.H) + + if !histogram.IsKnownSchema(rh.H.Schema) { + d.logger.Warn("invalid histogram schema in WAL record", "schema", rh.H.Schema, "timestamp", rh.T) + continue + } + if rh.H.Schema > histogram.ExponentialSchemaMax && rh.H.Schema <= histogram.ExponentialSchemaMaxReserved { + // This is a very slow path, but it should only happen if the + // record is from a newer Prometheus version that supports higher + // resolution. + rh.H.ReduceResolution(histogram.ExponentialSchemaMax) + } + histograms = append(histograms, rh) } @@ -525,7 +539,7 @@ func DecodeHistogram(buf *encoding.Decbuf, h *histogram.Histogram) { } } -func (*Decoder) FloatHistogramSamples(rec []byte, histograms []RefFloatHistogramSample) ([]RefFloatHistogramSample, error) { +func (d *Decoder) FloatHistogramSamples(rec []byte, histograms []RefFloatHistogramSample) ([]RefFloatHistogramSample, error) { dec := encoding.Decbuf{B: rec} t := Type(dec.Byte()) if t != FloatHistogramSamples && t != CustomBucketsFloatHistogramSamples { @@ -549,6 +563,18 @@ func (*Decoder) FloatHistogramSamples(rec []byte, histograms []RefFloatHistogram } DecodeFloatHistogram(&dec, rh.FH) + + if !histogram.IsKnownSchema(rh.FH.Schema) { + d.logger.Warn("invalid histogram schema in WAL record", "schema", rh.FH.Schema, "timestamp", rh.T) + continue + } + if rh.FH.Schema > histogram.ExponentialSchemaMax && rh.FH.Schema <= histogram.ExponentialSchemaMaxReserved { + // This is a very slow path, but it should only happen if the + // record is from a newer Prometheus version that supports higher + // resolution. + rh.FH.ReduceResolution(histogram.ExponentialSchemaMax) + } + histograms = append(histograms, rh) } diff --git a/tsdb/record/record_test.go b/tsdb/record/record_test.go index 6734f907f0..6d95130f26 100644 --- a/tsdb/record/record_test.go +++ b/tsdb/record/record_test.go @@ -15,11 +15,13 @@ package record import ( + "bytes" "fmt" "math/rand" "testing" "github.com/prometheus/common/model" + "github.com/prometheus/common/promslog" "github.com/stretchr/testify/require" "github.com/prometheus/prometheus/model/histogram" @@ -32,7 +34,7 @@ import ( func TestRecord_EncodeDecode(t *testing.T) { var enc Encoder - dec := NewDecoder(labels.NewSymbolTable()) + dec := NewDecoder(labels.NewSymbolTable(), promslog.NewNopLogger()) series := []RefSeries{ { @@ -224,11 +226,151 @@ func TestRecord_EncodeDecode(t *testing.T) { require.Equal(t, floatHistograms, decGaugeFloatHistograms) } +func TestRecord_DecodeInvalidHistogramSchema(t *testing.T) { + for _, schema := range []int32{-100, 100} { + t.Run(fmt.Sprintf("schema=%d", schema), func(t *testing.T) { + var enc Encoder + + var output bytes.Buffer + logger := promslog.New(&promslog.Config{Writer: &output}) + dec := NewDecoder(labels.NewSymbolTable(), logger) + histograms := []RefHistogramSample{ + { + Ref: 56, + T: 1234, + H: &histogram.Histogram{ + Count: 5, + ZeroCount: 2, + ZeroThreshold: 0.001, + Sum: 18.4 * rand.Float64(), + Schema: schema, + PositiveSpans: []histogram.Span{ + {Offset: 0, Length: 2}, + {Offset: 1, Length: 2}, + }, + PositiveBuckets: []int64{1, 1, -1, 0}, + }, + }, + } + histSamples, _ := enc.HistogramSamples(histograms, nil) + decHistograms, err := dec.HistogramSamples(histSamples, nil) + require.NoError(t, err) + require.Empty(t, decHistograms) + require.Contains(t, output.String(), "invalid histogram schema in WAL record") + }) + } +} + +func TestRecord_DecodeInvalidFloatHistogramSchema(t *testing.T) { + for _, schema := range []int32{-100, 100} { + t.Run(fmt.Sprintf("schema=%d", schema), func(t *testing.T) { + var enc Encoder + + var output bytes.Buffer + logger := promslog.New(&promslog.Config{Writer: &output}) + dec := NewDecoder(labels.NewSymbolTable(), logger) + histograms := []RefFloatHistogramSample{ + { + Ref: 56, + T: 1234, + FH: &histogram.FloatHistogram{ + Count: 5, + ZeroCount: 2, + ZeroThreshold: 0.001, + Sum: 18.4 * rand.Float64(), + Schema: schema, + PositiveSpans: []histogram.Span{ + {Offset: 0, Length: 2}, + {Offset: 1, Length: 2}, + }, + PositiveBuckets: []float64{1, 1, -1, 0}, + }, + }, + } + histSamples, _ := enc.FloatHistogramSamples(histograms, nil) + decHistograms, err := dec.FloatHistogramSamples(histSamples, nil) + require.NoError(t, err) + require.Empty(t, decHistograms) + require.Contains(t, output.String(), "invalid histogram schema in WAL record") + }) + } +} + +func TestRecord_DecodeTooHighResolutionHistogramSchema(t *testing.T) { + for _, schema := range []int32{9, 52} { + t.Run(fmt.Sprintf("schema=%d", schema), func(t *testing.T) { + var enc Encoder + + var output bytes.Buffer + logger := promslog.New(&promslog.Config{Writer: &output}) + dec := NewDecoder(labels.NewSymbolTable(), logger) + histograms := []RefHistogramSample{ + { + Ref: 56, + T: 1234, + H: &histogram.Histogram{ + Count: 5, + ZeroCount: 2, + ZeroThreshold: 0.001, + Sum: 18.4 * rand.Float64(), + Schema: schema, + PositiveSpans: []histogram.Span{ + {Offset: 0, Length: 2}, + {Offset: 1, Length: 2}, + }, + PositiveBuckets: []int64{1, 1, -1, 0}, + }, + }, + } + histSamples, _ := enc.HistogramSamples(histograms, nil) + decHistograms, err := dec.HistogramSamples(histSamples, nil) + require.NoError(t, err) + require.Len(t, decHistograms, 1) + require.Equal(t, histogram.ExponentialSchemaMax, decHistograms[0].H.Schema) + }) + } +} + +func TestRecord_DecodeTooHighResolutionFloatHistogramSchema(t *testing.T) { + for _, schema := range []int32{9, 52} { + t.Run(fmt.Sprintf("schema=%d", schema), func(t *testing.T) { + var enc Encoder + + var output bytes.Buffer + logger := promslog.New(&promslog.Config{Writer: &output}) + dec := NewDecoder(labels.NewSymbolTable(), logger) + histograms := []RefFloatHistogramSample{ + { + Ref: 56, + T: 1234, + FH: &histogram.FloatHistogram{ + Count: 5, + ZeroCount: 2, + ZeroThreshold: 0.001, + Sum: 18.4 * rand.Float64(), + Schema: schema, + PositiveSpans: []histogram.Span{ + {Offset: 0, Length: 2}, + {Offset: 1, Length: 2}, + }, + PositiveBuckets: []float64{1, 1, -1, 0}, + }, + }, + } + histSamples, _ := enc.FloatHistogramSamples(histograms, nil) + decHistograms, err := dec.FloatHistogramSamples(histSamples, nil) + require.NoError(t, err) + require.Len(t, decHistograms, 1) + require.Equal(t, histogram.ExponentialSchemaMax, decHistograms[0].FH.Schema) + }) + } +} + // TestRecord_Corrupted ensures that corrupted records return the correct error. // Bugfix check for pull/521 and pull/523. func TestRecord_Corrupted(t *testing.T) { var enc Encoder - dec := NewDecoder(labels.NewSymbolTable()) + dec := NewDecoder(labels.NewSymbolTable(), promslog.NewNopLogger()) t.Run("Test corrupted series record", func(t *testing.T) { series := []RefSeries{ diff --git a/tsdb/wlog/checkpoint.go b/tsdb/wlog/checkpoint.go index 64abe21aa9..c26f3f1052 100644 --- a/tsdb/wlog/checkpoint.go +++ b/tsdb/wlog/checkpoint.go @@ -156,7 +156,7 @@ func Checkpoint(logger *slog.Logger, w *WL, from, to int, keep func(id chunks.He exemplars []record.RefExemplar metadata []record.RefMetadata st = labels.NewSymbolTable() // Needed for decoding; labels do not outlive this function. - dec = record.NewDecoder(st) + dec = record.NewDecoder(st, logger) enc record.Encoder buf []byte recs [][]byte diff --git a/tsdb/wlog/checkpoint_test.go b/tsdb/wlog/checkpoint_test.go index 0b0d11ac45..b83724ea2e 100644 --- a/tsdb/wlog/checkpoint_test.go +++ b/tsdb/wlog/checkpoint_test.go @@ -311,7 +311,7 @@ func TestCheckpoint(t *testing.T) { require.NoError(t, err) defer sr.Close() - dec := record.NewDecoder(labels.NewSymbolTable()) + dec := record.NewDecoder(labels.NewSymbolTable(), promslog.NewNopLogger()) var series []record.RefSeries var metadata []record.RefMetadata r := NewReader(sr) diff --git a/tsdb/wlog/watcher.go b/tsdb/wlog/watcher.go index 95bd554a76..3594cf47bc 100644 --- a/tsdb/wlog/watcher.go +++ b/tsdb/wlog/watcher.go @@ -494,7 +494,7 @@ func (w *Watcher) garbageCollectSeries(segmentNum int) error { // Also used with readCheckpoint - implements segmentReadFn. func (w *Watcher) readSegment(r *LiveReader, segmentNum int, tail bool) error { var ( - dec = record.NewDecoder(labels.NewSymbolTable()) // One table per WAL segment means it won't grow indefinitely. + dec = record.NewDecoder(labels.NewSymbolTable(), w.logger) // One table per WAL segment means it won't grow indefinitely. series []record.RefSeries samples []record.RefSample samplesToSend []record.RefSample @@ -647,7 +647,7 @@ func (w *Watcher) readSegment(r *LiveReader, segmentNum int, tail bool) error { // Used with readCheckpoint - implements segmentReadFn. func (w *Watcher) readSegmentForGC(r *LiveReader, segmentNum int, _ bool) error { var ( - dec = record.NewDecoder(labels.NewSymbolTable()) // Needed for decoding; labels do not outlive this function. + dec = record.NewDecoder(labels.NewSymbolTable(), w.logger) // Needed for decoding; labels do not outlive this function. series []record.RefSeries ) for r.Next() && !isClosed(w.quit) {