diff --git a/ottomain.go b/ottomain.go index e79e353..403a699 100644 --- a/ottomain.go +++ b/ottomain.go @@ -1,6 +1,7 @@ package ottoman import ( + "context" "encoding/base64" "errors" "fmt" @@ -72,6 +73,14 @@ func ProcessRequest(script string, params map[string]interface{}) (response map[ // Initialize semaphore and config once once.Do(initSemaphore) + // Validate input + if script == "" { + return nil, fmt.Errorf("script cannot be empty") + } + if params == nil { + return nil, fmt.Errorf("params cannot be nil") + } + // Acquire semaphore to limit concurrent executions select { case semaphore <- struct{}{}: @@ -99,40 +108,48 @@ func ProcessRequest(script string, params map[string]interface{}) (response map[ } vm.Interrupt = make(chan func(), 1) - timeoutDone := make(chan struct{}) + timeoutCtx, timeoutCancel := context.WithCancel(context.Background()) timeoutTimer := time.NewTimer(time.Duration(cfg.Timeout) * time.Second) // Start timeout goroutine go func() { select { case <-timeoutTimer.C: - // Timer expired, try to interrupt + // Timer expired, try to interrupt (non-blocking) select { case vm.Interrupt <- func() { panic(errors.New("some code took to long! Stopping after timeout")) }: - // Interrupt sent - case <-timeoutDone: - // Script already completed - } - case <-timeoutDone: - // Script completed before timeout, ensure timer is stopped - timeoutTimer.Stop() - // Drain timer channel if needed (non-blocking) - select { - case <-timeoutTimer.C: + // Interrupt sent successfully + case <-timeoutCtx.Done(): + // Script already completed, cancel prevented interrupt default: + // Can't send interrupt (channel full), but context cancelled + // Just exit - script might be done anyway + } + case <-timeoutCtx.Done(): + // Script completed before timeout, stop timer + if !timeoutTimer.Stop() { + // Timer already fired, drain it + select { + case <-timeoutTimer.C: + default: + } } } }() defer func() { - // Signal timeout goroutine to exit by closing channel - close(timeoutDone) + // Cancel timeout context to signal goroutine to exit + timeoutCancel() // Stop timer - this is safe even if it already fired - // We don't need to drain the channel here as goroutine handles it timeoutTimer.Stop() + // Drain timer channel to prevent goroutine leak + select { + case <-timeoutTimer.C: + default: + } if r := recover(); r != nil { switch x := r.(type) { @@ -151,10 +168,7 @@ func ProcessRequest(script string, params map[string]interface{}) (response map[ return nil, fmt.Errorf("otto run error: %w", err) } - getOttoValue := func(variable string, err error) (interface{}, error) { - if err != nil { - return nil, err - } + getOttoValue := func(variable string) (interface{}, error) { value, err := vm.Get(variable) if err != nil { return nil, err @@ -169,9 +183,9 @@ func ProcessRequest(script string, params map[string]interface{}) (response map[ response = make(map[string]interface{}) for v := range params { - message, err := getOttoValue(v, err) - if err != nil { - return nil, fmt.Errorf("otto get value error: %w", err) + message, getErr := getOttoValue(v) + if getErr != nil { + return nil, fmt.Errorf("otto get value error for '%s': %w", v, getErr) } switch message := message.(type) { @@ -184,10 +198,14 @@ func ProcessRequest(script string, params map[string]interface{}) (response map[ rt := reflect.TypeOf(message) switch rt.Kind() { case reflect.Slice, reflect.Array: - response[v] = []byte(message.([]byte)) + // Safe type assertion for byte slice + if byteSlice, ok := message.([]byte); ok { + response[v] = byteSlice + } else { + response[v] = message + } default: response[v] = message - //response[v] = nil } }