package host
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
"shep/internal/fbs/shep"
"shep/internal/logger"
"shep/internal/protocol"
)
const (
pluginPrefix = "shep-plugin-"
handshakeTimeout = 5 * time.Second
shutdownTimeout = 3 * time.Second
invokeTimeout = 5 * time.Second
)
var (
ErrPluginNotFound = errors.New("plugin not found")
ErrHandshakeFailed = errors.New("handshake failed")
ErrPluginNotRunning = errors.New("plugin not running")
ErrUnexpectedMessage = errors.New("unexpected message type")
ErrProtocolMismatch = errors.New("protocol version mismatch")
ErrHandshakeRejected = errors.New("handshake rejected by host")
ErrPluginStartFailed = errors.New("plugin start failed")
ErrPluginCommunication = errors.New("plugin communication error")
ErrCapabilityNotFound = errors.New("capability not found")
ErrInvokeFailed = errors.New("capability invocation failed")
)
// CapabilityInfo describes a capability provided by a plugin.
type CapabilityInfo struct {
Name string
Version int
}
// PluginInfo contains metadata about a discovered plugin.
type PluginInfo struct {
Name string
Path string
Version int
Capabilities []CapabilityInfo
}
// Plugin represents a running plugin process.
type Plugin struct {
info PluginInfo
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Reader
log logger.Logger
mu sync.Mutex
}
// Manager handles plugin discovery, lifecycle, and communication.
type Manager struct {
pluginDir string
plugins map[string]*Plugin
capabilities map[string]*Plugin // capability name -> plugin
log logger.Logger
mu sync.RWMutex
}
// NewManager creates a new plugin manager.
func NewManager(log logger.Logger) *Manager {
// Default plugin dir is same directory as the executable
execPath, err := os.Executable()
if err != nil {
execPath = "."
}
pluginDir := filepath.Dir(execPath)
return &Manager{
pluginDir: pluginDir,
plugins: make(map[string]*Plugin),
capabilities: make(map[string]*Plugin),
log: log.WithPrefix("plugin-manager"),
}
}
// SetPluginDir sets the directory to search for plugins.
func (m *Manager) SetPluginDir(dir string) {
m.pluginDir = dir
}
// Discover finds all available plugins in the plugin directory.
func (m *Manager) Discover() ([]string, error) {
entries, err := os.ReadDir(m.pluginDir)
if err != nil {
return nil, fmt.Errorf("read plugin directory: %w", err)
}
var plugins []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
if name, found := strings.CutPrefix(entry.Name(), pluginPrefix); found {
plugins = append(plugins, name)
}
}
return plugins, nil
}
// Start starts a plugin by name and performs handshake.
func (m *Manager) Start(name string) (*PluginInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Check if already running
if p, exists := m.plugins[name]; exists {
return &p.info, nil
}
pluginPath := filepath.Join(m.pluginDir, pluginPrefix+name)
if _, err := os.Stat(pluginPath); os.IsNotExist(err) {
return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, name)
}
m.log.Info("starting plugin: %s", name)
cmd := exec.Command(pluginPath)
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("%w: stdin pipe: %v", ErrPluginStartFailed, err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
stdin.Close()
return nil, fmt.Errorf("%w: stdout pipe: %v", ErrPluginStartFailed, err)
}
// Redirect stderr to our stderr for debugging
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
stdin.Close()
return nil, fmt.Errorf("%w: %v", ErrPluginStartFailed, err)
}
plugin := &Plugin{
info: PluginInfo{
Name: name,
Path: pluginPath,
},
cmd: cmd,
stdin: stdin,
stdout: bufio.NewReader(stdout),
log: m.log.WithPrefix(name),
}
// Perform handshake
if err := m.doHandshake(plugin); err != nil {
plugin.stop()
return nil, fmt.Errorf("%w: %v", ErrHandshakeFailed, err)
}
m.plugins[name] = plugin
// Register capabilities
for _, cap := range plugin.info.Capabilities {
m.capabilities[cap.Name] = plugin
m.log.Debug("registered capability: %s (version %d) from plugin %s",
cap.Name, cap.Version, name)
}
m.log.Info("plugin started: %s (version=%d, capabilities=%v)",
name, plugin.info.Version, plugin.info.Capabilities)
return &plugin.info, nil
}
func (m *Manager) doHandshake(p *Plugin) error {
// Read handshake from plugin (plugin sends first)
done := make(chan error, 1)
var hs *shep.Handshake
var data []byte
go func() {
kind, msgData, err := protocol.ReadMessage(p.stdout)
if err != nil {
done <- err
return
}
if kind != protocol.KindHandshake {
done <- fmt.Errorf("%w: expected handshake, got kind %d", ErrUnexpectedMessage, kind)
return
}
data = msgData
hs, err = protocol.ParseHandshake(data)
done <- err
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("read handshake: %w", err)
}
case <-time.After(handshakeTimeout):
return fmt.Errorf("handshake timeout")
}
// Check protocol version
if hs.ProtocolVersion() != protocol.ProtocolVersion {
// Send rejection
resp := protocol.BuildHandshakeResponse(shep.StatusFailure,
fmt.Sprintf("protocol version mismatch: expected %d, got %d",
protocol.ProtocolVersion, hs.ProtocolVersion()))
protocol.WriteMessage(p.stdin, protocol.KindHandshakeResponse, resp)
return fmt.Errorf("%w: expected %d, got %d",
ErrProtocolMismatch, protocol.ProtocolVersion, hs.ProtocolVersion())
}
// Store plugin info
p.info.Version = int(hs.ProtocolVersion())
for i := 0; i < hs.CapabilitiesLength(); i++ {
cap := new(shep.Capability)
if hs.Capabilities(cap, i) {
p.info.Capabilities = append(p.info.Capabilities, CapabilityInfo{
Name: string(cap.Name()),
Version: int(cap.Version()),
})
}
}
// Send handshake response (acceptance)
resp := protocol.BuildHandshakeResponse(shep.StatusOk, "")
if err := protocol.WriteMessage(p.stdin, protocol.KindHandshakeResponse, resp); err != nil {
return fmt.Errorf("send handshake response: %w", err)
}
return nil
}
// Stop stops a running plugin.
func (m *Manager) Stop(name string) error {
m.mu.Lock()
defer m.mu.Unlock()
p, exists := m.plugins[name]
if !exists {
return ErrPluginNotRunning
}
// Unregister capabilities
for _, cap := range p.info.Capabilities {
delete(m.capabilities, cap.Name)
}
delete(m.plugins, name)
return p.stop()
}
// StopAll stops all running plugins.
func (m *Manager) StopAll() {
m.mu.Lock()
defer m.mu.Unlock()
for name, p := range m.plugins {
m.log.Info("stopping plugin: %s", name)
p.stop()
}
m.plugins = make(map[string]*Plugin)
m.capabilities = make(map[string]*Plugin)
}
func (p *Plugin) stop() error {
// Send shutdown message
shutdownMsg := protocol.BuildShutdown("host shutdown")
protocol.WriteMessage(p.stdin, protocol.KindShutdown, shutdownMsg)
p.stdin.Close()
// Wait for process to exit
done := make(chan error, 1)
go func() {
done <- p.cmd.Wait()
}()
select {
case <-done:
return nil
case <-time.After(shutdownTimeout):
p.cmd.Process.Kill()
return nil
}
}
// Invoke calls a capability by name with the given payload.
// This is the primary way to interact with plugins - by capability, not by plugin name.
func (m *Manager) Invoke(capability, correlationID string, payload []byte) ([]byte, error) {
m.mu.RLock()
p, exists := m.capabilities[capability]
m.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("%w: %s", ErrCapabilityNotFound, capability)
}
p.mu.Lock()
defer p.mu.Unlock()
// Send envelope request
envelope := protocol.BuildEnvelope(capability, correlationID, false, "", payload)
if err := protocol.WriteMessage(p.stdin, protocol.KindEnvelope, envelope); err != nil {
return nil, fmt.Errorf("%w: send envelope: %v", ErrPluginCommunication, err)
}
// Wait for response (with timeout)
done := make(chan error, 1)
var response []byte
var responseErr string
go func() {
for {
kind, data, err := protocol.ReadMessage(p.stdout)
if err != nil {
done <- err
return
}
switch kind {
case protocol.KindEnvelope:
env, err := protocol.ParseEnvelope(data)
if err != nil {
done <- err
return
}
// Check correlation ID
if string(env.CorrelationId()) == correlationID && env.IsResponse() {
if errMsg := env.Error(); len(errMsg) > 0 {
responseErr = string(errMsg)
}
response = env.PayloadBytes()
done <- nil
return
}
case protocol.KindLog:
logMsg, err := protocol.ParseLog(data)
if err == nil {
p.log.Info("[plugin] %s", string(logMsg.Message()))
}
default:
done <- fmt.Errorf("%w: expected envelope, got kind %d", ErrUnexpectedMessage, kind)
return
}
}
}()
select {
case err := <-done:
if err != nil {
return nil, err
}
if responseErr != "" {
return nil, fmt.Errorf("%w: %s", ErrInvokeFailed, responseErr)
}
return response, nil
case <-time.After(invokeTimeout):
return nil, fmt.Errorf("invoke timeout")
}
}
// HasCapability checks if a capability is available.
func (m *Manager) HasCapability(name string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, exists := m.capabilities[name]
return exists
}
// ListCapabilities returns all registered capabilities.
func (m *Manager) ListCapabilities() []string {
m.mu.RLock()
defer m.mu.RUnlock()
caps := make([]string, 0, len(m.capabilities))
for name := range m.capabilities {
caps = append(caps, name)
}
return caps
}
// GetPlugin returns info about a running plugin.
func (m *Manager) GetPlugin(name string) (*PluginInfo, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
p, exists := m.plugins[name]
if !exists {
return nil, false
}
return &p.info, true
}
// ListRunning returns names of all running plugins.
func (m *Manager) ListRunning() []string {
m.mu.RLock()
defer m.mu.RUnlock()
names := make([]string, 0, len(m.plugins))
for name := range m.plugins {
names = append(names, name)
}
return names
}
package logger
import (
"fmt"
"io"
"os"
"sync"
"time"
)
// Level represents log severity levels.
type Level int
const (
LevelDebug Level = iota
LevelInfo
LevelWarn
LevelError
)
func (l Level) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
// Logger defines the logging interface used throughout shep.
type Logger interface {
Debug(msg string, args ...any)
Info(msg string, args ...any)
Warn(msg string, args ...any)
Error(msg string, args ...any)
WithPrefix(prefix string) Logger
}
// StandardLogger implements Logger writing to an io.Writer.
type StandardLogger struct {
out io.Writer
prefix string
level Level
mu sync.Mutex
}
// Option configures a StandardLogger.
type Option func(*StandardLogger)
// WithOutput sets the output writer.
func WithOutput(w io.Writer) Option {
return func(l *StandardLogger) {
l.out = w
}
}
// WithLevel sets the minimum log level.
func WithLevel(level Level) Option {
return func(l *StandardLogger) {
l.level = level
}
}
// WithLogPrefix sets a prefix for log messages.
func WithLogPrefix(prefix string) Option {
return func(l *StandardLogger) {
l.prefix = prefix
}
}
// New creates a new StandardLogger with the given options.
func New(opts ...Option) *StandardLogger {
l := &StandardLogger{
out: os.Stderr,
level: LevelInfo,
}
for _, opt := range opts {
opt(l)
}
return l
}
func (l *StandardLogger) log(level Level, msg string, args ...any) {
if level < l.level {
return
}
l.mu.Lock()
defer l.mu.Unlock()
ts := time.Now().Format("15:04:05")
prefix := ""
if l.prefix != "" {
prefix = "[" + l.prefix + "] "
}
formatted := msg
if len(args) > 0 {
formatted = fmt.Sprintf(msg, args...)
}
fmt.Fprintf(l.out, "%s %s %s%s\n", ts, level.String(), prefix, formatted)
}
func (l *StandardLogger) Debug(msg string, args ...any) {
l.log(LevelDebug, msg, args...)
}
func (l *StandardLogger) Info(msg string, args ...any) {
l.log(LevelInfo, msg, args...)
}
func (l *StandardLogger) Warn(msg string, args ...any) {
l.log(LevelWarn, msg, args...)
}
func (l *StandardLogger) Error(msg string, args ...any) {
l.log(LevelError, msg, args...)
}
func (l *StandardLogger) WithPrefix(prefix string) Logger {
newPrefix := prefix
if l.prefix != "" {
newPrefix = l.prefix + "/" + prefix
}
return &StandardLogger{
out: l.out,
prefix: newPrefix,
level: l.level,
}
}
// NopLogger is a logger that does nothing.
type NopLogger struct{}
func (NopLogger) Debug(msg string, args ...any) {}
func (NopLogger) Info(msg string, args ...any) {}
func (NopLogger) Warn(msg string, args ...any) {}
func (NopLogger) Error(msg string, args ...any) {}
func (NopLogger) WithPrefix(prefix string) Logger {
return NopLogger{}
}
package protocol
import (
"encoding/binary"
"errors"
"fmt"
"io"
"time"
flatbuffers "github.com/google/flatbuffers/go"
"shep/internal/fbs/shep"
)
const (
// ProtocolVersion is the current protocol version.
ProtocolVersion = 1
// MaxMessageSize is the maximum allowed message size (1MB).
MaxMessageSize = 1 << 20
)
// MessageKind identifies the type of framed message (lifecycle vs capability).
type MessageKind byte
const (
KindHandshake MessageKind = 1
KindHandshakeResponse MessageKind = 2
KindEnvelope MessageKind = 3
KindLog MessageKind = 4
KindShutdown MessageKind = 5
)
var (
ErrMessageTooLarge = errors.New("message too large")
ErrInvalidMessage = errors.New("invalid message")
ErrUnknownKind = errors.New("unknown message kind")
)
// CapabilityInfo describes a capability declared by a plugin.
type CapabilityInfo struct {
Name string
Version int
}
// WriteMessage writes a framed message with kind prefix.
func WriteMessage(w io.Writer, kind MessageKind, data []byte) error {
if len(data) > MaxMessageSize {
return ErrMessageTooLarge
}
// Write kind (1 byte) + length (4 bytes big endian) + data
header := make([]byte, 5)
header[0] = byte(kind)
binary.BigEndian.PutUint32(header[1:], uint32(len(data)))
if _, err := w.Write(header); err != nil {
return fmt.Errorf("write header: %w", err)
}
if _, err := w.Write(data); err != nil {
return fmt.Errorf("write data: %w", err)
}
return nil
}
// ReadMessage reads a framed message and returns the kind and data.
func ReadMessage(r io.Reader) (MessageKind, []byte, error) {
// Read kind + length
header := make([]byte, 5)
if _, err := io.ReadFull(r, header); err != nil {
return 0, nil, fmt.Errorf("read header: %w", err)
}
kind := MessageKind(header[0])
msgLen := binary.BigEndian.Uint32(header[1:])
if msgLen > MaxMessageSize {
return 0, nil, ErrMessageTooLarge
}
data := make([]byte, msgLen)
if _, err := io.ReadFull(r, data); err != nil {
return 0, nil, fmt.Errorf("read data: %w", err)
}
return kind, data, nil
}
// BuildHandshake creates a handshake message.
func BuildHandshake(name string, capabilities []CapabilityInfo) []byte {
builder := flatbuffers.NewBuilder(256)
// Build capabilities
capOffsets := make([]flatbuffers.UOffsetT, len(capabilities))
for i, cap := range capabilities {
nameOffset := builder.CreateString(cap.Name)
shep.CapabilityStart(builder)
shep.CapabilityAddName(builder, nameOffset)
shep.CapabilityAddVersion(builder, int32(cap.Version))
capOffsets[i] = shep.CapabilityEnd(builder)
}
shep.HandshakeStartCapabilitiesVector(builder, len(capabilities))
for i := len(capOffsets) - 1; i >= 0; i-- {
builder.PrependUOffsetT(capOffsets[i])
}
capsVec := builder.EndVector(len(capabilities))
// Build handshake
nameOffset := builder.CreateString(name)
shep.HandshakeStart(builder)
shep.HandshakeAddProtocolVersion(builder, ProtocolVersion)
shep.HandshakeAddPluginName(builder, nameOffset)
shep.HandshakeAddCapabilities(builder, capsVec)
offset := shep.HandshakeEnd(builder)
builder.Finish(offset)
return builder.FinishedBytes()
}
// ParseHandshake parses a handshake message.
func ParseHandshake(data []byte) (*shep.Handshake, error) {
if len(data) == 0 {
return nil, ErrInvalidMessage
}
return shep.GetRootAsHandshake(data, 0), nil
}
// BuildHandshakeResponse creates a handshake response message.
func BuildHandshakeResponse(status shep.Status, errMsg string) []byte {
builder := flatbuffers.NewBuilder(256)
var errOffset flatbuffers.UOffsetT
if errMsg != "" {
errOffset = builder.CreateString(errMsg)
}
shep.HandshakeResponseStart(builder)
shep.HandshakeResponseAddProtocolVersion(builder, ProtocolVersion)
shep.HandshakeResponseAddStatus(builder, status)
if errMsg != "" {
shep.HandshakeResponseAddErrorMessage(builder, errOffset)
}
offset := shep.HandshakeResponseEnd(builder)
builder.Finish(offset)
return builder.FinishedBytes()
}
// ParseHandshakeResponse parses a handshake response message.
func ParseHandshakeResponse(data []byte) (*shep.HandshakeResponse, error) {
if len(data) == 0 {
return nil, ErrInvalidMessage
}
return shep.GetRootAsHandshakeResponse(data, 0), nil
}
// BuildEnvelope creates a capability envelope message.
func BuildEnvelope(capability, correlationID string, isResponse bool, errMsg string, payload []byte) []byte {
builder := flatbuffers.NewBuilder(256 + len(payload))
capOffset := builder.CreateString(capability)
corrOffset := builder.CreateString(correlationID)
var errOffset flatbuffers.UOffsetT
if errMsg != "" {
errOffset = builder.CreateString(errMsg)
}
var payloadOffset flatbuffers.UOffsetT
if len(payload) > 0 {
payloadOffset = builder.CreateByteVector(payload)
}
shep.EnvelopeStart(builder)
shep.EnvelopeAddProtocolVersion(builder, ProtocolVersion)
shep.EnvelopeAddCapability(builder, capOffset)
shep.EnvelopeAddCorrelationId(builder, corrOffset)
shep.EnvelopeAddIsResponse(builder, isResponse)
if errMsg != "" {
shep.EnvelopeAddError(builder, errOffset)
}
if len(payload) > 0 {
shep.EnvelopeAddPayload(builder, payloadOffset)
}
offset := shep.EnvelopeEnd(builder)
builder.Finish(offset)
return builder.FinishedBytes()
}
// ParseEnvelope parses an envelope message.
func ParseEnvelope(data []byte) (*shep.Envelope, error) {
if len(data) == 0 {
return nil, ErrInvalidMessage
}
return shep.GetRootAsEnvelope(data, 0), nil
}
// BuildLog creates a log message.
func BuildLog(level shep.LogLevel, message string) []byte {
builder := flatbuffers.NewBuilder(256)
msgOffset := builder.CreateString(message)
shep.LogMessageStart(builder)
shep.LogMessageAddLevel(builder, level)
shep.LogMessageAddMessage(builder, msgOffset)
shep.LogMessageAddTimestamp(builder, time.Now().UnixMilli())
offset := shep.LogMessageEnd(builder)
builder.Finish(offset)
return builder.FinishedBytes()
}
// ParseLog parses a log message.
func ParseLog(data []byte) (*shep.LogMessage, error) {
if len(data) == 0 {
return nil, ErrInvalidMessage
}
return shep.GetRootAsLogMessage(data, 0), nil
}
// BuildShutdown creates a shutdown message.
func BuildShutdown(reason string) []byte {
builder := flatbuffers.NewBuilder(128)
reasonOffset := builder.CreateString(reason)
shep.ShutdownStart(builder)
shep.ShutdownAddReason(builder, reasonOffset)
offset := shep.ShutdownEnd(builder)
builder.Finish(offset)
return builder.FinishedBytes()
}
// ParseShutdown parses a shutdown message.
func ParseShutdown(data []byte) (*shep.Shutdown, error) {
if len(data) == 0 {
return nil, ErrInvalidMessage
}
return shep.GetRootAsShutdown(data, 0), nil
}