290 lines
7.1 KiB
Go
290 lines
7.1 KiB
Go
package transport
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
pb "github.com/jeremytregunna/kevo/proto/kevo"
|
|
"github.com/jeremytregunna/kevo/pkg/transport"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/keepalive"
|
|
)
|
|
|
|
const (
|
|
defaultDialTimeout = 5 * time.Second
|
|
defaultConnectTimeout = 5 * time.Second
|
|
defaultKeepAliveTime = 15 * time.Second
|
|
defaultKeepAlivePolicy = 5 * time.Second
|
|
defaultMaxConnIdle = 60 * time.Second
|
|
defaultMaxConnAge = 5 * time.Minute
|
|
)
|
|
|
|
type GRPCTransportManager struct {
|
|
opts *GRPCTransportOptions
|
|
server *grpc.Server
|
|
listener net.Listener
|
|
connections sync.Map // map[string]*grpc.ClientConn
|
|
mu sync.RWMutex
|
|
metrics *transport.Metrics
|
|
}
|
|
|
|
type GRPCTransportOptions struct {
|
|
ListenAddr string
|
|
TLSConfig *tls.Config
|
|
ConnectionTimeout time.Duration
|
|
DialTimeout time.Duration
|
|
KeepAliveTime time.Duration
|
|
KeepAliveTimeout time.Duration
|
|
MaxConnectionIdle time.Duration
|
|
MaxConnectionAge time.Duration
|
|
MaxPoolConnections int
|
|
}
|
|
|
|
var _ transport.TransportManager = (*GRPCTransportManager)(nil)
|
|
|
|
func DefaultGRPCTransportOptions() *GRPCTransportOptions {
|
|
return &GRPCTransportOptions{
|
|
ListenAddr: ":50051",
|
|
ConnectionTimeout: defaultConnectTimeout,
|
|
DialTimeout: defaultDialTimeout,
|
|
KeepAliveTime: defaultKeepAliveTime,
|
|
KeepAliveTimeout: defaultKeepAlivePolicy,
|
|
MaxConnectionIdle: defaultMaxConnIdle,
|
|
MaxConnectionAge: defaultMaxConnAge,
|
|
}
|
|
}
|
|
|
|
func NewGRPCTransportManager(opts *GRPCTransportOptions) (*GRPCTransportManager, error) {
|
|
if opts == nil {
|
|
opts = DefaultGRPCTransportOptions()
|
|
}
|
|
|
|
metrics := transport.NewMetrics("grpc")
|
|
|
|
return &GRPCTransportManager{
|
|
opts: opts,
|
|
metrics: metrics,
|
|
}, nil
|
|
}
|
|
|
|
func (g *GRPCTransportManager) Start(ctx context.Context) error {
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
|
|
if g.server != nil {
|
|
return fmt.Errorf("gRPC transport already started")
|
|
}
|
|
|
|
var serverOpts []grpc.ServerOption
|
|
|
|
// Configure TLS if provided
|
|
if g.opts.TLSConfig != nil {
|
|
serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(g.opts.TLSConfig)))
|
|
}
|
|
|
|
// Configure keepalive parameters
|
|
kaProps := keepalive.ServerParameters{
|
|
MaxConnectionIdle: g.opts.MaxConnectionIdle,
|
|
MaxConnectionAge: g.opts.MaxConnectionAge,
|
|
Time: g.opts.KeepAliveTime,
|
|
Timeout: g.opts.KeepAliveTimeout,
|
|
}
|
|
kaPolicy := keepalive.EnforcementPolicy{
|
|
MinTime: g.opts.KeepAliveTime / 2,
|
|
PermitWithoutStream: true,
|
|
}
|
|
serverOpts = append(serverOpts,
|
|
grpc.KeepaliveParams(kaProps),
|
|
grpc.KeepaliveEnforcementPolicy(kaPolicy),
|
|
)
|
|
|
|
// Create and start the gRPC server
|
|
g.server = grpc.NewServer(serverOpts...)
|
|
|
|
// Register service implementations
|
|
// This will be implemented later once we have the service implementation
|
|
// pb.RegisterKevoServiceServer(g.server, &kevoServiceServer{})
|
|
|
|
// Start listening
|
|
listener, err := net.Listen("tcp", g.opts.ListenAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen on %s: %w", g.opts.ListenAddr, err)
|
|
}
|
|
g.listener = listener
|
|
|
|
// Start server in a goroutine
|
|
go func() {
|
|
g.metrics.ServerStarted()
|
|
if err := g.server.Serve(listener); err != nil {
|
|
g.metrics.ServerErrored()
|
|
// Just log the error, as this is running in a goroutine
|
|
// and we can't return it
|
|
fmt.Printf("gRPC server stopped: %v\n", err)
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (g *GRPCTransportManager) Stop(ctx context.Context) error {
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
|
|
if g.server == nil {
|
|
return nil
|
|
}
|
|
|
|
// Close all client connections
|
|
g.connections.Range(func(key, value interface{}) bool {
|
|
conn := value.(*grpc.ClientConn)
|
|
conn.Close()
|
|
g.connections.Delete(key)
|
|
return true
|
|
})
|
|
|
|
// Gracefully stop the server
|
|
stopped := make(chan struct{})
|
|
go func() {
|
|
g.server.GracefulStop()
|
|
close(stopped)
|
|
}()
|
|
|
|
// Wait for graceful stop or context deadline
|
|
select {
|
|
case <-stopped:
|
|
// Server stopped gracefully
|
|
case <-ctx.Done():
|
|
// Context deadline exceeded, force stop
|
|
g.server.Stop()
|
|
}
|
|
|
|
g.metrics.ServerStopped()
|
|
g.server = nil
|
|
return nil
|
|
}
|
|
|
|
func (g *GRPCTransportManager) Connect(ctx context.Context, address string) (transport.Connection, error) {
|
|
g.mu.RLock()
|
|
defer g.mu.RUnlock()
|
|
|
|
// Check if we already have a connection to this address
|
|
if conn, ok := g.connections.Load(address); ok {
|
|
return &GRPCConnection{
|
|
conn: conn.(*grpc.ClientConn),
|
|
address: address,
|
|
metrics: g.metrics,
|
|
}, nil
|
|
}
|
|
|
|
// Set connection options
|
|
dialOptions := []grpc.DialOption{
|
|
grpc.WithBlock(),
|
|
}
|
|
|
|
// Add TLS if configured
|
|
if g.opts.TLSConfig != nil {
|
|
dialOptions = append(dialOptions, grpc.WithTransportCredentials(
|
|
credentials.NewTLS(g.opts.TLSConfig),
|
|
))
|
|
} else {
|
|
dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
}
|
|
|
|
// Add keepalive options
|
|
dialOptions = append(dialOptions, grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
|
Time: g.opts.KeepAliveTime,
|
|
Timeout: g.opts.KeepAliveTimeout,
|
|
PermitWithoutStream: true,
|
|
}))
|
|
|
|
// Set timeout for connection
|
|
dialCtx, cancel := context.WithTimeout(ctx, g.opts.DialTimeout)
|
|
defer cancel()
|
|
|
|
// Dial the server
|
|
conn, err := grpc.DialContext(dialCtx, address, dialOptions...)
|
|
if err != nil {
|
|
g.metrics.ConnectionFailed()
|
|
return nil, fmt.Errorf("failed to connect to %s: %w", address, err)
|
|
}
|
|
|
|
// Store the connection
|
|
g.connections.Store(address, conn)
|
|
g.metrics.ConnectionOpened()
|
|
|
|
return &GRPCConnection{
|
|
conn: conn,
|
|
address: address,
|
|
metrics: g.metrics,
|
|
}, nil
|
|
}
|
|
|
|
func (g *GRPCTransportManager) RegisterService(service interface{}) error {
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
|
|
if g.server == nil {
|
|
return fmt.Errorf("server not started, cannot register service")
|
|
}
|
|
|
|
// Type assert to KevoServiceServer and register
|
|
kevoService, ok := service.(pb.KevoServiceServer)
|
|
if !ok {
|
|
return fmt.Errorf("service is not a valid KevoServiceServer implementation")
|
|
}
|
|
|
|
pb.RegisterKevoServiceServer(g.server, kevoService)
|
|
return nil
|
|
}
|
|
|
|
// GRPCConnection represents a gRPC client connection
|
|
type GRPCConnection struct {
|
|
conn *grpc.ClientConn
|
|
address string
|
|
metrics *transport.Metrics
|
|
}
|
|
|
|
func (c *GRPCConnection) Close() error {
|
|
err := c.conn.Close()
|
|
c.metrics.ConnectionClosed()
|
|
return err
|
|
}
|
|
|
|
func (c *GRPCConnection) GetClient() interface{} {
|
|
return pb.NewKevoServiceClient(c.conn)
|
|
}
|
|
|
|
func (c *GRPCConnection) Address() string {
|
|
return c.address
|
|
}
|
|
|
|
func init() {
|
|
transport.RegisterTransport("grpc", func(opts map[string]interface{}) (transport.TransportManager, error) {
|
|
// Convert generic options map to GRPCTransportOptions
|
|
options := DefaultGRPCTransportOptions()
|
|
|
|
if addr, ok := opts["listen_addr"].(string); ok {
|
|
options.ListenAddr = addr
|
|
}
|
|
|
|
if timeout, ok := opts["connection_timeout"].(time.Duration); ok {
|
|
options.ConnectionTimeout = timeout
|
|
}
|
|
|
|
if timeout, ok := opts["dial_timeout"].(time.Duration); ok {
|
|
options.DialTimeout = timeout
|
|
}
|
|
|
|
if tlsConfig, ok := opts["tls_config"].(*tls.Config); ok {
|
|
options.TLSConfig = tlsConfig
|
|
}
|
|
|
|
return NewGRPCTransportManager(options)
|
|
})
|
|
} |