AnCao/internal/services/ai_grading.go

132 lines
3.8 KiB
Go

package services
import (
"encoding/json"
"fmt"
"net/http"
"strings"
)
// AIGradingService AI评分服务接口(使用百度云AppBuilder)
type AIGradingService struct {
baiduService *BaiduAIGradingService
}
// NewAIGradingService 创建AI评分服务实例
func NewAIGradingService() (*AIGradingService, error) {
baiduService, err := NewBaiduAIGradingService()
if err != nil {
return nil, fmt.Errorf("创建百度云AI服务失败: %w", err)
}
return &AIGradingService{
baiduService: baiduService,
}, nil
}
// AIGradingResult AI评分结果
type AIGradingResult struct {
Score float64 `json:"score"` // 得分 (0-100)
IsCorrect bool `json:"is_correct"` // 是否正确 (Score >= 60 视为正确)
Feedback string `json:"feedback"` // 评语
Suggestion string `json:"suggestion"` // 改进建议
ReferenceAnswer string `json:"reference_answer"` // 参考答案(论述题)
ScoringRationale string `json:"scoring_rationale"` // 评分依据
}
// GradeEssay 对论述题进行AI评分(不需要标准答案)
// question: 题目内容
// userAnswer: 用户答案
func (s *AIGradingService) GradeEssay(question, userAnswer string) (*AIGradingResult, error) {
if s.baiduService == nil {
return nil, fmt.Errorf("百度云AI服务未初始化")
}
return s.baiduService.GradeEssay(question, userAnswer)
}
// GradeShortAnswer 对简答题进行AI评分
// question: 题目内容
// standardAnswer: 标准答案
// userAnswer: 用户答案
func (s *AIGradingService) GradeShortAnswer(question, standardAnswer, userAnswer string) (*AIGradingResult, error) {
if s.baiduService == nil {
return nil, fmt.Errorf("百度云AI服务未初始化")
}
return s.baiduService.GradeShortAnswer(question, standardAnswer, userAnswer)
}
// AIExplanationResult AI解析结果
type AIExplanationResult struct {
Explanation string `json:"explanation"` // 题目解析
}
// ExplainQuestionStream 生成题目解析(流式输出)
// writer: HTTP响应写入器
// question: 题目内容
// standardAnswer: 标准答案
// questionType: 题目类型
func (s *AIGradingService) ExplainQuestionStream(writer http.ResponseWriter, question, standardAnswer, questionType string) error {
if s.baiduService == nil {
return fmt.Errorf("百度云AI服务未初始化")
}
return s.baiduService.ExplainQuestionStream(writer, question, standardAnswer, questionType)
}
// ExplainQuestion 生成题目解析
// question: 题目内容
// standardAnswer: 标准答案
// questionType: 题目类型
func (s *AIGradingService) ExplainQuestion(question, standardAnswer, questionType string) (*AIExplanationResult, error) {
if s.baiduService == nil {
return nil, fmt.Errorf("百度云AI服务未初始化")
}
return s.baiduService.ExplainQuestion(question, standardAnswer, questionType)
}
// parseAIResponse 解析AI返回的JSON响应
func parseAIResponse(content string, result interface{}) error {
// 移除可能的markdown代码块标记
jsonStr := removeMarkdownCodeBlock(content)
// 使用json包解析
if err := json.Unmarshal([]byte(jsonStr), result); err != nil {
return fmt.Errorf("JSON解析失败: %w, 原始内容: %s", err, content)
}
return nil
}
// removeMarkdownCodeBlock 移除markdown代码块标记
func removeMarkdownCodeBlock(s string) string {
// 去除可能的```json和```标记
s = strings.TrimSpace(s)
// 移除开头的```json或```
if strings.HasPrefix(s, "```json") {
s = s[7:]
} else if strings.HasPrefix(s, "```") {
s = s[3:]
}
// 移除结尾的```
if strings.HasSuffix(s, "```") {
s = s[:len(s)-3]
}
s = strings.TrimSpace(s)
// 查找第一个{的位置
startIdx := strings.Index(s, "{")
if startIdx == -1 {
return s
}
// 查找最后一个}的位置
endIdx := strings.LastIndex(s, "}")
if endIdx == -1 || endIdx <= startIdx {
return s
}
return s[startIdx : endIdx+1]
}