Skip to content

Commit

Permalink
[CAPPL-20] Support per-method handlers in GatewayConnector
Browse files Browse the repository at this point in the history
Making GatewayConnector compatible with the new design, where each capability is able to add its own handler independently.
  • Loading branch information
bolekk committed Sep 6, 2024
1 parent 72f4cc8 commit 8138054
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 35 deletions.
5 changes: 5 additions & 0 deletions .changeset/thick-jobs-kneel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"chainlink": patch
---

Support per-method handlers in GatewayConnector
8 changes: 7 additions & 1 deletion core/scripts/gateway/connector/run_connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@ func main() {
sampleKey, _ := crypto.HexToECDSA("cd47d3fafdbd652dd2b66c6104fa79b372c13cb01f4a4fbfc36107cce913ac1d")
lggr, _ := logger.NewLogger()
client := &client{privateKey: sampleKey, lggr: lggr}
connector, _ := connector.NewGatewayConnector(&cfg, client, client, clockwork.NewRealClock(), lggr)
// client acts as a signer here
connector, _ := connector.NewGatewayConnector(&cfg, client, clockwork.NewRealClock(), lggr)
err = connector.AddHandler([]string{"test_method"}, client)
if err != nil {
fmt.Println("error adding handler:", err)
return
}
client.connector = connector

ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
Expand Down
40 changes: 30 additions & 10 deletions core/services/gateway/connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type GatewayConnector interface {
job.ServiceCtx
network.ConnectionInitiator

AddHandler(methods []string, handler GatewayConnectorHandler) error
SendToGateway(ctx context.Context, gatewayId string, msg *api.Message) error
}

Expand All @@ -51,7 +52,7 @@ type gatewayConnector struct {
clock clockwork.Clock
nodeAddress []byte
signer Signer
handler GatewayConnectorHandler
handlers map[string]GatewayConnectorHandler
gateways map[string]*gatewayState
urlToId map[string]string
closeWait sync.WaitGroup
Expand All @@ -76,8 +77,8 @@ type gatewayState struct {
wsClient network.WebSocketClient
}

func NewGatewayConnector(config *ConnectorConfig, signer Signer, handler GatewayConnectorHandler, clock clockwork.Clock, lggr logger.Logger) (GatewayConnector, error) {
if config == nil || signer == nil || handler == nil || clock == nil || lggr == nil {
func NewGatewayConnector(config *ConnectorConfig, signer Signer, clock clockwork.Clock, lggr logger.Logger) (GatewayConnector, error) {
if config == nil || signer == nil || clock == nil || lggr == nil {
return nil, errors.New("nil dependency")
}
if len(config.DonId) == 0 || len(config.DonId) > network.HandshakeDonIdLen {
Expand All @@ -93,7 +94,7 @@ func NewGatewayConnector(config *ConnectorConfig, signer Signer, handler Gateway
clock: clock,
nodeAddress: addressBytes,
signer: signer,
handler: handler,
handlers: make(map[string]GatewayConnectorHandler),
shutdownCh: make(chan struct{}),
lggr: lggr.Named("GatewayConnector"),
}
Expand Down Expand Up @@ -125,6 +126,22 @@ func NewGatewayConnector(config *ConnectorConfig, signer Signer, handler Gateway
return connector, nil
}

func (c *gatewayConnector) AddHandler(methods []string, handler GatewayConnectorHandler) error {
if handler == nil {
return errors.New("cannot add a nil handler")
}
for _, method := range methods {
if _, exists := c.handlers[method]; exists {
return fmt.Errorf("handler for method %s already exists", method)
}
}
// add all or nothing
for _, method := range methods {
c.handlers[method] = handler
}
return nil
}

func (c *gatewayConnector) SendToGateway(ctx context.Context, gatewayId string, msg *api.Message) error {
data, err := c.codec.EncodeResponse(msg)
if err != nil {
Expand Down Expand Up @@ -159,7 +176,12 @@ func (c *gatewayConnector) readLoop(gatewayState *gatewayState) {
c.lggr.Errorw("failed to validate message signature", "id", gatewayState.config.Id, "err", err)
break
}
c.handler.HandleGatewayMessage(ctx, gatewayState.config.Id, msg)
handler, exists := c.handlers[msg.Body.Method]
if !exists {
c.lggr.Errorw("no handler for method", "id", gatewayState.config.Id, "method", msg.Body.Method)
break
}
handler.HandleGatewayMessage(ctx, gatewayState.config.Id, msg)
}
}
}
Expand Down Expand Up @@ -194,9 +216,6 @@ func (c *gatewayConnector) reconnectLoop(gatewayState *gatewayState) {
func (c *gatewayConnector) Start(ctx context.Context) error {
return c.StartOnce("GatewayConnector", func() error {
c.lggr.Info("starting gateway connector")
if err := c.handler.Start(ctx); err != nil {
return err
}
for _, gatewayState := range c.gateways {
gatewayState := gatewayState
if err := gatewayState.conn.Start(ctx); err != nil {
Expand All @@ -214,11 +233,12 @@ func (c *gatewayConnector) Close() error {
return c.StopOnce("GatewayConnector", func() (err error) {
c.lggr.Info("closing gateway connector")
close(c.shutdownCh)
var errs error
for _, gatewayState := range c.gateways {
gatewayState.conn.Close()
errs = errors.Join(errs, gatewayState.conn.Close())
}
c.closeWait.Wait()
return c.handler.Close()
return errs
})
}

Expand Down
35 changes: 23 additions & 12 deletions core/services/gateway/connector/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/network"
)

const defaultConfig = `
const (
defaultConfig = `
NodeAddress = "0x68902d681c28119f9b2531473a417088bf008e59"
DonId = "example_don"
AuthMinChallengeLen = 10
Expand All @@ -32,6 +33,9 @@ URL = "ws://localhost:8081/node"
Id = "another_one"
URL = "wss://example.com:8090/node_endpoint"
`
testMethod1 = "test_method_1"
testMethod2 = "test_method_2"
)

func parseTOMLConfig(t *testing.T, tomlConfig string) *connector.ConnectorConfig {
var cfg connector.ConnectorConfig
Expand All @@ -40,12 +44,13 @@ func parseTOMLConfig(t *testing.T, tomlConfig string) *connector.ConnectorConfig
return &cfg
}

func newTestConnector(t *testing.T, config *connector.ConnectorConfig, now time.Time) (connector.GatewayConnector, *mocks.Signer, *mocks.GatewayConnectorHandler) {
func newTestConnector(t *testing.T, config *connector.ConnectorConfig) (connector.GatewayConnector, *mocks.Signer, *mocks.GatewayConnectorHandler) {
signer := mocks.NewSigner(t)
handler := mocks.NewGatewayConnectorHandler(t)
clock := clockwork.NewFakeClock()
connector, err := connector.NewGatewayConnector(config, signer, handler, clock, logger.TestLogger(t))
connector, err := connector.NewGatewayConnector(config, signer, clock, logger.TestLogger(t))
require.NoError(t, err)
require.NoError(t, connector.AddHandler([]string{testMethod1}, handler))
return connector, signer, handler
}

Expand All @@ -61,7 +66,7 @@ Id = "example_gateway"
URL = "ws://localhost:8081/node"
`)

newTestConnector(t, tomlConfig, time.Now())
newTestConnector(t, tomlConfig)
}

func TestGatewayConnector_NewGatewayConnector_InvalidConfig(t *testing.T) {
Expand Down Expand Up @@ -103,12 +108,11 @@ URL = "ws://localhost:8081/node"
}

signer := mocks.NewSigner(t)
handler := mocks.NewGatewayConnectorHandler(t)
clock := clockwork.NewFakeClock()
for name, config := range invalidCases {
config := config
t.Run(name, func(t *testing.T) {
_, err := connector.NewGatewayConnector(parseTOMLConfig(t, config), signer, handler, clock, logger.TestLogger(t))
_, err := connector.NewGatewayConnector(parseTOMLConfig(t, config), signer, clock, logger.TestLogger(t))
require.Error(t, err)
})
}
Expand All @@ -117,17 +121,15 @@ URL = "ws://localhost:8081/node"
func TestGatewayConnector_CleanStartAndClose(t *testing.T) {
t.Parallel()

connector, signer, handler := newTestConnector(t, parseTOMLConfig(t, defaultConfig), time.Now())
handler.On("Start", mock.Anything).Return(nil)
handler.On("Close").Return(nil)
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig))
signer.On("Sign", mock.Anything).Return(nil, errors.New("cannot sign"))
servicetest.Run(t, connector)
}

func TestGatewayConnector_NewAuthHeader_SignerError(t *testing.T) {
t.Parallel()

connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig), time.Now())
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig))
signer.On("Sign", mock.Anything).Return(nil, errors.New("cannot sign"))

url, err := url.Parse("ws://localhost:8081/node")
Expand All @@ -141,7 +143,7 @@ func TestGatewayConnector_NewAuthHeader_Success(t *testing.T) {

testSignature := make([]byte, network.HandshakeSignatureLen)
testSignature[1] = 0xfa
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig), time.Now())
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig))
signer.On("Sign", mock.Anything).Return(testSignature, nil)
url, err := url.Parse("ws://localhost:8081/node")
require.NoError(t, err)
Expand All @@ -157,7 +159,7 @@ func TestGatewayConnector_ChallengeResponse(t *testing.T) {
testSignature := make([]byte, network.HandshakeSignatureLen)
testSignature[1] = 0xfa
now := time.Now()
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig), now)
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig))
signer.On("Sign", mock.Anything).Return(testSignature, nil)
url, err := url.Parse("ws://localhost:8081/node")
require.NoError(t, err)
Expand Down Expand Up @@ -191,3 +193,12 @@ func TestGatewayConnector_ChallengeResponse(t *testing.T) {
_, err = connector.ChallengeResponse(url, network.PackChallenge(&badChallenge))
require.Equal(t, network.ErrAuthInvalidGateway, err)
}

