implement web push
This commit is contained in:
306
main.go
306
main.go
@@ -2,24 +2,27 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/coder/websocket"
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/joho/godotenv"
|
||||
"golang.org/x/time/rate"
|
||||
"google.golang.org/genai"
|
||||
"html/template"
|
||||
"io"
|
||||
"log"
|
||||
"mime"
|
||||
_ "modernc.org/sqlite"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -37,21 +40,39 @@ type pageTemplate struct {
|
||||
summary *template.Template
|
||||
}
|
||||
|
||||
type state struct {
|
||||
ctx context.Context
|
||||
apiKey apiKey
|
||||
template pageTemplate
|
||||
summaries sync.Map
|
||||
summaryChans map[string]chan string
|
||||
genai *genai.Client
|
||||
subscriberCount atomic.Int64
|
||||
}
|
||||
|
||||
type summaryTemplateData struct {
|
||||
Summary string
|
||||
Location string
|
||||
}
|
||||
|
||||
type updateSubscription struct {
|
||||
Subscription webpush.Subscription `json:"subscription"`
|
||||
Locations []string `json:"locations"`
|
||||
}
|
||||
|
||||
type registeredSubscription struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Subscription *webpush.Subscription `json:"-"`
|
||||
Locations []string `json:"locations"`
|
||||
}
|
||||
|
||||
type state struct {
|
||||
ctx context.Context
|
||||
db *sql.DB
|
||||
genai *genai.Client
|
||||
apiKey apiKey
|
||||
template pageTemplate
|
||||
|
||||
summaries sync.Map
|
||||
summaryChans map[string]chan string
|
||||
|
||||
subscriptions map[string][]registeredSubscription
|
||||
subscriptionsMutex sync.Mutex
|
||||
|
||||
vapidPublicKey string
|
||||
vapidPrivateKey string
|
||||
}
|
||||
|
||||
//go:embed web
|
||||
var webDir embed.FS
|
||||
|
||||
@@ -72,6 +93,11 @@ func main() {
|
||||
log.Fatalln("Please create a .env file using the provided template!")
|
||||
}
|
||||
|
||||
db, err := initDB()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize db: %e\n", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
@@ -88,6 +114,7 @@ func main() {
|
||||
|
||||
state := state{
|
||||
ctx: ctx,
|
||||
db: db,
|
||||
apiKey: apiKey{
|
||||
openWeatherMap: os.Getenv("OPEN_WEATHER_MAP_API_KEY"),
|
||||
},
|
||||
@@ -97,6 +124,11 @@ func main() {
|
||||
summaries: sync.Map{},
|
||||
summaryChans: map[string]chan string{},
|
||||
genai: genaiClient,
|
||||
|
||||
subscriptions: map[string][]registeredSubscription{},
|
||||
|
||||
vapidPublicKey: os.Getenv("VAPID_PUBLIC_KEY_BASE64"),
|
||||
vapidPrivateKey: os.Getenv("VAPID_PRIVATE_KEY_BASE64"),
|
||||
}
|
||||
|
||||
var schedulers []gocron.Scheduler
|
||||
@@ -114,17 +146,26 @@ func main() {
|
||||
|
||||
_, err = s.NewJob(
|
||||
gocron.DurationJob(time.Minute),
|
||||
gocron.NewTask(updateSummaries, &state, locKey, &loc))
|
||||
gocron.NewTask(updateSummaries, &state, locKey, &loc),
|
||||
gocron.WithStartAt(gocron.WithStartImmediately()),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
schedulers = append(schedulers, s)
|
||||
state.summaryChans[locKey] = make(chan string)
|
||||
c := make(chan string)
|
||||
|
||||
state.subscriptions[locKey] = []registeredSubscription{}
|
||||
state.summaryChans[locKey] = c
|
||||
|
||||
go listenForSummaryUpdates(&state, locKey)
|
||||
|
||||
s.Start()
|
||||
}
|
||||
|
||||
loadSubscriptions(&state)
|
||||
|
||||
http.HandleFunc("/", handleHTTPRequest(&state))
|
||||
http.ListenAndServe(":8080", nil)
|
||||
|
||||
@@ -137,28 +178,83 @@ func handleHTTPRequest(state *state) http.HandlerFunc {
|
||||
return func(writer http.ResponseWriter, request *http.Request) {
|
||||
path := strings.TrimPrefix(request.URL.Path, "/")
|
||||
|
||||
switch path {
|
||||
case "":
|
||||
index, _ := webDir.ReadFile("web/index.html")
|
||||
writer.Write(index)
|
||||
|
||||
case "ws":
|
||||
conn, err := websocket.Accept(writer, request, nil)
|
||||
if err != nil {
|
||||
log.Printf("error accepting incoming ws connection: %e\n", err)
|
||||
if path == "" {
|
||||
if request.Method == "" || request.Method == "GET" {
|
||||
index, _ := webDir.ReadFile("web/index.html")
|
||||
writer.Write(index)
|
||||
} else {
|
||||
writer.WriteHeader(http.StatusMethodNotAllowed)
|
||||
}
|
||||
defer conn.CloseNow()
|
||||
} else if path == "vapid" {
|
||||
if request.Method == "" || request.Method == "GET" {
|
||||
writer.Write([]byte(state.vapidPublicKey))
|
||||
} else {
|
||||
writer.WriteHeader(http.StatusMethodNotAllowed)
|
||||
}
|
||||
} else if strings.HasPrefix(path, "registrations") {
|
||||
if path == "registrations" && request.Method == "POST" {
|
||||
defer request.Body.Close()
|
||||
|
||||
log.Println("accepted incoming websocket connection")
|
||||
update := updateSubscription{}
|
||||
err := json.NewDecoder(request.Body).Decode(&update)
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
locKey := request.URL.Query().Get("location")
|
||||
if c, ok := state.summaryChans[locKey]; ok {
|
||||
state.subscriberCount.Add(1)
|
||||
sendSummaryUpdates(state, c, conn)
|
||||
state.subscriberCount.Add(-1)
|
||||
reg, err := registerSubscription(state, &update)
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = json.NewEncoder(writer).Encode(reg)
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
} else if request.Method == "PATCH" {
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 2 {
|
||||
writer.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
regID, err := uuid.Parse(parts[1])
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
defer request.Body.Close()
|
||||
|
||||
update := updateSubscription{}
|
||||
err = json.NewDecoder(request.Body).Decode(&update)
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
reg, err := updateRegisteredSubscription(state, regID, &update)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
writer.WriteHeader(http.StatusNotFound)
|
||||
} else {
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(writer).Encode(reg)
|
||||
} else {
|
||||
writer.WriteHeader(http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
} else {
|
||||
if request.Method != "" && request.Method != "GET" {
|
||||
writer.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
default:
|
||||
summary, ok := state.summaries.Load(path)
|
||||
if ok {
|
||||
state.template.summary.Execute(writer, summaryTemplateData{summary.(string), path})
|
||||
@@ -178,34 +274,113 @@ func handleHTTPRequest(state *state) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func sendSummaryUpdates(state *state, c <-chan string, conn *websocket.Conn) {
|
||||
l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10)
|
||||
ctx, cancel := context.WithCancel(state.ctx)
|
||||
defer cancel()
|
||||
func initDB() (*sql.DB, error) {
|
||||
db, err := sql.Open("sqlite", "file:data.sqlite")
|
||||
if err != nil {
|
||||
log.Fatalln("failed to initialize database")
|
||||
}
|
||||
|
||||
for {
|
||||
err := l.Wait(ctx)
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS subscriptions(
|
||||
id TEXT PRIMARY KEY,
|
||||
locations TEXT NOT NULL,
|
||||
subscription_json TEXT NOT NULL
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func loadSubscriptions(state *state) error {
|
||||
rows, err := state.db.Query(`SELECT id, locations, subscription_json FROM subscriptions;`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var id string
|
||||
var locations string
|
||||
var j string
|
||||
|
||||
err := rows.Scan(&id, &locations, &j)
|
||||
if err != nil {
|
||||
return
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case summary := <-c:
|
||||
log.Println("summary updated. sending updates via sockets...")
|
||||
s := webpush.Subscription{}
|
||||
err = json.Unmarshal([]byte(j), &s)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
w, err := conn.Writer(ctx, websocket.MessageText)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = w.Write([]byte(summary))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Close()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
reg := registeredSubscription{
|
||||
ID: uuid.MustParse(id),
|
||||
Locations: strings.Split(locations, ","),
|
||||
Subscription: &s,
|
||||
}
|
||||
|
||||
for _, l := range reg.Locations {
|
||||
state.subscriptions[l] = append(state.subscriptions[l], reg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateRegisteredSubscription(state *state, id uuid.UUID, update *updateSubscription) (*registeredSubscription, error) {
|
||||
j, err := json.Marshal(update.Subscription)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = state.db.Exec(
|
||||
"UPDATE subscriptions SET subscription_json = ?, locations = ? WHERE id = ?",
|
||||
string(j), strings.Join(update.Locations, ","), id,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ®isteredSubscription{
|
||||
ID: id,
|
||||
Subscription: &update.Subscription,
|
||||
Locations: update.Locations,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func registerSubscription(state *state, sub *updateSubscription) (*registeredSubscription, error) {
|
||||
j, err := json.Marshal(sub.Subscription)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = state.db.Exec(
|
||||
"INSERT INTO subscriptions (id, locations, subscription_json) VALUES (?, ?, ?);",
|
||||
id, strings.Join(sub.Locations, ","), string(j),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reg := registeredSubscription{
|
||||
ID: id,
|
||||
Subscription: &sub.Subscription,
|
||||
Locations: sub.Locations,
|
||||
}
|
||||
|
||||
for _, l := range sub.Locations {
|
||||
state.subscriptions[l] = append(state.subscriptions[l], reg)
|
||||
}
|
||||
|
||||
return ®, nil
|
||||
}
|
||||
|
||||
func updateSummaries(state *state, locKey string, loc *location) {
|
||||
@@ -239,9 +414,32 @@ func updateSummaries(state *state, locKey string, loc *location) {
|
||||
c := state.summaryChans[locKey]
|
||||
|
||||
state.summaries.Store(locKey, summary)
|
||||
if state.subscriberCount.Load() > 0 {
|
||||
if len(state.subscriptions[locKey]) > 0 {
|
||||
c <- summary
|
||||
}
|
||||
|
||||
log.Printf("updated summary for %v successfully\n", locKey)
|
||||
}
|
||||
|
||||
func listenForSummaryUpdates(state *state, locKey string) {
|
||||
c := state.summaryChans[locKey]
|
||||
for {
|
||||
select {
|
||||
case summary := <-c:
|
||||
log.Printf("sending summary for %v to subscribers...\n", locKey)
|
||||
for _, sub := range state.subscriptions[locKey] {
|
||||
_, err := webpush.SendNotificationWithContext(state.ctx, []byte(summary), sub.Subscription, &webpush.Options{
|
||||
VAPIDPublicKey: state.vapidPublicKey,
|
||||
VAPIDPrivateKey: state.vapidPrivateKey,
|
||||
TTL: 30,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("failed to send notification %e\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
case <-state.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user