OpenAI and Gemini integration
This commit is contained in:
206
main.go
206
main.go
@ -1,28 +1,35 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
// "time" // Removed: import and not used
|
||||
|
||||
"github.com/charmbracelet/huh"
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
// maxRedirects defines the maximum number of redirects to follow.
|
||||
const maxRedirects = 10
|
||||
|
||||
// isValidURLAndFetchContent checks if the URL is valid and fetches its content.
|
||||
// It handles redirects and checks for the .sh extension.
|
||||
// It returns the content, a boolean indicating validity, and an error if any.
|
||||
// Constants for AI model names
|
||||
const geminiModelName = "gemini-1.5-flash-latest"
|
||||
const openAIModelName = "gpt-4o-mini"
|
||||
|
||||
// isValidURLAndFetchContent (no changes)
|
||||
func isValidURLAndFetchContent(url string, redirectCount int) ([]byte, bool, error) {
|
||||
if redirectCount > maxRedirects {
|
||||
return nil, false, fmt.Errorf("too many redirects")
|
||||
}
|
||||
|
||||
// Check if the URL ends with .sh
|
||||
if strings.HasSuffix(strings.ToLower(url), ".sh") {
|
||||
fmt.Printf("Attempting to download: %s\n", url)
|
||||
resp, err := http.Get(url)
|
||||
@ -44,7 +51,7 @@ func isValidURLAndFetchContent(url string, redirectCount int) ([]byte, bool, err
|
||||
fmt.Printf("URL %s does not end with .sh, checking for redirect...\n", url)
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse // Do not follow redirects automatically
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
@ -67,41 +74,37 @@ func isValidURLAndFetchContent(url string, redirectCount int) ([]byte, bool, err
|
||||
}
|
||||
}
|
||||
|
||||
// hasShebang checks if the content starts with "#!".
|
||||
// hasShebang (no changes)
|
||||
func hasShebang(content []byte) bool {
|
||||
return len(content) >= 2 && content[0] == '#' && content[1] == '!'
|
||||
}
|
||||
|
||||
// executeScript saves the script content to a temporary file and executes it.
|
||||
// executeScript (no changes)
|
||||
func executeScript(content []byte) error {
|
||||
tmpFile, err := os.CreateTemp("", "script-*.sh")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary file: %w", err)
|
||||
}
|
||||
scriptPath := tmpFile.Name()
|
||||
defer os.Remove(scriptPath) // Clean up the temp file
|
||||
defer os.Remove(scriptPath)
|
||||
|
||||
if _, err := tmpFile.Write(content); err != nil {
|
||||
tmpFile.Close() // Close before attempting remove on error
|
||||
tmpFile.Close()
|
||||
return fmt.Errorf("failed to write to temporary file: %w", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close temporary file: %w", err)
|
||||
}
|
||||
|
||||
// Make the script executable
|
||||
if err := os.Chmod(scriptPath, 0700); err != nil {
|
||||
return fmt.Errorf("failed to make script executable: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\n--- EXECUTING SCRIPT (%s) ---\n", scriptPath)
|
||||
// The command to execute. On Unix-like systems, the kernel will use the shebang.
|
||||
// For Windows, or if a specific interpreter is needed universally,
|
||||
// one might need to parse the shebang and prepend the interpreter.
|
||||
cmd := exec.Command(scriptPath)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdin = os.Stdin // Allow script to read from stdin
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
err = cmd.Run()
|
||||
fmt.Println("\n--- EXECUTION FINISHED ---")
|
||||
@ -114,6 +117,120 @@ func executeScript(content []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildAnalysisPrompt(scriptContent string) string {
|
||||
return fmt.Sprintf(`Please analyze the shell script provided below. Focus on the following aspects:
|
||||
1. **Purpose**: What is the primary goal or function of this script?
|
||||
2. **General Steps**: Outline the main actions or sequence of operations the script performs.
|
||||
3. **Safety Assessment**: Based on its actions, does the script appear safe or potentially malicious? Explain your reasoning. Highlight any suspicious commands or patterns.
|
||||
|
||||
--- SCRIPT START ---
|
||||
%s
|
||||
--- SCRIPT END ---
|
||||
|
||||
Provide your analysis as a stream of text.`, scriptContent)
|
||||
}
|
||||
|
||||
func analyzeWithGemini(scriptContent string, apiKey string) error {
|
||||
ctx := context.Background()
|
||||
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Gemini client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
model := client.GenerativeModel(geminiModelName)
|
||||
// Corrected: Assign directly to the SafetySettings field
|
||||
model.SafetySettings = []*genai.SafetySetting{
|
||||
{
|
||||
Category: genai.HarmCategoryDangerousContent,
|
||||
Threshold: genai.HarmBlockNone,
|
||||
},
|
||||
{
|
||||
Category: genai.HarmCategoryHarassment,
|
||||
Threshold: genai.HarmBlockNone,
|
||||
},
|
||||
{
|
||||
Category: genai.HarmCategorySexuallyExplicit,
|
||||
Threshold: genai.HarmBlockNone,
|
||||
},
|
||||
{
|
||||
Category: genai.HarmCategoryHateSpeech,
|
||||
Threshold: genai.HarmBlockNone,
|
||||
},
|
||||
}
|
||||
prompt := buildAnalysisPrompt(scriptContent)
|
||||
|
||||
fmt.Printf("\n--- ANALYZING WITH GEMINI (%s) ---\n", geminiModelName)
|
||||
iter := model.GenerateContentStream(ctx, genai.Text(prompt))
|
||||
for {
|
||||
resp, err := iter.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("Gemini stream error: %w", err)
|
||||
}
|
||||
if resp == nil || len(resp.Candidates) == 0 || resp.Candidates[0].Content == nil || len(resp.Candidates[0].Content.Parts) == 0 {
|
||||
continue
|
||||
}
|
||||
// Ensure part is genai.Text before trying to print
|
||||
if len(resp.Candidates[0].Content.Parts) > 0 {
|
||||
if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok {
|
||||
fmt.Print(textPart)
|
||||
os.Stdout.Sync() // Ensure immediate output
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println("\n--- GEMINI ANALYSIS FINISHED ---")
|
||||
return nil
|
||||
}
|
||||
|
||||
func analyzeWithOpenAI(scriptContent string, apiKey string) error {
|
||||
ctx := context.Background()
|
||||
client := openai.NewClient(apiKey)
|
||||
|
||||
prompt := buildAnalysisPrompt(scriptContent)
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: "You are an expert shell script analyzer.",
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
request := openai.ChatCompletionRequest{
|
||||
Model: openAIModelName,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
fmt.Printf("\n--- ANALYZING WITH OPENAI (%s) ---\n", openAIModelName)
|
||||
stream, err := client.CreateChatCompletionStream(ctx, request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create OpenAI completion stream: %w", err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("OpenAI stream error: %w", err)
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
fmt.Print(response.Choices[0].Delta.Content)
|
||||
os.Stdout.Sync() // Ensure immediate output
|
||||
}
|
||||
}
|
||||
fmt.Println("\n--- OPENAI ANALYSIS FINISHED ---")
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Println("Usage: go run script_downloader.go <URL>")
|
||||
@ -128,7 +245,6 @@ func main() {
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
// This case should ideally be covered by errors from isValidURLAndFetchContent
|
||||
fmt.Fprintf(os.Stderr, "The provided URL is invalid or does not point to a valid .sh file after redirects.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
@ -140,28 +256,40 @@ func main() {
|
||||
|
||||
fmt.Println("Valid script found.")
|
||||
|
||||
scriptOptions := []huh.Option[string]{
|
||||
huh.NewOption("Read the script", "read"),
|
||||
huh.NewOption("Execute the script (Potentially DANGEROUS!)", "execute"),
|
||||
}
|
||||
|
||||
geminiAPIKey := os.Getenv("GEMINI_API_KEY")
|
||||
if geminiAPIKey != "" {
|
||||
scriptOptions = append(scriptOptions, huh.NewOption("Analyze with Gemini", "analyze_gemini"))
|
||||
}
|
||||
|
||||
openaiAPIKey := os.Getenv("OPENAI_API_KEY")
|
||||
if openaiAPIKey != "" {
|
||||
scriptOptions = append(scriptOptions, huh.NewOption(fmt.Sprintf("Analyze with OpenAI (%s)", openAIModelName), "analyze_openai"))
|
||||
}
|
||||
|
||||
scriptOptions = append(scriptOptions, huh.NewOption("Exit", "exit"))
|
||||
|
||||
var choice string
|
||||
scriptActionForm := huh.NewForm(
|
||||
huh.NewGroup(
|
||||
huh.NewSelect[string]().
|
||||
Title("What would you like to do with the script?").
|
||||
Options(
|
||||
huh.NewOption("Read the script", "read"),
|
||||
huh.NewOption("Execute the script (Potentially DANGEROUS!)", "execute"),
|
||||
huh.NewOption("Exit", "exit"),
|
||||
).
|
||||
Options(scriptOptions...).
|
||||
Value(&choice),
|
||||
),
|
||||
)
|
||||
|
||||
err = scriptActionForm.Run()
|
||||
if err != nil {
|
||||
// This can happen if the user cancels the form (e.g., Ctrl+C)
|
||||
if err == huh.ErrUserAborted {
|
||||
formErr := scriptActionForm.Run()
|
||||
if formErr != nil {
|
||||
if formErr == huh.ErrUserAborted {
|
||||
fmt.Println("Operation cancelled by user. Exiting.")
|
||||
os.Exit(0)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error running selection form: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Error running selection form: %v\n", formErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@ -172,7 +300,6 @@ func main() {
|
||||
fmt.Println("--- SCRIPT CONTENT END ---")
|
||||
case "execute":
|
||||
fmt.Println("Attempting to execute the script...")
|
||||
// Add an additional confirmation for execution due to security risks
|
||||
var confirmExecute bool
|
||||
confirmForm := huh.NewForm(
|
||||
huh.NewGroup(
|
||||
@ -184,12 +311,12 @@ func main() {
|
||||
Value(&confirmExecute),
|
||||
),
|
||||
)
|
||||
err := confirmForm.Run()
|
||||
if err != nil || !confirmExecute {
|
||||
if err == huh.ErrUserAborted || !confirmExecute {
|
||||
confirmErr := confirmForm.Run()
|
||||
if confirmErr != nil || !confirmExecute {
|
||||
if confirmErr == huh.ErrUserAborted || !confirmExecute {
|
||||
fmt.Println("Execution cancelled by user.")
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Error during confirmation: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Error during confirmation: %v\n", confirmErr)
|
||||
}
|
||||
os.Exit(0)
|
||||
return
|
||||
@ -199,11 +326,28 @@ func main() {
|
||||
fmt.Fprintf(os.Stderr, "Error during script execution: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
case "analyze_gemini":
|
||||
if geminiAPIKey == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: GEMINI_API_KEY not found.")
|
||||
os.Exit(1)
|
||||
}
|
||||
err := analyzeWithGemini(string(content), geminiAPIKey)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\nError analyzing with Gemini: %v\n", err)
|
||||
}
|
||||
case "analyze_openai":
|
||||
if openaiAPIKey == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: OPENAI_API_KEY not found.")
|
||||
os.Exit(1)
|
||||
}
|
||||
err := analyzeWithOpenAI(string(content), openaiAPIKey)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\nError analyzing with OpenAI: %v\n", err)
|
||||
}
|
||||
case "exit":
|
||||
fmt.Println("Exiting.")
|
||||
os.Exit(0)
|
||||
default:
|
||||
// This path should ideally not be reached if huh.Select is used correctly
|
||||
fmt.Fprintf(os.Stderr, "Invalid choice. Exiting.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
Reference in New Issue
Block a user