diff --git a/ottomain.go b/ottomain.go index 02a10aa..b450598 100644 --- a/ottomain.go +++ b/ottomain.go @@ -1,7 +1,6 @@ package ottoman import ( - "context" "encoding/base64" "errors" "fmt" @@ -14,8 +13,9 @@ import ( ) type otto_config struct { - Timeout int `env:"JS_TIMEOUT" envDefault:"2"` - MaxConcurrent int `env:"JS_MAX_CONCURRENT" envDefault:"50"` + Timeout time.Duration `env:"JS_TIMEOUT" envDefault:"2s"` + MaxConcurrent int `env:"JS_MAX_CONCURRENT" envDefault:"50"` + WaitTimeout time.Duration `env:"JS_MAX_WAITTIMEOUT" envDefault:"30s"` } var ( @@ -26,6 +26,8 @@ var ( configMutex sync.Mutex ) +var ErrTimeout = errors.New("javascript execution timeout") + func jsBtoa(b string) string { return base64.StdEncoding.EncodeToString([]byte(b)) } @@ -59,9 +61,9 @@ func initSemaphore() { if !configChecked { var cfg otto_config if err := env.Parse(&cfg); err != nil { - // Use defaults if env parse fails cfg.MaxConcurrent = 50 - cfg.Timeout = 2 + cfg.Timeout = 2 * time.Second + cfg.WaitTimeout = 30 * time.Second } globalConfig = cfg semaphore = make(chan struct{}, cfg.MaxConcurrent) @@ -70,26 +72,22 @@ func initSemaphore() { } func ProcessRequest(script string, params map[string]interface{}) (response map[string]interface{}, err error) { - // Initialize semaphore and config once once.Do(initSemaphore) - // Acquire semaphore to limit concurrent executions + cfg := globalConfig + select { case semaphore <- struct{}{}: - defer func() { <-semaphore }() // Release semaphore when done - case <-time.After(30 * time.Second): - // Timeout if too many concurrent requests - return nil, fmt.Errorf("too many concurrent JavaScript executions, please try again later") + defer func() { <-semaphore }() + case <-time.After(cfg.WaitTimeout): + return nil, errors.New("too many concurrent JavaScript executions, please try again later") } - // Use cached config instead of parsing every time - cfg := globalConfig - vm := otto.New() err = jsRegisterFunctions(vm) if err != nil { - return nil, fmt.Errorf("otto registreing standart functions error: %w", err) + return nil, fmt.Errorf("otto registering standard functions error: %w", err) } for key, uf := range params { @@ -99,64 +97,20 @@ func ProcessRequest(script string, params map[string]interface{}) (response map[ } } - vm.Interrupt = make(chan func(), 1) - timeoutCtx, timeoutCancel := context.WithCancel(context.Background()) - timeoutTimer := time.NewTimer(time.Duration(cfg.Timeout) * time.Second) - - // Start timeout goroutine - go func() { - select { - case <-timeoutTimer.C: - // Timer expired, try to interrupt (non-blocking) - select { - case vm.Interrupt <- func() { - panic(errors.New("some code took to long! Stopping after timeout")) - }: - // Interrupt sent successfully - case <-timeoutCtx.Done(): - // Script already completed, cancel prevented interrupt - default: - // Can't send interrupt (channel full), but context cancelled - // Just exit - script might be done anyway - } - case <-timeoutCtx.Done(): - // Script completed before timeout, stop timer - if !timeoutTimer.Stop() { - // Timer already fired, drain it - select { - case <-timeoutTimer.C: - default: - } - } + timer := time.AfterFunc(cfg.Timeout, func() { + vm.Interrupt <- func() { + panic(ErrTimeout) } - }() - - defer func() { - // Cancel timeout context to signal goroutine to exit - timeoutCancel() - - // Stop timer - this is safe even if it already fired - timeoutTimer.Stop() - // Drain timer channel to prevent goroutine leak - select { - case <-timeoutTimer.C: - default: - } - - if r := recover(); r != nil { - switch x := r.(type) { - case error: - err = x - default: - err = errors.New("otto run error") - } - response = nil - } - }() + }) + defer timer.Stop() _, err = vm.Run(script) if err != nil { + if errors.Is(err, ErrTimeout) { + return nil, ErrTimeout + } + return nil, fmt.Errorf("otto run error: %w", err) }