add rules to AI

This commit is contained in:
2025-05-23 13:14:41 +01:00
parent c3413973f2
commit fc9e6a5fce

54
main.go
View File

@ -9,11 +9,11 @@ import (
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
// "time" // Removed: import and not used
"github.com/charmbracelet/huh" "github.com/charmbracelet/huh"
"github.com/google/generative-ai-go/genai" "github.com/google/generative-ai-go/genai"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"google.golang.org/api/iterator" // Added for iterator.Done
"google.golang.org/api/option" "google.golang.org/api/option"
) )
@ -21,7 +21,7 @@ import (
const maxRedirects = 10 const maxRedirects = 10
// Constants for AI model names // Constants for AI model names
const geminiModelName = "gemini-1.5-flash-latest" const geminiModelName = "gemini-2.5-flash-preview-05-20"
const openAIModelName = "gpt-4o-mini" const openAIModelName = "gpt-4o-mini"
// isValidURLAndFetchContent (no changes) // isValidURLAndFetchContent (no changes)
@ -127,7 +127,9 @@ func buildAnalysisPrompt(scriptContent string) string {
%s %s
--- SCRIPT END --- --- SCRIPT END ---
Provide your analysis as a stream of text.`, scriptContent)
Do not over explain yourself or include very basic information, assume the user already understands the basics of a shell script.
Do not use Markdown symbols such as asterisks or backticks.`, scriptContent)
} }
func analyzeWithGemini(scriptContent string, apiKey string) error { func analyzeWithGemini(scriptContent string, apiKey string) error {
@ -139,24 +141,11 @@ func analyzeWithGemini(scriptContent string, apiKey string) error {
defer client.Close() defer client.Close()
model := client.GenerativeModel(geminiModelName) model := client.GenerativeModel(geminiModelName)
// Corrected: Assign directly to the SafetySettings field
model.SafetySettings = []*genai.SafetySetting{ model.SafetySettings = []*genai.SafetySetting{
{ {Category: genai.HarmCategoryDangerousContent, Threshold: genai.HarmBlockNone},
Category: genai.HarmCategoryDangerousContent, {Category: genai.HarmCategoryHarassment, Threshold: genai.HarmBlockNone},
Threshold: genai.HarmBlockNone, {Category: genai.HarmCategorySexuallyExplicit, Threshold: genai.HarmBlockNone},
}, {Category: genai.HarmCategoryHateSpeech, Threshold: genai.HarmBlockNone},
{
Category: genai.HarmCategoryHarassment,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategorySexuallyExplicit,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategoryHateSpeech,
Threshold: genai.HarmBlockNone,
},
} }
prompt := buildAnalysisPrompt(scriptContent) prompt := buildAnalysisPrompt(scriptContent)
@ -164,21 +153,21 @@ func analyzeWithGemini(scriptContent string, apiKey string) error {
iter := model.GenerateContentStream(ctx, genai.Text(prompt)) iter := model.GenerateContentStream(ctx, genai.Text(prompt))
for { for {
resp, err := iter.Next() resp, err := iter.Next()
if err == io.EOF { // Corrected check: Use errors.Is with iterator.Done
if errors.Is(err, iterator.Done) {
break break
} }
if err != nil { if err != nil {
// This will catch other errors that are not iterator.Done
return fmt.Errorf("Gemini stream error: %w", err) return fmt.Errorf("Gemini stream error: %w", err)
} }
if resp == nil || len(resp.Candidates) == 0 || resp.Candidates[0].Content == nil || len(resp.Candidates[0].Content.Parts) == 0 { if resp == nil || len(resp.Candidates) == 0 || resp.Candidates[0].Content == nil || len(resp.Candidates[0].Content.Parts) == 0 {
continue continue
} }
// Ensure part is genai.Text before trying to print
if len(resp.Candidates[0].Content.Parts) > 0 {
if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok {
fmt.Print(textPart) fmt.Print(textPart)
os.Stdout.Sync() // Ensure immediate output os.Stdout.Sync()
}
} }
} }
fmt.Println("\n--- GEMINI ANALYSIS FINISHED ---") fmt.Println("\n--- GEMINI ANALYSIS FINISHED ---")
@ -191,14 +180,8 @@ func analyzeWithOpenAI(scriptContent string, apiKey string) error {
prompt := buildAnalysisPrompt(scriptContent) prompt := buildAnalysisPrompt(scriptContent)
messages := []openai.ChatCompletionMessage{ messages := []openai.ChatCompletionMessage{
{ {Role: openai.ChatMessageRoleSystem, Content: "You are an expert shell script analyzer."},
Role: openai.ChatMessageRoleSystem, {Role: openai.ChatMessageRoleUser, Content: prompt},
Content: "You are an expert shell script analyzer.",
},
{
Role: openai.ChatMessageRoleUser,
Content: prompt,
},
} }
request := openai.ChatCompletionRequest{ request := openai.ChatCompletionRequest{
@ -216,7 +199,7 @@ func analyzeWithOpenAI(scriptContent string, apiKey string) error {
for { for {
response, err := stream.Recv() response, err := stream.Recv()
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) { // OpenAI SDK uses io.EOF for stream end
break break
} }
if err != nil { if err != nil {
@ -224,7 +207,7 @@ func analyzeWithOpenAI(scriptContent string, apiKey string) error {
} }
if len(response.Choices) > 0 { if len(response.Choices) > 0 {
fmt.Print(response.Choices[0].Delta.Content) fmt.Print(response.Choices[0].Delta.Content)
os.Stdout.Sync() // Ensure immediate output os.Stdout.Sync()
} }
} }
fmt.Println("\n--- OPENAI ANALYSIS FINISHED ---") fmt.Println("\n--- OPENAI ANALYSIS FINISHED ---")
@ -333,6 +316,7 @@ func main() {
} }
err := analyzeWithGemini(string(content), geminiAPIKey) err := analyzeWithGemini(string(content), geminiAPIKey)
if err != nil { if err != nil {
// This will now only print if there's an actual error other than iterator.Done
fmt.Fprintf(os.Stderr, "\nError analyzing with Gemini: %v\n", err) fmt.Fprintf(os.Stderr, "\nError analyzing with Gemini: %v\n", err)
} }
case "analyze_openai": case "analyze_openai":