package host
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
"shep/internal/fbs/shep"
"shep/internal/logger"
"shep/internal/protocol"
)
const (
pluginPrefix = "shep-plugin-"
handshakeTimeout = 5 * time.Second
shutdownTimeout = 3 * time.Second
invokeTimeout = 5 * time.Second
)
var (
ErrPluginNotFound = errors.New("plugin not found")
ErrHandshakeFailed = errors.New("handshake failed")
ErrPluginNotRunning = errors.New("plugin not running")
ErrUnexpectedMessage = errors.New("unexpected message type")
ErrProtocolMismatch = errors.New("protocol version mismatch")
ErrHandshakeRejected = errors.New("handshake rejected by host")
ErrPluginStartFailed = errors.New("plugin start failed")
ErrPluginCommunication = errors.New("plugin communication error")
ErrCapabilityNotFound = errors.New("capability not found")
ErrInvokeFailed = errors.New("capability invocation failed")
)
// CapabilityInfo describes a capability provided by a plugin.
type CapabilityInfo struct {
Name string
Version int
}
// PluginInfo contains metadata about a discovered plugin.
type PluginInfo struct {
Name string
Path string
Version int
Capabilities []CapabilityInfo
}
// Plugin represents a running plugin process.
type Plugin struct {
info PluginInfo
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Reader
log logger.Logger
mu sync.Mutex
}
// Manager handles plugin discovery, lifecycle, and communication.
type Manager struct {
pluginDir string
plugins map[string]*Plugin
capabilities map[string]*Plugin // capability name -> plugin
log logger.Logger
mu sync.RWMutex
}
// NewManager creates a new plugin manager.
func NewManager(log logger.Logger) *Manager {
// Default plugin dir is same directory as the executable
execPath, err := os.Executable()
if err != nil {
execPath = "."
}
pluginDir := filepath.Dir(execPath)
return &Manager{
pluginDir: pluginDir,
plugins: make(map[string]*Plugin),
capabilities: make(map[string]*Plugin),
log: log.WithPrefix("plugin-manager"),
}
}
// SetPluginDir sets the directory to search for plugins.
func (m *Manager) SetPluginDir(dir string) {
m.pluginDir = dir
}
// Discover finds all available plugins in the plugin directory.
func (m *Manager) Discover() ([]string, error) {
entries, err := os.ReadDir(m.pluginDir)
if err != nil {
return nil, fmt.Errorf("read plugin directory: %w", err)
}
var plugins []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
if name, found := strings.CutPrefix(entry.Name(), pluginPrefix); found {
plugins = append(plugins, name)
}
}
return plugins, nil
}
// Start starts a plugin by name and performs handshake.
func (m *Manager) Start(name string) (*PluginInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Check if already running
if p, exists := m.plugins[name]; exists {
return &p.info, nil
}
pluginPath := filepath.Join(m.pluginDir, pluginPrefix+name)
if _, err := os.Stat(pluginPath); os.IsNotExist(err) {
return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, name)
}
m.log.Info("starting plugin: %s", name)
cmd := exec.Command(pluginPath)
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("%w: stdin pipe: %v", ErrPluginStartFailed, err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
stdin.Close()
return nil, fmt.Errorf("%w: stdout pipe: %v", ErrPluginStartFailed, err)
}
// Redirect stderr to our stderr for debugging
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
stdin.Close()
return nil, fmt.Errorf("%w: %v", ErrPluginStartFailed, err)
}
plugin := &Plugin{
info: PluginInfo{
Name: name,
Path: pluginPath,
},
cmd: cmd,
stdin: stdin,
stdout: bufio.NewReader(stdout),
log: m.log.WithPrefix(name),
}
// Perform handshake
if err := m.doHandshake(plugin); err != nil {
plugin.stop()
return nil, fmt.Errorf("%w: %v", ErrHandshakeFailed, err)
}
m.plugins[name] = plugin
// Register capabilities
for _, cap := range plugin.info.Capabilities {
m.capabilities[cap.Name] = plugin
m.log.Debug("registered capability: %s (version %d) from plugin %s",
cap.Name, cap.Version, name)
}
m.log.Info("plugin started: %s (version=%d, capabilities=%v)",
name, plugin.info.Version, plugin.info.Capabilities)
return &plugin.info, nil
}
func (m *Manager) doHandshake(p *Plugin) error {
// Read handshake from plugin (plugin sends first)
done := make(chan error, 1)
var hs *shep.Handshake
var data []byte
go func() {
kind, msgData, err := protocol.ReadMessage(p.stdout)
if err != nil {
done <- err
return
}
if kind != protocol.KindHandshake {
done <- fmt.Errorf("%w: expected handshake, got kind %d", ErrUnexpectedMessage, kind)
return
}
data = msgData
hs, err = protocol.ParseHandshake(data)
done <- err
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("read handshake: %w", err)
}
case <-time.After(handshakeTimeout):
return fmt.Errorf("handshake timeout")
}
// Check protocol version
if hs.ProtocolVersion() != protocol.ProtocolVersion {
// Send rejection
resp := protocol.BuildHandshakeResponse(shep.StatusFailure,
fmt.Sprintf("protocol version mismatch: expected %d, got %d",
protocol.ProtocolVersion, hs.ProtocolVersion()))
protocol.WriteMessage(p.stdin, protocol.KindHandshakeResponse, resp)
return fmt.Errorf("%w: expected %d, got %d",
ErrProtocolMismatch, protocol.ProtocolVersion, hs.ProtocolVersion())
}
// Store plugin info
p.info.Version = int(hs.ProtocolVersion())
for i := 0; i < hs.CapabilitiesLength(); i++ {
cap := new(shep.Capability)
if hs.Capabilities(cap, i) {
p.info.Capabilities = append(p.info.Capabilities, CapabilityInfo{
Name: string(cap.Name()),
Version: int(cap.Version()),
})
}
}
// Send handshake response (acceptance)
resp := protocol.BuildHandshakeResponse(shep.StatusOk, "")
if err := protocol.WriteMessage(p.stdin, protocol.KindHandshakeResponse, resp); err != nil {
return fmt.Errorf("send handshake response: %w", err)
}
return nil
}
// Stop stops a running plugin.
func (m *Manager) Stop(name string) error {
m.mu.Lock()
defer m.mu.Unlock()
p, exists := m.plugins[name]
if !exists {
return ErrPluginNotRunning
}
// Unregister capabilities
for _, cap := range p.info.Capabilities {
delete(m.capabilities, cap.Name)
}
delete(m.plugins, name)
return p.stop()
}
// StopAll stops all running plugins.
func (m *Manager) StopAll() {
m.mu.Lock()
defer m.mu.Unlock()
for name, p := range m.plugins {
m.log.Info("stopping plugin: %s", name)
p.stop()
}
m.plugins = make(map[string]*Plugin)
m.capabilities = make(map[string]*Plugin)
}
func (p *Plugin) stop() error {
// Send shutdown message
shutdownMsg := protocol.BuildShutdown("host shutdown")
protocol.WriteMessage(p.stdin, protocol.KindShutdown, shutdownMsg)
p.stdin.Close()
// Wait for process to exit
done := make(chan error, 1)
go func() {
done <- p.cmd.Wait()
}()
select {
case <-done:
return nil
case <-time.After(shutdownTimeout):
p.cmd.Process.Kill()
return nil
}
}
// Invoke calls a capability by name with the given payload.
// This is the primary way to interact with plugins - by capability, not by plugin name.
func (m *Manager) Invoke(capability, correlationID string, payload []byte) ([]byte, error) {
m.mu.RLock()
p, exists := m.capabilities[capability]
m.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("%w: %s", ErrCapabilityNotFound, capability)
}
p.mu.Lock()
defer p.mu.Unlock()
// Send envelope request
envelope := protocol.BuildEnvelope(capability, correlationID, false, "", payload)
if err := protocol.WriteMessage(p.stdin, protocol.KindEnvelope, envelope); err != nil {
return nil, fmt.Errorf("%w: send envelope: %v", ErrPluginCommunication, err)
}
// Wait for response (with timeout)
done := make(chan error, 1)
var response []byte
var responseErr string
go func() {
for {
kind, data, err := protocol.ReadMessage(p.stdout)
if err != nil {
done <- err
return
}
switch kind {
case protocol.KindEnvelope:
env, err := protocol.ParseEnvelope(data)
if err != nil {
done <- err
return
}
// Check correlation ID
if string(env.CorrelationId()) == correlationID && env.IsResponse() {
if errMsg := env.Error(); len(errMsg) > 0 {
responseErr = string(errMsg)
}
response = env.PayloadBytes()
done <- nil
return
}
case protocol.KindLog:
logMsg, err := protocol.ParseLog(data)
if err == nil {
p.log.Info("[plugin] %s", string(logMsg.Message()))
}
default:
done <- fmt.Errorf("%w: expected envelope, got kind %d", ErrUnexpectedMessage, kind)
return
}
}
}()
select {
case err := <-done:
if err != nil {
return nil, err
}
if responseErr != "" {
return nil, fmt.Errorf("%w: %s", ErrInvokeFailed, responseErr)
}
return response, nil
case <-time.After(invokeTimeout):
return nil, fmt.Errorf("invoke timeout")
}
}
// HasCapability checks if a capability is available.
func (m *Manager) HasCapability(name string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, exists := m.capabilities[name]
return exists
}
// ListCapabilities returns all registered capabilities.
func (m *Manager) ListCapabilities() []string {
m.mu.RLock()
defer m.mu.RUnlock()
caps := make([]string, 0, len(m.capabilities))
for name := range m.capabilities {
caps = append(caps, name)
}
return caps
}
// GetPlugin returns info about a running plugin.
func (m *Manager) GetPlugin(name string) (*PluginInfo, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
p, exists := m.plugins[name]
if !exists {
return nil, false
}
return &p.info, true
}
// ListRunning returns names of all running plugins.
func (m *Manager) ListRunning() []string {
m.mu.RLock()
defer m.mu.RUnlock()
names := make([]string, 0, len(m.plugins))
for name := range m.plugins {
names = append(names, name)
}
return names
}