mirror of
				https://github.com/abhinavxd/libredesk.git
				synced 2025-11-04 05:53:30 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			163 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			163 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Package ai manages AI prompts and integrates with LLM providers.
 | 
						|
package ai
 | 
						|
 | 
						|
import (
 | 
						|
	"database/sql"
 | 
						|
	"embed"
 | 
						|
	"encoding/json"
 | 
						|
	"errors"
 | 
						|
 | 
						|
	"github.com/abhinavxd/libredesk/internal/ai/models"
 | 
						|
	"github.com/abhinavxd/libredesk/internal/dbutil"
 | 
						|
	"github.com/abhinavxd/libredesk/internal/envelope"
 | 
						|
	"github.com/jmoiron/sqlx"
 | 
						|
	"github.com/knadh/go-i18n"
 | 
						|
	"github.com/zerodha/logf"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	//go:embed queries.sql
 | 
						|
	efs embed.FS
 | 
						|
 | 
						|
	ErrInvalidAPIKey = errors.New("invalid API Key")
 | 
						|
	ErrApiKeyNotSet  = errors.New("api Key not set")
 | 
						|
)
 | 
						|
 | 
						|
type Manager struct {
 | 
						|
	q    queries
 | 
						|
	lo   *logf.Logger
 | 
						|
	i18n *i18n.I18n
 | 
						|
}
 | 
						|
 | 
						|
// Opts contains options for initializing the Manager.
 | 
						|
type Opts struct {
 | 
						|
	DB   *sqlx.DB
 | 
						|
	I18n *i18n.I18n
 | 
						|
	Lo   *logf.Logger
 | 
						|
}
 | 
						|
 | 
						|
// queries contains prepared SQL queries.
 | 
						|
type queries struct {
 | 
						|
	GetDefaultProvider *sqlx.Stmt `query:"get-default-provider"`
 | 
						|
	GetPrompt          *sqlx.Stmt `query:"get-prompt"`
 | 
						|
	GetPrompts         *sqlx.Stmt `query:"get-prompts"`
 | 
						|
	SetOpenAIKey       *sqlx.Stmt `query:"set-openai-key"`
 | 
						|
}
 | 
						|
 | 
						|
// New creates and returns a new instance of the Manager.
 | 
						|
