Files
SproutGate/sproutgate-backend/internal/storage/registration.go
2026-03-20 20:42:33 +08:00

215 lines
5.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package storage
import (
"crypto/rand"
"errors"
"os"
"strings"
"time"
"sproutgate-backend/internal/models"
)
// InviteEntry 管理员发放的注册邀请码。
type InviteEntry struct {
Code string `json:"code"`
Note string `json:"note,omitempty"`
MaxUses int `json:"maxUses"` // 0 表示不限次数
Uses int `json:"uses"`
ExpiresAt string `json:"expiresAt,omitempty"` // RFC3339空表示不过期
CreatedAt string `json:"createdAt"`
}
// RegistrationConfig 注册策略与邀请码列表data/config/registration.json
type RegistrationConfig struct {
RequireInviteCode bool `json:"requireInviteCode"`
Invites []InviteEntry `json:"invites"`
}
func normalizeInviteCode(raw string) string {
return strings.ToUpper(strings.TrimSpace(raw))
}
func (s *Store) loadOrCreateRegistrationConfig() error {
s.mu.Lock()
defer s.mu.Unlock()
if _, err := os.Stat(s.registrationPath); errors.Is(err, os.ErrNotExist) {
cfg := RegistrationConfig{RequireInviteCode: false, Invites: []InviteEntry{}}
if err := writeJSONFile(s.registrationPath, cfg); err != nil {
return err
}
s.registrationConfig = cfg
return nil
}
var cfg RegistrationConfig
if err := readJSONFile(s.registrationPath, &cfg); err != nil {
return err
}
if cfg.Invites == nil {
cfg.Invites = []InviteEntry{}
}
s.registrationConfig = cfg
return nil
}
func (s *Store) persistRegistrationConfigLocked() error {
return writeJSONFile(s.registrationPath, s.registrationConfig)
}
// RegistrationRequireInvite 是否强制要求邀请码才能发起注册(发邮件验证码)。
func (s *Store) RegistrationRequireInvite() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.registrationConfig.RequireInviteCode
}
// GetRegistrationConfig 返回配置副本(管理端)。
func (s *Store) GetRegistrationConfig() RegistrationConfig {
s.mu.Lock()
defer s.mu.Unlock()
out := s.registrationConfig
out.Invites = append([]InviteEntry(nil), s.registrationConfig.Invites...)
return out
}
// SetRegistrationRequireInvite 更新是否强制邀请码。
func (s *Store) SetRegistrationRequireInvite(require bool) error {
s.mu.Lock()
defer s.mu.Unlock()
s.registrationConfig.RequireInviteCode = require
return s.persistRegistrationConfigLocked()
}
func inviteEntryValid(e *InviteEntry) error {
if strings.TrimSpace(e.ExpiresAt) != "" {
t, err := time.Parse(time.RFC3339, e.ExpiresAt)
if err == nil && time.Now().After(t) {
return errors.New("invite code expired")
}
}
if e.MaxUses > 0 && e.Uses >= e.MaxUses {
return errors.New("invite code has been fully used")
}
return nil
}
// ValidateInviteForRegister 校验邀请码是否可用(发验证码前,不扣次)。
func (s *Store) ValidateInviteForRegister(code string) error {
n := normalizeInviteCode(code)
if n == "" {
return errors.New("invite code is required")
}
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.registrationConfig.Invites {
e := &s.registrationConfig.Invites[i]
if strings.EqualFold(e.Code, n) {
return inviteEntryValid(e)
}
}
return errors.New("invalid invite code")
}
// RedeemInvite 邮箱验证通过创建用户后扣减邀请码使用次数。
func (s *Store) RedeemInvite(code string) error {
n := normalizeInviteCode(code)
if n == "" {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.registrationConfig.Invites {
e := &s.registrationConfig.Invites[i]
if strings.EqualFold(e.Code, n) {
if err := inviteEntryValid(e); err != nil {
return err
}
e.Uses++
return s.persistRegistrationConfigLocked()
}
}
return errors.New("invalid invite code")
}
const inviteCodeAlphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
func randomInviteToken(n int) (string, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", err
}
var sb strings.Builder
sb.Grow(n)
for i := 0; i < n; i++ {
sb.WriteByte(inviteCodeAlphabet[int(b[i])%len(inviteCodeAlphabet)])
}
return sb.String(), nil
}
// AddInviteEntry 生成新邀请码并写入配置。
func (s *Store) AddInviteEntry(note string, maxUses int, expiresAt string) (InviteEntry, error) {
s.mu.Lock()
defer s.mu.Unlock()
var code string
for attempt := 0; attempt < 24; attempt++ {
c, err := randomInviteToken(8)
if err != nil {
return InviteEntry{}, err
}
dup := false
for _, ex := range s.registrationConfig.Invites {
if strings.EqualFold(ex.Code, c) {
dup = true
break
}
}
if !dup {
code = c
break
}
}
if code == "" {
return InviteEntry{}, errors.New("failed to generate unique invite code")
}
expiresAt = strings.TrimSpace(expiresAt)
if expiresAt != "" {
if _, err := time.Parse(time.RFC3339, expiresAt); err != nil {
return InviteEntry{}, errors.New("invalid expiresAt (use RFC3339)")
}
}
if maxUses < 0 {
maxUses = 0
}
entry := InviteEntry{
Code: code,
Note: strings.TrimSpace(note),
MaxUses: maxUses,
Uses: 0,
ExpiresAt: expiresAt,
CreatedAt: models.NowISO(),
}
s.registrationConfig.Invites = append(s.registrationConfig.Invites, entry)
if err := s.persistRegistrationConfigLocked(); err != nil {
s.registrationConfig.Invites = s.registrationConfig.Invites[:len(s.registrationConfig.Invites)-1]
return InviteEntry{}, err
}
return entry, nil
}
// DeleteInviteEntry 按码删除(大小写不敏感)。
func (s *Store) DeleteInviteEntry(code string) error {
n := normalizeInviteCode(code)
if n == "" {
return errors.New("code is required")
}
s.mu.Lock()
defer s.mu.Unlock()
for i, e := range s.registrationConfig.Invites {
if strings.EqualFold(e.Code, n) {
s.registrationConfig.Invites = append(s.registrationConfig.Invites[:i], s.registrationConfig.Invites[i+1:]...)
return s.persistRegistrationConfigLocked()
}
}
return errors.New("invite not found")
}