mirror of https://github.com/chaitin/PandaWiki.git
115 lines
3.0 KiB
Go
115 lines
3.0 KiB
Go
package pg
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/chaitin/panda-wiki/domain"
|
|
"github.com/chaitin/panda-wiki/log"
|
|
"github.com/chaitin/panda-wiki/store/pg"
|
|
)
|
|
|
|
type ModelRepository struct {
|
|
db *pg.DB
|
|
logger *log.Logger
|
|
}
|
|
|
|
func NewModelRepository(db *pg.DB, logger *log.Logger) *ModelRepository {
|
|
return &ModelRepository{db: db, logger: logger.WithModule("repo.pg.model")}
|
|
}
|
|
|
|
func (r *ModelRepository) Create(ctx context.Context, model *domain.Model) error {
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Create(model).Error; err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *ModelRepository) GetList(ctx context.Context) ([]*domain.ModelListItem, error) {
|
|
var models []*domain.ModelListItem
|
|
if err := r.db.WithContext(ctx).
|
|
Model(&domain.Model{}).
|
|
Order("created_at ASC").
|
|
Find(&models).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return models, nil
|
|
}
|
|
|
|
func (r *ModelRepository) Update(ctx context.Context, req *domain.UpdateModelReq) error {
|
|
param := domain.ModelParam{}
|
|
if req.Parameters != nil {
|
|
param = *req.Parameters
|
|
}
|
|
updateMap := map[string]any{
|
|
"model": req.Model,
|
|
"api_key": req.APIKey,
|
|
"api_header": req.APIHeader,
|
|
"base_url": req.BaseURL,
|
|
"api_version": req.APIVersion,
|
|
"provider": req.Provider,
|
|
"type": req.Type,
|
|
"parameters": param,
|
|
}
|
|
if req.IsActive != nil {
|
|
updateMap["is_active"] = *req.IsActive
|
|
}
|
|
return r.db.WithContext(ctx).
|
|
Model(&domain.Model{}).
|
|
Where("id = ?", req.ID).
|
|
Updates(updateMap).Error
|
|
}
|
|
|
|
func (r *ModelRepository) Delete(ctx context.Context, id string) error {
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// delete model
|
|
if err := tx.Where("id = ?", id).
|
|
Delete(&domain.Model{}).Error; err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *ModelRepository) GetChatModel(ctx context.Context) (*domain.Model, error) {
|
|
var model domain.Model
|
|
if err := r.db.WithContext(ctx).
|
|
Model(&domain.Model{}).
|
|
Where("type = ?", domain.ModelTypeChat).
|
|
First(&model).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &model, nil
|
|
}
|
|
|
|
func (r *ModelRepository) GetModelByType(ctx context.Context, modelType domain.ModelType) (*domain.Model, error) {
|
|
var model domain.Model
|
|
if err := r.db.WithContext(ctx).
|
|
Model(&domain.Model{}).
|
|
Where("type = ?", modelType).
|
|
First(&model).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &model, nil
|
|
}
|
|
|
|
func (r *ModelRepository) UpdateUsage(ctx context.Context, modelID string, usage *schema.TokenUsage) error {
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// update model usage
|
|
if err := tx.Model(&domain.Model{}).
|
|
Where("id = ?", modelID).
|
|
Updates(map[string]any{
|
|
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
|
|
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
|
|
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
|
|
}).Error; err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|