Files
shrunner/main.go
2025-05-23 13:01:19 +01:00

355 lines
10 KiB
Go

package main
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"strings"
// "time" // Removed: import and not used
"github.com/charmbracelet/huh"
"github.com/google/generative-ai-go/genai"
"github.com/sashabaranov/go-openai"
"google.golang.org/api/option"
)
// maxRedirects defines the maximum number of redirects to follow.
const maxRedirects = 10
// Constants for AI model names
const geminiModelName = "gemini-1.5-flash-latest"
const openAIModelName = "gpt-4o-mini"
// isValidURLAndFetchContent (no changes)
func isValidURLAndFetchContent(url string, redirectCount int) ([]byte, bool, error) {
if redirectCount > maxRedirects {
return nil, false, fmt.Errorf("too many redirects")
}
if strings.HasSuffix(strings.ToLower(url), ".sh") {
fmt.Printf("Attempting to download: %s\n", url)
resp, err := http.Get(url)
if err != nil {
return nil, false, fmt.Errorf("failed to download from %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, false, fmt.Errorf("failed to read content from %s: %w", url, err)
}
return content, true, nil
}
return nil, false, fmt.Errorf("failed to download %s: status code %d", url, resp.StatusCode)
}
fmt.Printf("URL %s does not end with .sh, checking for redirect...\n", url)
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
resp, err := client.Get(url)
if err != nil {
return nil, false, fmt.Errorf("failed to request %s: %w", url, err)
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusMovedPermanently, http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect:
location, err := resp.Location()
if err != nil {
return nil, false, fmt.Errorf("failed to get redirect location from %s: %w", url, err)
}
fmt.Printf("Redirected to: %s\n", location.String())
return isValidURLAndFetchContent(location.String(), redirectCount+1)
default:
return nil, false, fmt.Errorf("URL %s does not end in .sh and did not redirect (status: %d)", url, resp.StatusCode)
}
}
// hasShebang (no changes)
func hasShebang(content []byte) bool {
return len(content) >= 2 && content[0] == '#' && content[1] == '!'
}
// executeScript (no changes)
func executeScript(content []byte) error {
tmpFile, err := os.CreateTemp("", "script-*.sh")
if err != nil {
return fmt.Errorf("failed to create temporary file: %w", err)
}
scriptPath := tmpFile.Name()
defer os.Remove(scriptPath)
if _, err := tmpFile.Write(content); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to write to temporary file: %w", err)
}
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temporary file: %w", err)
}
if err := os.Chmod(scriptPath, 0700); err != nil {
return fmt.Errorf("failed to make script executable: %w", err)
}
fmt.Printf("\n--- EXECUTING SCRIPT (%s) ---\n", scriptPath)
cmd := exec.Command(scriptPath)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Stdin = os.Stdin
err = cmd.Run()
fmt.Println("\n--- EXECUTION FINISHED ---")
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return fmt.Errorf("script execution failed with exit code %d", exitErr.ExitCode())
}
return fmt.Errorf("script execution failed: %w", err)
}
return nil
}
func buildAnalysisPrompt(scriptContent string) string {
return fmt.Sprintf(`Please analyze the shell script provided below. Focus on the following aspects:
1. **Purpose**: What is the primary goal or function of this script?
2. **General Steps**: Outline the main actions or sequence of operations the script performs.
3. **Safety Assessment**: Based on its actions, does the script appear safe or potentially malicious? Explain your reasoning. Highlight any suspicious commands or patterns.
--- SCRIPT START ---
%s
--- SCRIPT END ---
Provide your analysis as a stream of text.`, scriptContent)
}
func analyzeWithGemini(scriptContent string, apiKey string) error {
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
if err != nil {
return fmt.Errorf("failed to create Gemini client: %w", err)
}
defer client.Close()
model := client.GenerativeModel(geminiModelName)
// Corrected: Assign directly to the SafetySettings field
model.SafetySettings = []*genai.SafetySetting{
{
Category: genai.HarmCategoryDangerousContent,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategoryHarassment,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategorySexuallyExplicit,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategoryHateSpeech,
Threshold: genai.HarmBlockNone,
},
}
prompt := buildAnalysisPrompt(scriptContent)
fmt.Printf("\n--- ANALYZING WITH GEMINI (%s) ---\n", geminiModelName)
iter := model.GenerateContentStream(ctx, genai.Text(prompt))
for {
resp, err := iter.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("Gemini stream error: %w", err)
}
if resp == nil || len(resp.Candidates) == 0 || resp.Candidates[0].Content == nil || len(resp.Candidates[0].Content.Parts) == 0 {
continue
}
// Ensure part is genai.Text before trying to print
if len(resp.Candidates[0].Content.Parts) > 0 {
if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok {
fmt.Print(textPart)
os.Stdout.Sync() // Ensure immediate output
}
}
}
fmt.Println("\n--- GEMINI ANALYSIS FINISHED ---")
return nil
}
func analyzeWithOpenAI(scriptContent string, apiKey string) error {
ctx := context.Background()
client := openai.NewClient(apiKey)
prompt := buildAnalysisPrompt(scriptContent)
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "You are an expert shell script analyzer.",
},
{
Role: openai.ChatMessageRoleUser,
Content: prompt,
},
}
request := openai.ChatCompletionRequest{
Model: openAIModelName,
Messages: messages,
Stream: true,
}
fmt.Printf("\n--- ANALYZING WITH OPENAI (%s) ---\n", openAIModelName)
stream, err := client.CreateChatCompletionStream(ctx, request)
if err != nil {
return fmt.Errorf("failed to create OpenAI completion stream: %w", err)
}
defer stream.Close()
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return fmt.Errorf("OpenAI stream error: %w", err)
}
if len(response.Choices) > 0 {
fmt.Print(response.Choices[0].Delta.Content)
os.Stdout.Sync() // Ensure immediate output
}
}
fmt.Println("\n--- OPENAI ANALYSIS FINISHED ---")
return nil
}
func main() {
if len(os.Args) < 2 {
fmt.Println("Usage: go run script_downloader.go <URL>")
os.Exit(1)
}
initialURL := os.Args[1]
content, isValid, err := isValidURLAndFetchContent(initialURL, 0)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
if !isValid {
fmt.Fprintf(os.Stderr, "The provided URL is invalid or does not point to a valid .sh file after redirects.\n")
os.Exit(1)
}
if !hasShebang(content) {
fmt.Fprintf(os.Stderr, "Error: The file from URL does not start with a shebang '#!'.\n")
os.Exit(1)
}
fmt.Println("Valid script found.")
scriptOptions := []huh.Option[string]{
huh.NewOption("Read the script", "read"),
huh.NewOption("Execute the script (Potentially DANGEROUS!)", "execute"),
}
geminiAPIKey := os.Getenv("GEMINI_API_KEY")
if geminiAPIKey != "" {
scriptOptions = append(scriptOptions, huh.NewOption("Analyze with Gemini", "analyze_gemini"))
}
openaiAPIKey := os.Getenv("OPENAI_API_KEY")
if openaiAPIKey != "" {
scriptOptions = append(scriptOptions, huh.NewOption(fmt.Sprintf("Analyze with OpenAI (%s)", openAIModelName), "analyze_openai"))
}
scriptOptions = append(scriptOptions, huh.NewOption("Exit", "exit"))
var choice string
scriptActionForm := huh.NewForm(
huh.NewGroup(
huh.NewSelect[string]().
Title("What would you like to do with the script?").
Options(scriptOptions...).
Value(&choice),
),
)
formErr := scriptActionForm.Run()
if formErr != nil {
if formErr == huh.ErrUserAborted {
fmt.Println("Operation cancelled by user. Exiting.")
os.Exit(0)
}
fmt.Fprintf(os.Stderr, "Error running selection form: %v\n", formErr)
os.Exit(1)
}
switch choice {
case "read":
fmt.Println("\n--- SCRIPT CONTENT START ---")
fmt.Print(string(content))
fmt.Println("--- SCRIPT CONTENT END ---")
case "execute":
fmt.Println("Attempting to execute the script...")
var confirmExecute bool
confirmForm := huh.NewForm(
huh.NewGroup(
huh.NewConfirm().
Title("DANGER: You are about to execute a script from the internet.").
Description("Only proceed if you FULLY TRUST the source of this script.\nAre you sure you want to execute it?").
Affirmative("Yes, execute it!").
Negative("No, cancel.").
Value(&confirmExecute),
),
)
confirmErr := confirmForm.Run()
if confirmErr != nil || !confirmExecute {
if confirmErr == huh.ErrUserAborted || !confirmExecute {
fmt.Println("Execution cancelled by user.")
} else {
fmt.Fprintf(os.Stderr, "Error during confirmation: %v\n", confirmErr)
}
os.Exit(0)
return
}
if err := executeScript(content); err != nil {
fmt.Fprintf(os.Stderr, "Error during script execution: %v\n", err)
os.Exit(1)
}
case "analyze_gemini":
if geminiAPIKey == "" {
fmt.Fprintln(os.Stderr, "Error: GEMINI_API_KEY not found.")
os.Exit(1)
}
err := analyzeWithGemini(string(content), geminiAPIKey)
if err != nil {
fmt.Fprintf(os.Stderr, "\nError analyzing with Gemini: %v\n", err)
}
case "analyze_openai":
if openaiAPIKey == "" {
fmt.Fprintln(os.Stderr, "Error: OPENAI_API_KEY not found.")
os.Exit(1)
}
err := analyzeWithOpenAI(string(content), openaiAPIKey)
if err != nil {
fmt.Fprintf(os.Stderr, "\nError analyzing with OpenAI: %v\n", err)
}
case "exit":
fmt.Println("Exiting.")
os.Exit(0)
default:
fmt.Fprintf(os.Stderr, "Invalid choice. Exiting.\n")
os.Exit(1)
}
}