func TestGatewayConnector_AddHandler(t *testing.T) {
t.Parallel()

connector, _, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig))
// testMethod1 already exists
require.Error(t, connector.AddHandler([]string{testMethod1}, mocks.NewGatewayConnectorHandler(t)))
require.NoError(t, connector.AddHandler([]string{testMethod2}, mocks.NewGatewayConnectorHandler(t)))
}
48 changes: 48 additions & 0 deletions core/services/gateway/connector/mocks/gateway_connector.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ func TestIntegration_Gateway_NoFullNodes_BasicConnectionAndMessage(t *testing.T)

// Launch Connector
client := &client{privateKey: nodeKeys.PrivateKey}
connector, err := connector.NewGatewayConnector(parseConnectorConfig(t, nodeConfigTemplate, nodeKeys.Address, nodeUrl), client, client, clockwork.NewRealClock(), lggr)
// client acts as a signer here
connector, err := connector.NewGatewayConnector(parseConnectorConfig(t, nodeConfigTemplate, nodeKeys.Address, nodeUrl), client, clockwork.NewRealClock(), lggr)
require.NoError(t, err)
require.NoError(t, connector.AddHandler([]string{"test"}, client))
client.connector = connector
servicetest.Run(t, connector)

Expand Down
25 changes: 16 additions & 9 deletions core/services/ocr2/plugins/functions/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/functions"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector"
hc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common"
hf "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions"
gwAllowlist "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/allowlist"
gwSubscriptions "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/subscriptions"
"github.com/smartcontractkit/chainlink/v2/core/services/job"
Expand Down Expand Up @@ -174,11 +175,12 @@ func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOra
return nil, errors.Wrap(err, "failed to create a OnchainSubscriptions")
}
connectorLogger := conf.Logger.Named("GatewayConnector").With("jobName", conf.Job.PipelineSpec.JobName)
connector, err2 := NewConnector(ctx, &pluginConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, connectorLogger)
connector, handler, err2 := NewConnector(ctx, &pluginConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, connectorLogger)
if err2 != nil {
return nil, errors.Wrap(err, "failed to create a GatewayConnector")
}
allServices = append(allServices, connector)
allServices = append(allServices, handler)
} else {
listenerLogger.Warn("Insufficient config, GatewayConnector will not be enabled")
}
Expand All @@ -201,29 +203,34 @@ func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOra
return allServices, nil
}

