diff --git a/lib/async/redis/client.rb b/lib/async/redis/client.rb index 5b5af20..8af83ae 100755 --- a/lib/async/redis/client.rb +++ b/lib/async/redis/client.rb @@ -10,7 +10,7 @@ require_relative 'context/pipeline' require_relative 'context/transaction' require_relative 'context/subscribe' -require_relative 'protocol/resp2' +require_relative 'endpoint' require 'io/endpoint/host_endpoint' require 'async/pool/controller' @@ -23,14 +23,10 @@ module Redis # Legacy. ServerError = ::Protocol::Redis::ServerError - def self.local_endpoint(port: 6379) - ::IO::Endpoint.tcp('localhost', port) - end - class Client include ::Protocol::Redis::Methods - def initialize(endpoint = Redis.local_endpoint, protocol: Protocol::RESP2, **options) + def initialize(endpoint = Endpoint.local, protocol: endpoint.protocol, **options) @endpoint = endpoint @protocol = protocol diff --git a/lib/async/redis/endpoint.rb b/lib/async/redis/endpoint.rb new file mode 100644 index 0000000..a60a9f8 --- /dev/null +++ b/lib/async/redis/endpoint.rb @@ -0,0 +1,252 @@ +# frozen_string_literal: true + +# Released under the MIT License. +# Copyright, 2024, by Samuel Williams. + +require 'io/endpoint' +require 'io/endpoint/host_endpoint' +require 'io/endpoint/ssl_endpoint' + +require_relative 'protocol/resp2' +require_relative 'protocol/authenticated' +require_relative 'protocol/selected' + +module Async + module Redis + def self.local_endpoint(**options) + Endpoint.local(**options) + end + + # Represents a way to connect to a remote Redis server. + class Endpoint < ::IO::Endpoint::Generic + LOCALHOST = URI.parse("redis://localhost").freeze + + def self.local(**options) + self.new(LOCALHOST, **options) + end + + SCHEMES = { + 'redis' => URI::Generic, + 'rediss' => URI::Generic, + } + + def self.parse(string, endpoint = nil, **options) + url = URI.parse(string).normalize + + return self.new(url, endpoint, **options) + end + + # Construct an endpoint with a specified scheme, hostname, optional path, and options. + # + # @parameter scheme [String] The scheme to use, e.g. "redis" or "rediss". + # @parameter hostname [String] The hostname to connect to (or bind to). + # @parameter *options [Hash] Additional options, passed to {#initialize}. + def self.for(scheme, hostname, credentials: nil, port: nil, database: nil, **options) + uri_klass = SCHEMES.fetch(scheme.downcase) do + raise ArgumentError, "Unsupported scheme: #{scheme.inspect}" + end + + if database + path = "/#{database}" + end + + self.new( + uri_klass.new(scheme, credentials&.join(":"), hostname, port, nil, path, nil, nil, nil).normalize, + **options + ) + end + + # Coerce the given object into an endpoint. + # @parameter url [String | Endpoint] The URL or endpoint to convert. + def self.[](object) + if object.is_a?(self) + return object + else + self.parse(object.to_s) + end + end + + # Create a new endpoint. + # + # @parameter url [URI] The URL to connect to. + # @parameter endpoint [Endpoint] The underlying endpoint to use. + # @parameter scheme [String] The scheme to use, e.g. "redis" or "rediss". + # @parameter hostname [String] The hostname to connect to (or bind to), overrides the URL hostname (used for SNI). + # @parameter port [Integer] The port to bind to, overrides the URL port. + def initialize(url, endpoint = nil, **options) + super(**options) + + raise ArgumentError, "URL must be absolute (include scheme, host): #{url}" unless url.absolute? + + @url = url + + if endpoint + @endpoint = self.build_endpoint(endpoint) + else + @endpoint = nil + end + end + + def to_url + url = @url.dup + + unless default_port? + url.port = self.port + end + + return url + end + + def to_s + "\#<#{self.class} #{self.to_url} #{@options}>" + end + + def inspect + "\#<#{self.class} #{self.to_url} #{@options.inspect}>" + end + + attr :url + + def address + endpoint.address + end + + def secure? + ['rediss'].include?(self.scheme) + end + + def protocol + protocol = @options.fetch(:protocol, Protocol::RESP2) + + if database = self.database + protocol = Protocol::Selected.new(database, protocol) + end + + if credentials = self.credentials + protocol = Protocol::Authenticated.new(credentials, protocol) + end + + return protocol + end + + def default_port + 6379 + end + + def default_port? + port == default_port + end + + def port + @options[:port] || @url.port || default_port + end + + # The hostname is the server we are connecting to: + def hostname + @options[:hostname] || @url.hostname + end + + def scheme + @options[:scheme] || @url.scheme + end + + def database + @options[:database] || @url.path[1..-1].to_i + end + + def credentials + @options[:credentials] || @url.userinfo&.split(":") + end + + def localhost? + @url.hostname =~ /^(.*?\.)?localhost\.?$/ + end + + # We don't try to validate peer certificates when talking to localhost because they would always be self-signed. + def ssl_verify_mode + if self.localhost? + OpenSSL::SSL::VERIFY_NONE + else + OpenSSL::SSL::VERIFY_PEER + end + end + + def ssl_context + @options[:ssl_context] || OpenSSL::SSL::SSLContext.new.tap do |context| + context.set_params( + verify_mode: self.ssl_verify_mode + ) + end + end + + def build_endpoint(endpoint = nil) + endpoint ||= tcp_endpoint + + if secure? + # Wrap it in SSL: + return ::IO::Endpoint::SSLEndpoint.new(endpoint, + ssl_context: self.ssl_context, + hostname: @url.hostname, + timeout: self.timeout, + ) + end + + return endpoint + end + + def endpoint + @endpoint ||= build_endpoint + end + + def endpoint=(endpoint) + @endpoint = build_endpoint(endpoint) + end + + def bind(*arguments, &block) + endpoint.bind(*arguments, &block) + end + + def connect(&block) + endpoint.connect(&block) + end + + def each + return to_enum unless block_given? + + self.tcp_endpoint.each do |endpoint| + yield self.class.new(@url, endpoint, **@options) + end + end + + def key + [@url, @options] + end + + def eql? other + self.key.eql? other.key + end + + def hash + self.key.hash + end + + protected + + def tcp_options + options = @options.dup + + options.delete(:scheme) + options.delete(:port) + options.delete(:hostname) + options.delete(:ssl_context) + options.delete(:protocol) + + return options + end + + def tcp_endpoint + ::IO::Endpoint.tcp(self.hostname, port, **tcp_options) + end + end + end +end diff --git a/lib/async/redis/protocol/authenticated.rb b/lib/async/redis/protocol/authenticated.rb index 49b8f9d..2e7177b 100644 --- a/lib/async/redis/protocol/authenticated.rb +++ b/lib/async/redis/protocol/authenticated.rb @@ -18,7 +18,7 @@ class AuthenticationError < StandardError # # @parameter credentials [Array] The credentials to use for authentication. # @parameter protocol [Object] The delegated protocol for connecting. - def initialize(credentials, protocol: Async::Redis::Protocol::RESP2) + def initialize(credentials, protocol = Async::Redis::Protocol::RESP2) @credentials = credentials @protocol = protocol end diff --git a/lib/async/redis/protocol/selected.rb b/lib/async/redis/protocol/selected.rb index 1880ad4..c9f84da 100644 --- a/lib/async/redis/protocol/selected.rb +++ b/lib/async/redis/protocol/selected.rb @@ -18,7 +18,7 @@ class SelectionError < StandardError # # @parameter index [Integer] The database index to select. # @parameter protocol [Object] The delegated protocol for connecting. - def initialize(index, protocol: Async::Redis::Protocol::RESP2) + def initialize(index, protocol = Async::Redis::Protocol::RESP2) @index = index @protocol = protocol end diff --git a/test/async/redis/disconnect.rb b/test/async/redis/disconnect.rb index e7bb003..cc9e966 100644 --- a/test/async/redis/disconnect.rb +++ b/test/async/redis/disconnect.rb @@ -10,18 +10,30 @@ describe Async::Redis::Client do include Sus::Fixtures::Async::ReactorContext - - let(:endpoint) {::IO::Endpoint.tcp('localhost', 5555)} - + + # Intended to not be connected: + let(:endpoint) {Async::Redis::Endpoint.local(port: 5555)} + + before do + @server_endpoint = ::IO::Endpoint.tcp("localhost").bound + end + + after do + @server_endpoint&.close + end + it "should raise error on unexpected disconnect" do - server_task = reactor.async do - endpoint.accept do |connection| + server_task = Async do + @server_endpoint.accept do |connection| connection.read(8) connection.close end end - - client = Async::Redis::Client.new(endpoint) + + client = Async::Redis::Client.new( + @server_endpoint.local_address_endpoint, + protocol: Async::Redis::Protocol::RESP2, + ) expect do client.call("GET", "test") diff --git a/test/async/redis/endpoint.rb b/test/async/redis/endpoint.rb new file mode 100644 index 0000000..d5c2a2f --- /dev/null +++ b/test/async/redis/endpoint.rb @@ -0,0 +1,43 @@ +# frozen_string_literal: true + +# Released under the MIT License. +# Copyright, 2024, by Samuel Williams. + +require 'async/redis/client' +require 'async/redis/protocol/authenticated' +require 'sus/fixtures/async' + +describe Async::Redis::Protocol::Authenticated do + include Sus::Fixtures::Async::ReactorContext + + let(:endpoint) {Async::Redis.local_endpoint} + let(:credentials) {["testuser", "testpassword"]} + let(:protocol) {subject.new(credentials)} + let(:client) {Async::Redis::Client.new(endpoint, protocol: protocol)} + + before do + # Setup ACL user with limited permissions for testing. + admin_client = Async::Redis::Client.new(endpoint) + admin_client.call("ACL", "SETUSER", "testuser", "on", ">" + credentials[1], "+ping", "+auth") + ensure + admin_client.close + end + + after do + # Cleanup ACL user after tests. + admin_client = Async::Redis::Client.new(endpoint) + admin_client.call("ACL", "DELUSER", "testuser") + admin_client.close + end + + it "can authenticate and send allowed commands" do + response = client.call("PING") + expect(response).to be == "PONG" + end + + it "rejects commands not allowed by ACL" do + expect do + client.call("SET", "key", "value") + end.to raise_exception(Protocol::Redis::ServerError, message: be =~ /NOPERM/) + end +end