implement web push

This commit is contained in:
2025-05-10 13:04:39 +01:00
parent f832b95d69
commit 44d744266a
11 changed files with 559 additions and 142 deletions

306
main.go
View File

@@ -2,24 +2,27 @@ package main
import (
"context"
"database/sql"
"embed"
_ "embed"
"encoding/json"
"errors"
"fmt"
"github.com/coder/websocket"
"github.com/SherClockHolmes/webpush-go"
"github.com/go-co-op/gocron/v2"
"github.com/google/uuid"
"github.com/joho/godotenv"
"golang.org/x/time/rate"
"google.golang.org/genai"
"html/template"
"io"
"log"
"mime"
_ "modernc.org/sqlite"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
)
@@ -37,21 +40,39 @@ type pageTemplate struct {
summary *template.Template
}
type state struct {
ctx context.Context
apiKey apiKey
template pageTemplate
summaries sync.Map
summaryChans map[string]chan string
genai *genai.Client
subscriberCount atomic.Int64
}
type summaryTemplateData struct {
Summary string
Location string
}
type updateSubscription struct {
Subscription webpush.Subscription `json:"subscription"`
Locations []string `json:"locations"`
}
type registeredSubscription struct {
ID uuid.UUID `json:"id"`
Subscription *webpush.Subscription `json:"-"`
Locations []string `json:"locations"`
}
type state struct {
ctx context.Context
db *sql.DB
genai *genai.Client
apiKey apiKey
template pageTemplate
summaries sync.Map
summaryChans map[string]chan string
subscriptions map[string][]registeredSubscription
subscriptionsMutex sync.Mutex
vapidPublicKey string
vapidPrivateKey string
}
//go:embed web
var webDir embed.FS
@@ -72,6 +93,11 @@ func main() {
log.Fatalln("Please create a .env file using the provided template!")
}
db, err := initDB()
if err != nil {
log.Fatalf("failed to initialize db: %e\n", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -88,6 +114,7 @@ func main() {
state := state{
ctx: ctx,
db: db,
apiKey: apiKey{
openWeatherMap: os.Getenv("OPEN_WEATHER_MAP_API_KEY"),
},
@@ -97,6 +124,11 @@ func main() {
summaries: sync.Map{},
summaryChans: map[string]chan string{},
genai: genaiClient,
subscriptions: map[string][]registeredSubscription{},
vapidPublicKey: os.Getenv("VAPID_PUBLIC_KEY_BASE64"),
vapidPrivateKey: os.Getenv("VAPID_PRIVATE_KEY_BASE64"),
}
var schedulers []gocron.Scheduler
@@ -114,17 +146,26 @@ func main() {
_, err = s.NewJob(
gocron.DurationJob(time.Minute),
gocron.NewTask(updateSummaries, &state, locKey, &loc))
gocron.NewTask(updateSummaries, &state, locKey, &loc),
gocron.WithStartAt(gocron.WithStartImmediately()),
)
if err != nil {
log.Fatal(err)
}
schedulers = append(schedulers, s)
state.summaryChans[locKey] = make(chan string)
c := make(chan string)
state.subscriptions[locKey] = []registeredSubscription{}
state.summaryChans[locKey] = c
go listenForSummaryUpdates(&state, locKey)
s.Start()
}
loadSubscriptions(&state)
http.HandleFunc("/", handleHTTPRequest(&state))
http.ListenAndServe(":8080", nil)
@@ -137,28 +178,83 @@ func handleHTTPRequest(state *state) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
path := strings.TrimPrefix(request.URL.Path, "/")
switch path {
case "":
index, _ := webDir.ReadFile("web/index.html")
writer.Write(index)
case "ws":
conn, err := websocket.Accept(writer, request, nil)
if err != nil {
log.Printf("error accepting incoming ws connection: %e\n", err)
if path == "" {
if request.Method == "" || request.Method == "GET" {
index, _ := webDir.ReadFile("web/index.html")
writer.Write(index)
} else {
writer.WriteHeader(http.StatusMethodNotAllowed)
}
defer conn.CloseNow()
} else if path == "vapid" {
if request.Method == "" || request.Method == "GET" {
writer.Write([]byte(state.vapidPublicKey))
} else {
writer.WriteHeader(http.StatusMethodNotAllowed)
}
} else if strings.HasPrefix(path, "registrations") {
if path == "registrations" && request.Method == "POST" {
defer request.Body.Close()
log.Println("accepted incoming websocket connection")
update := updateSubscription{}
err := json.NewDecoder(request.Body).Decode(&update)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
locKey := request.URL.Query().Get("location")
if c, ok := state.summaryChans[locKey]; ok {
state.subscriberCount.Add(1)
sendSummaryUpdates(state, c, conn)
state.subscriberCount.Add(-1)
reg, err := registerSubscription(state, &update)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
err = json.NewEncoder(writer).Encode(reg)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
}
} else if request.Method == "PATCH" {
parts := strings.Split(path, "/")
if len(parts) < 2 {
writer.WriteHeader(http.StatusMethodNotAllowed)
return
}
regID, err := uuid.Parse(parts[1])
if err != nil {
writer.WriteHeader(http.StatusNotFound)
return
}
defer request.Body.Close()
update := updateSubscription{}
err = json.NewDecoder(request.Body).Decode(&update)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
reg, err := updateRegisteredSubscription(state, regID, &update)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
writer.WriteHeader(http.StatusNotFound)
} else {
writer.WriteHeader(http.StatusInternalServerError)
}
return
}
json.NewEncoder(writer).Encode(reg)
} else {
writer.WriteHeader(http.StatusMethodNotAllowed)
}
} else {
if request.Method != "" && request.Method != "GET" {
writer.WriteHeader(http.StatusMethodNotAllowed)
return
}
default:
summary, ok := state.summaries.Load(path)
if ok {
state.template.summary.Execute(writer, summaryTemplateData{summary.(string), path})
@@ -178,34 +274,113 @@ func handleHTTPRequest(state *state) http.HandlerFunc {
}
}
func sendSummaryUpdates(state *state, c <-chan string, conn *websocket.Conn) {
l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10)
ctx, cancel := context.WithCancel(state.ctx)
defer cancel()
func initDB() (*sql.DB, error) {
db, err := sql.Open("sqlite", "file:data.sqlite")
if err != nil {
log.Fatalln("failed to initialize database")
}
for {
err := l.Wait(ctx)
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS subscriptions(
id TEXT PRIMARY KEY,
locations TEXT NOT NULL,
subscription_json TEXT NOT NULL
);
`)
if err != nil {
return nil, err
}
return db, nil
}
func loadSubscriptions(state *state) error {
rows, err := state.db.Query(`SELECT id, locations, subscription_json FROM subscriptions;`)
if err != nil {
return err
}
for rows.Next() {
var id string
var locations string
var j string
err := rows.Scan(&id, &locations, &j)
if err != nil {
return
continue
}
select {
case summary := <-c:
log.Println("summary updated. sending updates via sockets...")
s := webpush.Subscription{}
err = json.Unmarshal([]byte(j), &s)
if err != nil {
continue
}
w, err := conn.Writer(ctx, websocket.MessageText)
if err != nil {
return
}
_, err = w.Write([]byte(summary))
if err != nil {
return
}
w.Close()
case <-ctx.Done():
return
reg := registeredSubscription{
ID: uuid.MustParse(id),
Locations: strings.Split(locations, ","),
Subscription: &s,
}
for _, l := range reg.Locations {
state.subscriptions[l] = append(state.subscriptions[l], reg)
}
}
return nil
}
func updateRegisteredSubscription(state *state, id uuid.UUID, update *updateSubscription) (*registeredSubscription, error) {
j, err := json.Marshal(update.Subscription)
if err != nil {
return nil, err
}
_, err = state.db.Exec(
"UPDATE subscriptions SET subscription_json = ?, locations = ? WHERE id = ?",
string(j), strings.Join(update.Locations, ","), id,
)
if err != nil {
return nil, err
}
return &registeredSubscription{
ID: id,
Subscription: &update.Subscription,
Locations: update.Locations,
}, nil
}
func registerSubscription(state *state, sub *updateSubscription) (*registeredSubscription, error) {
j, err := json.Marshal(sub.Subscription)
if err != nil {
return nil, err
}
id, err := uuid.NewV7()
if err != nil {
return nil, err
}
_, err = state.db.Exec(
"INSERT INTO subscriptions (id, locations, subscription_json) VALUES (?, ?, ?);",
id, strings.Join(sub.Locations, ","), string(j),
)
if err != nil {
return nil, err
}
reg := registeredSubscription{
ID: id,
Subscription: &sub.Subscription,
Locations: sub.Locations,
}
for _, l := range sub.Locations {
state.subscriptions[l] = append(state.subscriptions[l], reg)
}
return &reg, nil
}
func updateSummaries(state *state, locKey string, loc *location) {
@@ -239,9 +414,32 @@ func updateSummaries(state *state, locKey string, loc *location) {
c := state.summaryChans[locKey]
state.summaries.Store(locKey, summary)
if state.subscriberCount.Load() > 0 {
if len(state.subscriptions[locKey]) > 0 {
c <- summary
}
log.Printf("updated summary for %v successfully\n", locKey)
}
func listenForSummaryUpdates(state *state, locKey string) {
c := state.summaryChans[locKey]
for {
select {
case summary := <-c:
log.Printf("sending summary for %v to subscribers...\n", locKey)
for _, sub := range state.subscriptions[locKey] {
_, err := webpush.SendNotificationWithContext(state.ctx, []byte(summary), sub.Subscription, &webpush.Options{
VAPIDPublicKey: state.vapidPublicKey,
VAPIDPrivateKey: state.vapidPrivateKey,
TTL: 30,
})
if err != nil {
log.Printf("failed to send notification %e\n", err)
}
}
case <-state.ctx.Done():
return
}
}
}