func NewConnector(ctx context.Context, pluginConfig *config.PluginConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwAllowlist.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwSubscriptions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, lggr logger.Logger) (connector.GatewayConnector, error) {
func NewConnector(ctx context.Context, pluginConfig *config.PluginConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwAllowlist.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwSubscriptions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, lggr logger.Logger) (connector.GatewayConnector, connector.GatewayConnectorHandler, error) {
enabledKeys, err := ethKeystore.EnabledKeysForChain(ctx, chainID)
if err != nil {
return nil, err
return nil, nil, err
}
configuredNodeAddress := common.HexToAddress(pluginConfig.GatewayConnectorConfig.NodeAddress)
idx := slices.IndexFunc(enabledKeys, func(key ethkey.KeyV2) bool { return key.Address == configuredNodeAddress })
if idx == -1 {
return nil, errors.New("key for configured node address not found")
return nil, nil, errors.New("key for configured node address not found")
}
signerKey := enabledKeys[idx].ToEcdsaPrivKey()
if enabledKeys[idx].ID() != pluginConfig.GatewayConnectorConfig.NodeAddress {
return nil, errors.New("node address mismatch")
return nil, nil, errors.New("node address mismatch")
}

handler, err := functions.NewFunctionsConnectorHandler(pluginConfig, signerKey, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, lggr)
if err != nil {
return nil, err
return nil, nil, err
}
connector, err := connector.NewGatewayConnector(pluginConfig.GatewayConnectorConfig, handler, handler, clockwork.NewRealClock(), lggr)
// handler acts as a signer here
connector, err := connector.NewGatewayConnector(pluginConfig.GatewayConnectorConfig, handler, clockwork.NewRealClock(), lggr)
if err != nil {
return nil, err
return nil, nil, err
}
err = connector.AddHandler([]string{hf.MethodSecretsSet, hf.MethodSecretsList, hf.MethodHeartbeat}, handler)
if err != nil {
return nil, nil, err
}
handler.SetConnector(connector)
return connector, nil
return connector, handler, nil
}
4 changes: 2 additions & 2 deletions core/services/ocr2/plugins/functions/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestNewConnector_Success(t *testing.T) {
config := &config.PluginConfig{
GatewayConnectorConfig: gwcCfg,
}
_, err = functions.NewConnector(ctx, config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t))
_, _, err = functions.NewConnector(ctx, config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t))
require.NoError(t, err)
}

Expand Down Expand Up @@ -78,6 +78,6 @@ func TestNewConnector_NoKeyForConfiguredAddress(t *testing.T) {
config := &config.PluginConfig{
GatewayConnectorConfig: gwcCfg,
}
_, err = functions.NewConnector(ctx, config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t))
_, _, err = functions.NewConnector(ctx, config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t))
require.Error(t, err)
}

0 comments on commit 8138054

Please sign in to comment.