Initial commit
This commit is contained in:
101
internal/auth/manager.go
Normal file
101
internal/auth/manager.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrExpiredToken = errors.New("token has expired")
|
||||
ErrInvalidSignature = errors.New("invalid token signature")
|
||||
ErrMissingClaims = errors.New("missing required claims")
|
||||
)
|
||||
|
||||
// Claims represents the JWT claims structure
|
||||
type Claims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type JWTManager struct {
|
||||
secretKey []byte
|
||||
tokenDuration time.Duration
|
||||
}
|
||||
|
||||
func NewJWTManager(secretKey string, tokenDuration time.Duration) *JWTManager {
|
||||
return &JWTManager{
|
||||
secretKey: []byte(secretKey),
|
||||
tokenDuration: tokenDuration,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken creates a new JWT token for a user
|
||||
func (m *JWTManager) GenerateToken(userID int64, email string) (string, error) {
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(m.tokenDuration)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: os.Getenv("JWT_ISSUER"),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
tokenString, err := token.SignedString(m.secretKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates the JWT token and returns the claims
|
||||
func (m *JWTManager) ValidateToken(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(
|
||||
tokenString,
|
||||
&Claims{},
|
||||
func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return m.secretKey, nil
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrExpiredToken
|
||||
}
|
||||
if errors.Is(err, jwt.ErrSignatureInvalid) {
|
||||
return nil, ErrInvalidSignature
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// RefreshToken generates a new token from an existing valid token
|
||||
func (m *JWTManager) RefreshToken(tokenString string) (string, error) {
|
||||
claims, err := m.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return m.GenerateToken(claims.UserID, claims.Email)
|
||||
}
|
||||
41
internal/database/postgres.go
Normal file
41
internal/database/postgres.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
DBName string
|
||||
}
|
||||
|
||||
func NewConnection(cfg Config) (*sql.DB, error) {
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName,
|
||||
)
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Connection pooling
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxOpenConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
db.SetConnMaxIdleTime(10 * time.Minute)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
30
internal/handlers/health.go
Normal file
30
internal/handlers/health.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
DBConnected bool `json:"db_connected"`
|
||||
}
|
||||
|
||||
type HealthcheckService struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewHealthcheckService(db *sql.DB) *HealthcheckService {
|
||||
return &HealthcheckService{db: db}
|
||||
}
|
||||
|
||||
func (h *HealthcheckService) Health(w http.ResponseWriter, r *http.Request) {
|
||||
response := HealthResponse{
|
||||
Status: "ok",
|
||||
}
|
||||
|
||||
err := h.db.Ping()
|
||||
response.DBConnected = err == nil
|
||||
|
||||
JSON(w, http.StatusOK, response)
|
||||
}
|
||||
16
internal/handlers/responses.go
Normal file
16
internal/handlers/responses.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func JSON(w http.ResponseWriter, status int, data any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func Error(w http.ResponseWriter, status int, message string) {
|
||||
JSON(w, status, map[string]string{"error": message})
|
||||
}
|
||||
56
internal/middleware/auth.go
Normal file
56
internal/middleware/auth.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.kling.dev/jared/WorkoutTrackerAPI/internal/auth"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
UserIDKey contextKey = "user_id"
|
||||
EmailKey contextKey = "email"
|
||||
)
|
||||
|
||||
func ValidateJWT(jwtManager auth.JWTManager) func(http.HandlerFunc) http.HandlerFunc {
|
||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Starting ValidateJWT handler")
|
||||
auth_header := r.Header.Get("Authorization")
|
||||
if auth_header == "" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("Missing authorization header"))
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.Split(auth_header, " ")
|
||||
if len(parts) != 2 || strings.ToLower(strings.Trim(parts[0], " ")) != "bearer" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("Invalid authorization header format"))
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
claims, err := jwtManager.ValidateToken(tokenString)
|
||||
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
if errors.Is(err, auth.ErrExpiredToken) {
|
||||
w.Write([]byte("Token has expired"))
|
||||
return
|
||||
}
|
||||
w.Write([]byte("Invalid token"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), UserIDKey, claims.Subject)
|
||||
ctx = context.WithValue(ctx, EmailKey, claims.Email)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
18
internal/middleware/cors.go
Normal file
18
internal/middleware/cors.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
func CORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
30
internal/middleware/logging.go
Normal file
30
internal/middleware/logging.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Logging(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
next.ServeHTTP(wrapped, r)
|
||||
|
||||
log.Printf(
|
||||
"%s %s %d %s",
|
||||
r.Method,
|
||||
r.URL.Path,
|
||||
wrapped.statusCode,
|
||||
time.Since(start),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
108
internal/server/server.go
Normal file
108
internal/server/server.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"git.kling.dev/jared/WorkoutTrackerAPI/internal/auth"
|
||||
"git.kling.dev/jared/WorkoutTrackerAPI/internal/handlers"
|
||||
"git.kling.dev/jared/WorkoutTrackerAPI/internal/middleware"
|
||||
)
|
||||
|
||||
type Services struct {
|
||||
healthcheckService handlers.HealthcheckService
|
||||
}
|
||||
|
||||
type Handlers struct {
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
db *sql.DB
|
||||
router *http.ServeMux
|
||||
server *http.Server
|
||||
services *Services
|
||||
handlers *Handlers
|
||||
jwtManager *auth.JWTManager
|
||||
}
|
||||
|
||||
type UserMeResponse struct {
|
||||
UserId int64 `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func NewServer(db *sql.DB, jwtSecret string) *Server {
|
||||
s := &Server{
|
||||
db: db,
|
||||
router: http.NewServeMux(),
|
||||
server: nil,
|
||||
}
|
||||
|
||||
s.services = &Services{
|
||||
healthcheckService: *handlers.NewHealthcheckService(db),
|
||||
}
|
||||
|
||||
s.jwtManager = auth.NewJWTManager(jwtSecret, 60*time.Minute)
|
||||
|
||||
s.routes()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) routes() {
|
||||
auth := middleware.ValidateJWT(*s.jwtManager)
|
||||
s.router.HandleFunc("GET /health", s.services.healthcheckService.Health)
|
||||
s.router.HandleFunc("POST /auth", func(w http.ResponseWriter, r *http.Request) {
|
||||
token, err := s.jwtManager.GenerateToken(123, "jared")
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
handlers.JSON(w, http.StatusOK, token)
|
||||
})
|
||||
s.router.HandleFunc("GET /auth/me", auth(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Start of /auth/me")
|
||||
ctxUserId := r.Context().Value(middleware.UserIDKey).(string)
|
||||
email := r.Context().Value(middleware.EmailKey).(string)
|
||||
userId, _ := strconv.ParseInt(ctxUserId, 10, 64)
|
||||
log.Printf("Values: %d - %s", userId, email)
|
||||
|
||||
response := UserMeResponse{
|
||||
UserId: userId,
|
||||
Email: email,
|
||||
}
|
||||
handlers.JSON(w, http.StatusOK, response)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) Start(addr string) error {
|
||||
handler := middleware.Logging(
|
||||
middleware.CORS(
|
||||
s.router,
|
||||
),
|
||||
)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
ReadTimeout: 15 * time.Second,
|
||||
WriteTimeout: 15 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
s.server = server
|
||||
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if s.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
Reference in New Issue
Block a user