187 lines
5.6 KiB
Go
187 lines
5.6 KiB
Go
// Package session (session manager) control process session lifetime create, update and destroy
|
|
package session
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
var milis int64 = 1000
|
|
|
|
func init() {
|
|
if testing.Testing() {
|
|
milis = 1
|
|
}
|
|
}
|
|
|
|
// for provider fun
|
|
var provides = make(map[string]Provider)
|
|
|
|
// ProviderNames return slice of strings - registered Provider names
|
|
func ProviderNames() []string {
|
|
var prdnames []string
|
|
for pn := range provides {
|
|
prdnames = append(prdnames, pn)
|
|
}
|
|
return prdnames
|
|
}
|
|
|
|
// MilisPerSec return time resolution (milliseconds / 1sec) changed for short time in testing
|
|
func MilisPerSec() int64 {
|
|
return milis
|
|
}
|
|
|
|
// Provider interace implement lifecycle for one session
|
|
type Provider interface {
|
|
//set additional params for provider ex: sql db connection, filesystem path .. etc.
|
|
SetParams(params any) error
|
|
//create new session using sid value
|
|
Init(sid string) (Session, error)
|
|
//read and return existing session by id or if not exist create new session
|
|
Load(sid string) (Session, error)
|
|
//destroy remove session with sid from storage if exist
|
|
Destroy(sid string) error
|
|
//regenerate id change old sid to newsid and preserve existing session data
|
|
ChangeID(oldsid, newsid string) (err error)
|
|
//Exists return true if session with sid exist
|
|
Exists(sid string) bool
|
|
//gc remove all outdated sessions
|
|
GC(maxlifetime int64)
|
|
}
|
|
|
|
// Session interface implement storage for one session and have maxLifetime and lastAccessTime
|
|
type Session interface {
|
|
//set session value and update last access time
|
|
Set(key, value any) error
|
|
//get session value and update last access time
|
|
Get(key any) (v any, err error)
|
|
//delete session value
|
|
Delete(key any) error
|
|
//get session id
|
|
SessionID() string
|
|
}
|
|
|
|
// Register makes a session provide available by the provided name.
|
|
// If Register is called twice with the same name or if driver is nil, it panics.
|
|
func Register(name string, provide Provider) {
|
|
if provide == nil {
|
|
panic("session: Register provide is nil, must be import any session storage implementation")
|
|
}
|
|
if _, dup := provides[name]; dup {
|
|
panic("session: Already registered provider: " + name)
|
|
}
|
|
provides[name] = provide
|
|
}
|
|
|
|
// Manager controls all sessions with registered storage provider
|
|
type Manager struct {
|
|
provider Provider
|
|
sessOpts *SessOpts
|
|
addOpts any
|
|
}
|
|
|
|
// SessOpts set session options or NewManager
|
|
type SessOpts struct {
|
|
CookieName string
|
|
MaxLifetime int64
|
|
Ssl bool
|
|
}
|
|
|
|
// NewManager create new *Manager using SesOpts and aditional any other opts for using in provider
|
|
func NewManager(providerName string, sopts *SessOpts, adopts any) (manager *Manager, err error) {
|
|
var prv Provider
|
|
var ok bool
|
|
if prv, ok = provides[providerName]; !ok {
|
|
return nil, fmt.Errorf("session: Provider: %q not found (forgotten import?)", providerName)
|
|
}
|
|
if err = prv.SetParams(adopts); err != nil {
|
|
return nil, fmt.Errorf("session params: %v not valid: %v", adopts, err)
|
|
}
|
|
m := &Manager{
|
|
provider: prv,
|
|
sessOpts: sopts,
|
|
addOpts: adopts,
|
|
}
|
|
go m.GC()
|
|
return m, nil
|
|
}
|
|
|
|
// generate new secure 32 byte sessionID
|
|
func (manager *Manager) sessionID() (sid string, err error) {
|
|
b := make([]byte, 32)
|
|
if _, err = io.ReadFull(rand.Reader, b); err != nil {
|
|
return "", fmt.Errorf("Session manager error: generate sessionID failed: %w", err)
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
// SessionStart start session for next http response
|
|
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Session, err error) {
|
|
var cookie *http.Cookie
|
|
var sid string
|
|
if cookie, err = r.Cookie(manager.sessOpts.CookieName); err != nil || cookie == nil {
|
|
if sid, err = manager.sessionID(); err != nil {
|
|
return nil, err
|
|
}
|
|
if session, err = manager.provider.Init(sid); err != nil {
|
|
return nil, fmt.Errorf("Session init failed: %w", err)
|
|
}
|
|
cookie := http.Cookie{
|
|
Name: manager.sessOpts.CookieName, Value: url.QueryEscape(sid), Path: "/",
|
|
HttpOnly: true, MaxAge: int(manager.sessOpts.MaxLifetime),
|
|
Secure: manager.sessOpts.Ssl,
|
|
}
|
|
http.SetCookie(w, &cookie)
|
|
} else {
|
|
if sid, err = url.QueryUnescape(cookie.Value); err != nil {
|
|
return nil, fmt.Errorf("Session cookie decode error: %v", err)
|
|
}
|
|
if session, err = manager.provider.Load(sid); err != nil {
|
|
return nil, fmt.Errorf("Session provider load error: %w", err)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// SessionDestroy end session and delete session data at the server
|
|
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) (err error) {
|
|
var cookie *http.Cookie
|
|
if cookie, err = r.Cookie(manager.sessOpts.CookieName); err != nil || cookie.Value == "" {
|
|
return fmt.Errorf("get cookie from request failed: %v", err)
|
|
}
|
|
|
|
manager.provider.Destroy(cookie.Value)
|
|
rmcookie := http.Cookie{
|
|
Name: manager.sessOpts.CookieName, Path: "/", HttpOnly: true,
|
|
Expires: time.Now(), MaxAge: -1, Secure: manager.sessOpts.Ssl,
|
|
}
|
|
http.SetCookie(w, &rmcookie)
|
|
return nil
|
|
}
|
|
|
|
// Exists return true if session with sid exists on server
|
|
func (manager *Manager) Exists(sid string) bool {
|
|
return manager.provider.Exists(sid)
|
|
}
|
|
|
|
// RegenerateID vhange sid and preserve all session data
|
|
func (manager *Manager) RegenerateID(w http.ResponseWriter, r *http.Request) {
|
|
if ck, err := r.Cookie(manager.sessOpts.CookieName); err == nil && ck.Value != "" {
|
|
if newid, err := manager.sessionID(); err != nil {
|
|
manager.provider.ChangeID(ck.Value, newid)
|
|
}
|
|
}
|
|
}
|
|
|
|
// GC remove sessions which exceeded manager.maxLifetime
|
|
func (manager *Manager) GC() {
|
|
manager.provider.GC(manager.sessOpts.MaxLifetime)
|
|
msec := milis * manager.sessOpts.MaxLifetime
|
|
time.AfterFunc(time.Duration(msec), func() { manager.GC() })
|
|
}
|