Files
calspot/main.go

499 lines
13 KiB
Go

package main
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"database/sql"
"errors"
"fmt"
"io"
"log"
"math/big"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
"golang.org/x/net/webdav"
)
// --- Database & Models ---
var db *sql.DB
func initDB(dbPath string) {
var err error
db, err = sql.Open("sqlite3", dbPath)
if err != nil {
log.Fatalf("Failed to open db: %v", err)
}
schema := `
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT UNIQUE,
password_hash TEXT
);
CREATE TABLE IF NOT EXISTS calendars (
user_id TEXT PRIMARY KEY,
filename TEXT,
content BLOB,
mod_time DATETIME
);
`
if _, err := db.Exec(schema); err != nil {
log.Fatalf("Failed to init schema: %v", err)
}
}
// --- WebDAV FileSystem Implementation (SQLite Backed) ---
type sqlFS struct{}
func (fs *sqlFS) Mkdir(ctx context.Context, name string, perm os.FileMode) error {
return os.ErrPermission // Flat structure, no directories allowed
}
func (fs *sqlFS) RemoveAll(ctx context.Context, name string) error {
userID := ctx.Value("userID").(string)
if name == "/" || name == "" {
return os.ErrInvalid
}
_, err := db.Exec("DELETE FROM calendars WHERE user_id = ?", userID)
return err
}
func (fs *sqlFS) Rename(ctx context.Context, oldName, newName string) error {
// Not strictly supported for this simple single-file logic, but we can allow renaming the display name
userID := ctx.Value("userID").(string)
_, err := db.Exec("UPDATE calendars SET filename = ? WHERE user_id = ?", filepath.Base(newName), userID)
return err
}
func (fs *sqlFS) Stat(ctx context.Context, name string) (os.FileInfo, error) {
userID := ctx.Value("userID").(string)
// Root directory listing
if name == "/" || name == "" {
return &memFileInfo{name: "/", isDir: true, modTime: time.Now()}, nil
}
var filename string
var size int64
var modTime time.Time
// Clean name for DB lookup
reqName := filepath.Base(name)
err := db.QueryRow("SELECT filename, length(content), mod_time FROM calendars WHERE user_id = ?", userID).Scan(&filename, &size, &modTime)
if err == sql.ErrNoRows {
return nil, os.ErrNotExist
}
if err != nil {
return nil, err
}
// If the request name doesn't match the stored name, pretend it doesn't exist
// (unless we want to be loose about it, but WebDAV likes consistency)
if reqName != filename {
return nil, os.ErrNotExist
}
return &memFileInfo{name: filename, size: size, modTime: modTime, isDir: false}, nil
}
func (fs *sqlFS) OpenFile(ctx context.Context, name string, flag int, perm os.FileMode) (webdav.File, error) {
userID := ctx.Value("userID").(string)
if name == "/" {
// Return a virtual directory file for listing
return &sqlDir{userID: userID}, nil
}
// Check extension constraint
if !strings.HasSuffix(strings.ToLower(name), ".ics") {
return nil, errors.New("only .ics files are allowed")
}
// If writing
if flag&os.O_WRONLY != 0 || flag&os.O_RDWR != 0 || flag&os.O_CREATE != 0 {
return &sqlFileBuffer{
userID: userID,
filename: filepath.Base(name),
buffer: new(bytes.Buffer),
}, nil
}
// If reading
var content []byte
var filename string
err := db.QueryRow("SELECT content, filename FROM calendars WHERE user_id = ?", userID).Scan(&content, &filename)
if err != nil {
return nil, os.ErrNotExist
}
if filepath.Base(name) != filename {
return nil, os.ErrNotExist
}
return &memFile{
Reader: bytes.NewReader(content),
info: &memFileInfo{
name: filename,
size: int64(len(content)),
modTime: time.Now(), // Simplified
},
}, nil
}
// --- SQL File Helpers ---
// Helper structs for FileSystem
type memFileInfo struct {
name string
size int64
modTime time.Time
isDir bool
}
func (m *memFileInfo) Name() string { return m.name }
func (m *memFileInfo) Size() int64 { return m.size }
func (m *memFileInfo) Mode() os.FileMode {
if m.isDir {
return os.ModeDir | 0755
}
return 0644
}
func (m *memFileInfo) ModTime() time.Time { return m.modTime }
func (m *memFileInfo) IsDir() bool { return m.isDir }
func (m *memFileInfo) Sys() interface{} { return nil }
// Represents the root directory
type sqlDir struct {
userID string
pos int
}
func (d *sqlDir) Close() error { return nil }
func (d *sqlDir) Read([]byte) (int, error) { return 0, io.EOF }
func (d *sqlDir) Seek(int64, int) (int64, error) { return 0, io.EOF }
func (d *sqlDir) Write([]byte) (int, error) { return 0, os.ErrPermission }
func (d *sqlDir) Stat() (os.FileInfo, error) { return &memFileInfo{name: "/", isDir: true}, nil }
func (d *sqlDir) Readdir(count int) ([]os.FileInfo, error) {
if d.pos > 0 {
return nil, io.EOF
} // Only 0 or 1 file
var filename string
var size int64
var modTime time.Time
err := db.QueryRow("SELECT filename, length(content), mod_time FROM calendars WHERE user_id = ?", d.userID).Scan(&filename, &size, &modTime)
if err == sql.ErrNoRows {
return []os.FileInfo{}, nil
}
d.pos = 1
return []os.FileInfo{&memFileInfo{name: filename, size: size, modTime: modTime}}, nil
}
// Represents a file being written to DB
type sqlFileBuffer struct {
userID string
filename string
buffer *bytes.Buffer
}
func (f *sqlFileBuffer) Read(p []byte) (n int, err error) { return 0, io.EOF }
func (f *sqlFileBuffer) Seek(offset int64, whence int) (int64, error) { return 0, nil }
func (f *sqlFileBuffer) Readdir(count int) ([]os.FileInfo, error) { return nil, os.ErrInvalid }
func (f *sqlFileBuffer) Stat() (os.FileInfo, error) { return &memFileInfo{name: f.filename}, nil }
func (f *sqlFileBuffer) Write(p []byte) (int, error) {
// Limit calendar file size to 10MB to prevent DoS
const maxSize = 10 * 1024 * 1024
if f.buffer.Len()+len(p) > maxSize {
return 0, errors.New("file size exceeds maximum allowed (10MB)")
}
return f.buffer.Write(p)
}
func (f *sqlFileBuffer) Close() error {
// Flush buffer to DB
_, err := db.Exec(`
INSERT INTO calendars (user_id, filename, content, mod_time)
VALUES (?, ?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
filename=excluded.filename,
content=excluded.content,
mod_time=excluded.mod_time
`, f.userID, f.filename, f.buffer.Bytes(), time.Now())
return err
}
// Represents a file being read from memory
type memFile struct {
*bytes.Reader
info os.FileInfo
}
func (f *memFile) Close() error { return nil }
func (f *memFile) Readdir(count int) ([]os.FileInfo, error) { return nil, os.ErrInvalid }
func (f *memFile) Stat() (os.FileInfo, error) { return f.info, nil }
func (f *memFile) Write(p []byte) (int, error) { return 0, os.ErrPermission }
// --- HTTP Handlers ---
func publicCalendarHandler(w http.ResponseWriter, r *http.Request) {
// Path format: /<userid>/calendar.ics
parts := strings.Split(strings.TrimPrefix(r.URL.Path, "/"), "/")
if len(parts) < 1 {
http.NotFound(w, r)
return
}
userID := parts[0]
// Add security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "no-referrer")
var content []byte
err := db.QueryRow("SELECT content FROM calendars WHERE user_id = ?", userID).Scan(&content)
if err != nil {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/calendar")
w.Write(content)
}
func securityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "no-referrer")
next.ServeHTTP(w, r)
})
}
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, pass, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
// Input validation for username
if len(user) > 255 || len(user) == 0 {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
var id, hash string
err := db.QueryRow("SELECT id, password_hash FROM users WHERE username = ?", user).Scan(&id, &hash)
if err != nil || bcrypt.CompareHashAndPassword([]byte(hash), []byte(pass)) != nil {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), "userID", id)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// --- Utils ---
func generatePassword() string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
b := make([]byte, 16)
for i := range b {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
log.Fatalf("Failed to generate secure random password: %v", err)
}
b[i] = charset[num.Int64()]
}
return string(b)
}
func hashPassword(p string) string {
h, _ := bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost)
return string(h)
}
// --- REPL ---
func runREPL() {
scanner := bufio.NewScanner(os.Stdin)
fmt.Println("Server started. REPL available (commands: add, del, list, resetpassword).")
fmt.Print("> ")
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
parts := strings.Fields(line)
if len(parts) == 0 {
fmt.Print("> ")
continue
}
cmd := parts[0]
switch cmd {
case "add":
username := ""
if len(parts) > 1 {
username = parts[1]
// Validate username length and characters
if len(username) > 255 || len(username) == 0 {
fmt.Println("Error: username must be 1-255 characters")
break
}
} else {
u, _ := uuid.NewV7()
username = u.String()
fmt.Printf("Generated username: %s\n", username)
}
password := ""
if len(parts) > 2 {
password = parts[2]
} else {
password = generatePassword()
fmt.Printf("Generated password: %s\n", password)
}
// Generate ID for internal linking
uID, _ := uuid.NewV7()
id := uID.String()
_, err := db.Exec("INSERT INTO users (id, username, password_hash) VALUES (?, ?, ?)", id, username, hashPassword(password))
if err != nil {
fmt.Printf("Error adding user: %v\n", err)
} else {
fmt.Printf("User %s created. Public ID: %s\n", username, id)
}
case "del":
if len(parts) < 2 {
fmt.Println("Usage: del <username>")
break
}
username := parts[1]
res, err := db.Exec("DELETE FROM users WHERE username = ?", username)
if err != nil {
fmt.Printf("Error: %v\n", err)
}
rows, _ := res.RowsAffected()
// Foreign keys aren't strictly enforcing cascading in default SQLite driver without PRAGMA,
// so we manually clean up calendar if needed, or rely on logic.
// Ideally we fetch ID first.
if rows > 0 {
fmt.Println("User deleted.")
// Cleanup calendar (orphan cleanup simplified)
// In a real app we'd fetch ID first, but for this REPL:
// DELETE FROM calendars WHERE user_id NOT IN (SELECT id FROM users)
db.Exec("DELETE FROM calendars WHERE user_id NOT IN (SELECT id FROM users)")
} else {
fmt.Println("User not found.")
}
case "list":
rows, err := db.Query("SELECT username, id FROM users")
if err != nil {
fmt.Printf("Error: %v\n", err)
break
}
fmt.Println("--- Users ---")
for rows.Next() {
var u, i string
rows.Scan(&u, &i)
fmt.Printf("%s (ID: %s)\n", u, i)
}
rows.Close()
case "resetpassword":
if len(parts) < 2 {
fmt.Println("Usage: resetpassword <username> [newpassword]")
break
}
username := parts[1]
password := ""
if len(parts) > 2 {
password = parts[2]
} else {
password = generatePassword()
fmt.Printf("Generated password: %s\n", password)
}
_, err := db.Exec("UPDATE users SET password_hash = ? WHERE username = ?", hashPassword(password), username)
if err != nil {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Println("Password updated.")
}
default:
fmt.Println("Unknown command.")
}
fmt.Print("> ")
}
}
// --- Main ---
func main() {
dbPath := os.Getenv("DB_PATH")
if dbPath == "" {
dbPath = "./data/cal.db"
}
// Ensure directory exists
os.MkdirAll(filepath.Dir(dbPath), 0755)
initDB(dbPath)
defer db.Close()
// WebDAV Config
davHandler := &webdav.Handler{
FileSystem: &sqlFS{},
LockSystem: webdav.NewMemLS(),
Logger: func(r *http.Request, err error) {
if err != nil {
log.Printf("WEBDAV [%s]: %v", r.Method, err)
}
},
}
// Mux
mux := http.NewServeMux()
// Public endpoint (No Auth) - Matches UUIDv7 pattern roughly or just anything not webdav
// Since we don't have a homepage, we handle specific paths.
// We check for /webdav/ first.
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/webdav/") {
// Strip prefix for WebDAV handler so it sees relative root
http.StripPrefix("/webdav", securityHeadersMiddleware(authMiddleware(davHandler))).ServeHTTP(w, r)
return
}
// Public Calendar
publicCalendarHandler(w, r)
})
// Start Server
go func() {
fmt.Println("Server listening on :8000")
if err := http.ListenAndServe(":8000", mux); err != nil {
log.Fatalf("Server failed: %v", err)
}
}()
// Start REPL
runREPL()
}