ddns/main.go

568 lines
14 KiB
Go

package main
import (
"context"
"database/sql"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"regexp"
"strconv"
"strings"
"time"
"ddns-manager/db"
"ddns-manager/templates"
"github.com/cloudflare/cloudflare-go"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
_ "github.com/mattn/go-sqlite3"
"github.com/robfig/cron/v3"
)
var (
api *cloudflare.API
scheduler *cron.Cron
lastIP string
jobID cron.EntryID
queries *db.Queries
)
// Simple validation
func validateDNSRecord(name, recordType, content string) error {
if strings.TrimSpace(name) == "" {
return fmt.Errorf("name is required")
}
if strings.TrimSpace(content) == "" {
return fmt.Errorf("content is required")
}
// Validate by type
switch recordType {
case "A":
if net.ParseIP(content) == nil {
return fmt.Errorf("invalid IP address")
}
case "CNAME":
if !regexp.MustCompile(`^[a-zA-Z0-9\-\.]+$`).MatchString(content) {
return fmt.Errorf("invalid domain name")
}
}
return nil
}
// Clean input sanitization
func sanitizeInput(input string) string {
return strings.TrimSpace(input)
}
// Enhanced error responses
func errorResponse(c echo.Context, message string) error {
c.Response().WriteHeader(http.StatusBadRequest)
return templates.Render(c.Response(), templates.ErrorNotification(message))
}
func successResponse(c echo.Context, message string) error {
return templates.Render(c.Response(), templates.SuccessNotification(message))
}
// Improved createDNSRecord
func createDNSRecord(zoneID, domain, name, recordType, content string, ttl int, proxied bool) error {
if api == nil {
return fmt.Errorf("cloudflare API not initialized")
}
// Validate input
if err := validateDNSRecord(name, recordType, content); err != nil {
return err
}
// Prepare full name
fullName := name
if name != "@" && !strings.HasSuffix(name, domain) {
fullName = name + "." + domain
}
if name == "@" {
fullName = domain
}
// Create record
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
rc := cloudflare.ZoneIdentifier(zoneID)
_, err := api.CreateDNSRecord(ctx, rc, cloudflare.CreateDNSRecordParams{
Type: recordType,
Name: fullName,
Content: content,
TTL: ttl,
Proxied: &proxied,
})
if err != nil {
// Simple error handling
if cfErr, ok := err.(*cloudflare.Error); ok {
switch cfErr.ErrorCodes[0] {
case 10000:
return fmt.Errorf("invalid API credentials")
case 81044:
return fmt.Errorf("record already exists")
default:
return fmt.Errorf("cloudflare error: %s", cfErr.ErrorMessages[0])
}
}
return fmt.Errorf("failed to create record: %w", err)
}
return nil
}
func updateDNSRecord(zoneID, id, name, recordType, content string, ttl int, proxied bool) error {
if api == nil {
return fmt.Errorf("cloudflare API not initialized")
}
if err := validateDNSRecord(name, recordType, content); err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
rc := cloudflare.ZoneIdentifier(zoneID)
_, err := api.UpdateDNSRecord(ctx, rc, cloudflare.UpdateDNSRecordParams{
ID: id,
Type: recordType,
Name: name,
Content: content,
TTL: ttl,
Proxied: &proxied,
})
if err != nil {
if cfErr, ok := err.(*cloudflare.Error); ok {
return fmt.Errorf("cloudflare error: %s", cfErr.ErrorMessages[0])
}
return fmt.Errorf("failed to update record: %w", err)
}
return nil
}
func deleteDNSRecord(zoneID, id string) error {
if api == nil {
return fmt.Errorf("cloudflare API not initialized")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
rc := cloudflare.ZoneIdentifier(zoneID)
err := api.DeleteDNSRecord(ctx, rc, id)
if err != nil {
if cfErr, ok := err.(*cloudflare.Error); ok {
return fmt.Errorf("cloudflare error: %s", cfErr.ErrorMessages[0])
}
return fmt.Errorf("failed to delete record: %w", err)
}
return nil
}
func initDatabase() (*sql.DB, error) {
dbPath := os.Getenv("DB_PATH")
if dbPath == "" {
dbPath = "./ddns.db"
}
sqlDB, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, err
}
if err := db.InitSchema(sqlDB); err != nil {
return nil, err
}
queries = db.New(sqlDB)
return sqlDB, nil
}
func initCloudflare(apiToken string) error {
if apiToken == "" {
return nil
}
var err error
api, err = cloudflare.NewWithAPIToken(apiToken)
return err
}
func getCurrentIP() (string, error) {
resp, err := http.Get("https://api.ipify.org")
if err != nil {
return "", err
}
defer resp.Body.Close()
ip, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return strings.TrimSpace(string(ip)), nil
}
func getDNSRecords(zoneID string) ([]templates.DNSRecord, error) {
if api == nil {
return nil, fmt.Errorf("cloudflare API not initialized")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
rc := cloudflare.ZoneIdentifier(zoneID)
recs, _, err := api.ListDNSRecords(ctx, rc, cloudflare.ListDNSRecordsParams{})
if err != nil {
return nil, err
}
var records []templates.DNSRecord
for _, rec := range recs {
records = append(records, templates.DNSRecord{
ID: rec.ID,
Type: rec.Type,
Name: rec.Name,
Content: rec.Content,
TTL: rec.TTL,
Proxied: *rec.Proxied,
CreatedOn: rec.CreatedOn.Format(time.RFC3339),
})
}
return records, nil
}
func updateAllRecordsWithCurrentIP(zoneID string) error {
currentIP, err := getCurrentIP()
if err != nil {
return err
}
if currentIP == lastIP {
return nil
}
lastIP = currentIP
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
rc := cloudflare.ZoneIdentifier(zoneID)
records, _, err := api.ListDNSRecords(ctx, rc, cloudflare.ListDNSRecordsParams{Type: "A"})
if err != nil {
return err
}
for _, rec := range records {
if rec.Content != currentIP {
_, err := api.UpdateDNSRecord(ctx, rc, cloudflare.UpdateDNSRecordParams{
ID: rec.ID,
Type: rec.Type,
Name: rec.Name,
Content: currentIP,
TTL: rec.TTL,
Proxied: rec.Proxied,
})
if err != nil {
log.Printf("Failed to update record %s: %v", rec.Name, err)
}
}
}
return nil
}
func scheduleUpdates(zoneID, updatePeriod string) error {
if jobID != 0 {
scheduler.Remove(jobID)
log.Println("Scheduled update removed")
}
if updatePeriod == "" {
return nil
}
var err error
jobID, err = scheduler.AddFunc(updatePeriod, func() {
if err := updateAllRecordsWithCurrentIP(zoneID); err != nil {
log.Printf("Scheduled update failed: %v", err)
}
log.Println("Scheduled update completed")
})
return err
}
func getUpdateFrequencies() []templates.UpdateFrequency {
return []templates.UpdateFrequency{
{Label: "Every 1 minute", Value: "*/1 * * * *"},
{Label: "Every 5 minutes", Value: "*/5 * * * *"},
{Label: "Every 30 minutes", Value: "*/30 * * * *"},
{Label: "Hourly", Value: "0 * * * *"},
{Label: "Every 6 hours", Value: "0 */6 * * *"},
{Label: "Daily", Value: "0 0 * * *"},
{Label: "Never (manual only)", Value: ""},
}
}
func main() {
// Initialize database
sqlDB, err := initDatabase()
if err != nil {
log.Fatalf("Database init failed: %v", err)
}
defer sqlDB.Close()
// Load config
config, err := queries.GetConfig(context.Background())
if err != nil {
config = db.Config{Domain: "example.com", UpdatePeriod: "0 */6 * * *"}
}
// Initialize Cloudflare
if err := initCloudflare(config.ApiToken); err != nil {
log.Printf("Cloudflare init failed: %v", err)
}
// Initialize scheduler
scheduler = cron.New()
scheduler.Start()
defer scheduler.Stop()
if config.ApiToken != "" && config.ZoneID != "" && config.UpdatePeriod != "" {
scheduleUpdates(config.ZoneID, config.UpdatePeriod)
}
// Setup Echo
e := echo.New()
e.Use(middleware.Logger())
e.Use(middleware.Recover())
e.Use(middleware.CORS())
// Routes
e.GET("/", func(c echo.Context) error {
currentIP, _ := getCurrentIP()
var records []templates.DNSRecord
isConfigured := config.ApiToken != "" && config.ZoneID != ""
if isConfigured {
records, _ = getDNSRecords(config.ZoneID)
}
return templates.Render(c.Response(), templates.Index(templates.IndexProps{
Title: "DNS Manager",
IsConfigured: isConfigured,
CurrentIP: currentIP,
Config: templates.ConfigData{
ZoneID: config.ZoneID,
Domain: config.Domain,
UpdatePeriod: config.UpdatePeriod,
ApiToken: config.ApiToken,
},
Records: records,
UpdateFreqs: getUpdateFrequencies(),
}))
})
e.GET("/refresh-ip", func(c echo.Context) error {
ip, err := getCurrentIP()
if err != nil {
return errorResponse(c, "Failed to get current IP")
}
return c.HTML(http.StatusOK, fmt.Sprintf(`<span id="current-ip" class="fw-bold">%s</span>`, ip))
})
e.POST("/config", func(c echo.Context) error {
apiToken := sanitizeInput(c.FormValue("api_token"))
zoneID := sanitizeInput(c.FormValue("zone_id"))
domain := sanitizeInput(c.FormValue("domain"))
updatePeriod := sanitizeInput(c.FormValue("update_period"))
if apiToken == "" || zoneID == "" || domain == "" {
return errorResponse(c, "Please fill all required fields")
}
// Save config
queries.DeleteAllConfig(context.Background())
err := queries.InsertConfig(context.Background(), db.InsertConfigParams{
ApiToken: apiToken,
ZoneID: zoneID,
Domain: domain,
UpdatePeriod: updatePeriod,
})
if err != nil {
return errorResponse(c, "Failed to save configuration")
}
// Update global config
config.ApiToken = apiToken
config.ZoneID = zoneID
config.Domain = domain
config.UpdatePeriod = updatePeriod
// Reinitialize Cloudflare
initCloudflare(apiToken)
scheduleUpdates(zoneID, updatePeriod)
return templates.Render(c.Response(), templates.ConfigStatus(templates.ConfigData{
ZoneID: zoneID,
Domain: domain,
UpdatePeriod: updatePeriod,
ApiToken: apiToken,
}))
})
e.GET("/config", func(c echo.Context) error {
return templates.Render(c.Response(), templates.ConfigModal(templates.ConfigData{
ZoneID: config.ZoneID,
Domain: config.Domain,
UpdatePeriod: config.UpdatePeriod,
ApiToken: config.ApiToken,
}, getUpdateFrequencies()))
})
e.GET("/records/new", func(c echo.Context) error {
return templates.Render(c.Response(), templates.RecordForm("Add DNS Record", "", config.Domain, templates.DNSRecord{Type: "A", TTL: 1}))
})
e.POST("/records", func(c echo.Context) error {
name := sanitizeInput(c.FormValue("name"))
recordType := sanitizeInput(c.FormValue("type"))
content := sanitizeInput(c.FormValue("content"))
ttlStr := sanitizeInput(c.FormValue("ttl"))
proxied := c.FormValue("proxied") == "on"
useMyIP := c.FormValue("use_my_ip") == "on"
if useMyIP {
currentIP, err := getCurrentIP()
if err != nil {
return errorResponse(c, "Failed to get current IP")
}
content = currentIP
}
ttl, _ := strconv.Atoi(ttlStr)
if ttl == 0 {
ttl = 1
}
if err := createDNSRecord(config.ZoneID, config.Domain, name, recordType, content, ttl, proxied); err != nil {
return errorResponse(c, err.Error())
}
// Return updated table
records, _ := getDNSRecords(config.ZoneID)
currentIP, _ := getCurrentIP()
notification := templates.SuccessNotification("DNS record created")
table := templates.DNSRecordsTable(records, currentIP)
return templates.RenderMultiple(c.Response().Writer, notification, table)
})
e.PUT("/records/:id", func(c echo.Context) error {
id := c.Param("id")
name := sanitizeInput(c.FormValue("name"))
recordType := sanitizeInput(c.FormValue("type"))
content := sanitizeInput(c.FormValue("content"))
ttlStr := sanitizeInput(c.FormValue("ttl"))
proxied := c.FormValue("proxied") == "on"
useMyIP := c.FormValue("use_my_ip") == "on"
if useMyIP {
currentIP, err := getCurrentIP()
if err != nil {
return errorResponse(c, "Failed to get current IP")
}
content = currentIP
}
ttl, _ := strconv.Atoi(ttlStr)
if ttl == 0 {
ttl = 1
}
// Convert name to full domain
fullName := name
if name != "@" && !strings.HasSuffix(name, config.Domain) {
fullName = name + "." + config.Domain
}
if name == "@" {
fullName = config.Domain
}
if err := updateDNSRecord(config.ZoneID, id, fullName, recordType, content, ttl, proxied); err != nil {
return errorResponse(c, err.Error())
}
records, _ := getDNSRecords(config.ZoneID)
currentIP, _ := getCurrentIP()
notification := templates.SuccessNotification("DNS record updated")
table := templates.DNSRecordsTable(records, currentIP)
return templates.RenderMultiple(c.Response().Writer, notification, table)
})
e.DELETE("/records/:id", func(c echo.Context) error {
id := c.Param("id")
if err := deleteDNSRecord(config.ZoneID, id); err != nil {
return errorResponse(c, "Failed to delete record")
}
records, _ := getDNSRecords(config.ZoneID)
currentIP, _ := getCurrentIP()
notification := templates.SuccessNotification("DNS record deleted")
table := templates.DNSRecordsTable(records, currentIP)
return templates.RenderMultiple(c.Response().Writer, notification, table)
})
e.GET("/edit-record/:id", func(c echo.Context) error {
id := c.Param("id")
records, err := getDNSRecords(config.ZoneID)
if err != nil {
return errorResponse(c, "Failed to load records")
}
var record templates.DNSRecord
for _, r := range records {
if r.ID == id {
record = r
break
}
}
if record.ID == "" {
return errorResponse(c, "Record not found")
}
return templates.Render(c.Response(), templates.RecordForm("Edit DNS Record", id, config.Domain, record))
})
e.POST("/update-all-records", func(c echo.Context) error {
if err := updateAllRecordsWithCurrentIP(config.ZoneID); err != nil {
return errorResponse(c, "Failed to update records")
}
records, _ := getDNSRecords(config.ZoneID)
currentIP, _ := getCurrentIP()
notification := templates.SuccessNotification("All A records updated")
table := templates.DNSRecordsTable(records, currentIP)
return templates.RenderMultiple(c.Response().Writer, notification, table)
})
log.Println("Starting server on :3000")
log.Fatal(e.Start(":3000"))
}