func New(opts Opts) (*Manager, error) {
 | 
						|
	var q queries
 | 
						|
	if err := dbutil.ScanSQLFile("queries.sql", &q, opts.DB, efs); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return &Manager{
 | 
						|
		q:    q,
 | 
						|
		lo:   opts.Lo,
 | 
						|
		i18n: opts.I18n,
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
// Completion sends a prompt to the default provider and returns the response.
 | 
						|
func (m *Manager) Completion(k string, prompt string) (string, error) {
 | 
						|
	systemPrompt, err := m.getPrompt(k)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	client, err := m.getDefaultProviderClient()
 | 
						|
	if err != nil {
 | 
						|
		m.lo.Error("error getting provider client", "error", err)
 | 
						|
		return "", envelope.NewError(envelope.GeneralError, m.i18n.Ts("globals.messages.errorFetching", "name", m.i18n.Ts("globals.terms.provider")), nil)
 | 
						|
	}
 | 
						|
 | 
						|
	payload := PromptPayload{
 | 
						|
		SystemPrompt: systemPrompt,
 | 
						|
		UserPrompt:   prompt,
 | 
						|
	}
 | 
						|
 | 
						|
	response, err := client.SendPrompt(payload)
 | 
						|
	if err != nil {
 | 
						|
		if errors.Is(err, ErrInvalidAPIKey) {
 | 
						|
			m.lo.Error("error invalid API key", "error", err)
 | 
						|
			return "", envelope.NewError(envelope.InputError, m.i18n.Ts("globals.messages.invalid", "name", "OpenAI API Key"), nil)
 | 
						|
		}
 | 
						|
		if errors.Is(err, ErrApiKeyNotSet) {
 | 
						|
			m.lo.Error("error API key not set", "error", err)
 | 
						|
			return "", envelope.NewError(envelope.InputError, m.i18n.Ts("ai.apiKeyNotSet", "provider", "OpenAI"), nil)
 | 
						|
		}
 | 
						|
		m.lo.Error("error sending prompt to provider", "error", err)
 | 
						|
		return "", envelope.NewError(envelope.GeneralError, err.Error(), nil)
 | 
						|
	}
 | 
						|
 | 
						|
	return response, nil
 | 
						|
}
 | 
						|
 | 
						|
// GetPrompts returns a list of prompts from the database.
 | 
						|
func (m *Manager) GetPrompts() ([]models.Prompt, error) {
 | 
						|
	var prompts = make([]models.Prompt, 0)
 | 
						|
	if err := m.q.GetPrompts.Select(&prompts); err != nil {
 | 
						|
		m.lo.Error("error fetching prompts", "error", err)
 | 
						|
		return nil, envelope.NewError(envelope.GeneralError, m.i18n.Ts("globals.messages.errorFetching", "name", m.i18n.Ts("globals.terms.template")), nil)
 | 
						|
	}
 | 
						|
	return prompts, nil
 | 
						|
}
 | 
						|
 | 
						|
// UpdateProvider updates a provider.
 | 
						|
func (m *Manager) UpdateProvider(provider, apiKey string) error {
 | 
						|
	switch ProviderType(provider) {
 | 
						|
	case ProviderOpenAI:
 | 
						|
		return m.setOpenAIAPIKey(apiKey)
 | 
						|
	default:
 | 
						|
		m.lo.Error("unsupported provider type", "provider", provider)
 | 
						|
		return envelope.NewError(envelope.GeneralError, m.i18n.Ts("globals.messages.invalid", "name", m.i18n.Ts("globals.terms.provider")), nil)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// setOpenAIAPIKey sets the OpenAI API key in the database.
 | 
						|
func (m *Manager) setOpenAIAPIKey(apiKey string) error {
 | 
						|
	if _, err := m.q.SetOpenAIKey.Exec(apiKey); err != nil {
 | 
						|
		m.lo.Error("error setting OpenAI API key", "error", err)
 | 
						|
		return envelope.NewError(envelope.GeneralError, m.i18n.Ts("globals.messages.errorUpdating", "name", "OpenAI API Key"), nil)
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// getPrompt returns a prompt from the database.
 | 
						|
func (m *Manager) getPrompt(k string) (string, error) {
 | 
						|
	var p models.Prompt
 | 
						|
	if err := m.q.GetPrompt.Get(&p, k); err != nil {
 | 
						|
		if err == sql.ErrNoRows {
 | 
						|
			m.lo.Error("error prompt not found", "key", k)
 | 
						|
			return "", envelope.NewError(envelope.InputError, m.i18n.Ts("globals.messages.notFound", "name", m.i18n.Ts("globals.terms.template")), nil)
 | 
						|
		}
 | 
						|
		m.lo.Error("error fetching prompt", "error", err)
 | 
						|
		return "", envelope.NewError(envelope.GeneralError, m.i18n.Ts("globals.messages.errorFetching", "name", m.i18n.Ts("globals.terms.template")), nil)
 | 
						|
	}
 | 
						|
	return p.Content, nil
 | 
						|
}
 | 
						|
 | 
						|
// getDefaultProviderClient returns a ProviderClient for the default provider.
 | 
						|
func (m *Manager) getDefaultProviderClient() (ProviderClient, error) {
 | 
						|
	var p models.Provider
 | 
						|
 | 
						|
	if err := m.q.GetDefaultProvider.Get(&p); err != nil {
 | 
						|
		m.lo.Error("error fetching provider details", "error", err)
 | 
						|
		return nil, envelope.NewError(envelope.GeneralError, m.i18n.Ts("globals.messages.errorFetching", "name", m.i18n.Ts("globals.terms.provider")), nil)
 | 
						|
	}
 | 
						|
 | 
						|
	switch ProviderType(p.Provider) {
 | 
						|
	case ProviderOpenAI:
 | 
						|
		config := struct {
 | 
						|
			APIKey string `json:"api_key"`
 | 
						|
		}{}
 | 
						|
		if err := json.Unmarshal([]byte(p.Config), &config); err != nil {
 | 
						|
			m.lo.Error("error parsing provider config", "error", err)
 | 
						|
			return nil, envelope.NewError(envelope.GeneralError, m.i18n.Ts("globals.messages.errorParsing", "name", m.i18n.Ts("globals.terms.provider")), nil)
 | 
						|
		}
 | 
						|
		return NewOpenAIClient(config.APIKey, m.lo), nil
 | 
						|
	default:
 | 
						|
		m.lo.Error("unsupported provider type", "provider", p.Provider)
 | 
						|
		return nil, envelope.NewError(envelope.GeneralError, m.i18n.Ts("globals.messages.invalid", "name", m.i18n.Ts("globals.terms.provider")), nil)
 | 
						|
	}
 | 
						|
}
 |