1 module tinyredis.subscriber;
2 
3 /**
4  * Authors: Ali Çehreli, acehreli@yahoo.com
5  */
6 
7 import tinyredis.response : Response;
8 import tinyredis.encoder : toMultiBulk;
9 import tinyredis.connection : receiveResponses;
10 import std.socket : TcpSocket, InternetAddress, SocketShutdown;
11 import std.stdio : stderr, writefln;
12 import std.array : empty, front, popFront;
13 import std.algorithm : find, any, min;
14 import std.conv : to;
15 
16 // Regular subscription callback
17 alias Callback = void delegate(string channel, string message),
18 
19 // Pattern subscription callback
20 	PCallback = void delegate(string pattern, string channel, string message);
21 
22 /**
23  * Whether a response is of a particular message type
24  */
25 bool isType(string type)(Response r)
26 {
27 	return r.values[0].value == type;
28 }
29 
30 class Subscriber
31 {
32 private:
33 	TcpSocket conn;
34 	Callback[string] callbacks;      // Regular subscription callbacks
35 	PCallback[string] pCallbacks;    // Pattern subscription callbacks
36 	Response[][] queue;              // Responses collected but not yet processed
37 
38 	/**
39 	 * Send a redis command.
40 	 */
41 	void send(string cmd)
42 	{
43 		// XXX - Do we need toMultiBulk here?
44 		conn.send(toMultiBulk(cmd));
45 	}
46 
47 	/**
48 	 * Poll responses from the redis server and queue for later processing unless they match the
49 	 * predicate.
50 	 *
51 	 * This function is the workhorse behind all member functions of this type.
52 	 *
53 	 * @param pred - The predicate function that determines whether a response is an expected one
54 	 * @param expected - The number of responses expected to match the predicate
55 	 * @return - The last response that matched the predicate
56 	 */
57 	Response queueUnless(bool delegate(Response) pred, size_t expected = 1)
58 	{
59 		Response resp;
60 		size_t matched = 0;
61 
62 		/* We will receive responses until all 'expected' responses are found. */
63 
64 		// TODO - Timeout?
65 		while (matched < expected) {
66 			Response[] responses = receiveResponses(conn, 1);
67 
68 			// This group may have zero or many matching responses
69 
70 			while (!responses.empty) {
71 				auto found = responses.find!pred;
72 
73 				// Enqueue older responses for later processing
74 				queue ~= responses[0 .. $ - found.length];
75 
76 				if (found.empty)
77 					break;
78 				
79 				resp = found.front;
80 				responses = found[1 .. $];
81 				++matched;
82 			}
83 		}
84 
85 		return resp;
86 	}
87 
88 	/**
89 	 * Convenience wrapper for queueUnless(), which constructs a delegate from the provided message
90 	 * type.
91 	 */
92 	Response queueUnlessType(string type)(size_t expected = 1)
93 	{
94 		return queueUnless(r => r.isType!type, expected);
95 	}
96 
97 	/**
98 	 * Process a single message
99 	 */
100 	void processMessage(Response resp)
101 	{
102 		auto elements = resp.values;
103 
104 		/* Nested convenience function */
105 		void reportBadResponse()
106 		{
107 			stderr.writefln("Unexpected subscription response: %s", resp);
108 		}
109 
110 		/* Nested convenience function returning response element at the specified index */
111 		string element(size_t index)
112 		{
113 			return elements[index].value;
114 		}
115 
116 		string type = element(0);
117 
118 		switch (type)
119 		{
120 		case "message":
121 			if (elements.length != 3) {
122 				reportBadResponse();
123 				break;
124 			}
125 			string channel = element(1);
126 			const callback = (channel in callbacks);
127 
128 			if (callback)
129 			{
130 				string message = element(2);
131 				(*callback)(channel, message);
132 			}
133 			else
134 				stderr.writefln("No callback for message: %s", resp);
135 			break;
136 
137 		case "pmessage":
138 			if (elements.length != 4) {
139 				reportBadResponse();
140 				break;
141 			}
142 			string pattern = element(1);
143 			const callback = (pattern in pCallbacks);
144 
145 			if (callback) {
146 				string channel = element(2);
147 				string message = element(3);
148 
149 				(*callback)(pattern, channel, message);
150 			}
151 			else
152 				stderr.writefln("No callback for pattern message: %s", resp);
153 			break;
154 
155 		default:
156 			reportBadResponse();
157 			break;
158 		}
159 	}
160 
161 public:
162 
163 	/**
164 	 * Create a new non-blocking subscriber using a Redis host and port
165 	 */
166 	this(string host = "127.0.0.1", ushort port = 6379)
167 	{
168 		conn = new TcpSocket(new InternetAddress(host, port));
169 		conn.blocking = false;
170 	}
171 
172 	/**
173 	 * Create a new subscriber using an existing socket
174 	 */
175 	this(TcpSocket conn) { this.conn = conn; }
176 
177 	/**
178 	 * Subscribe to a channel
179 	 *
180 	 * Returns the number of channels currently subscribed to
181 	 */
182 	size_t subscribe(string channel, Callback callback)
183 	{
184 		auto cmd = "SUBSCRIBE " ~ channel;
185 		send(cmd);
186 
187 		Response resp = queueUnlessType!"subscribe"();
188 		callbacks[channel] = callback;
189 
190 		return resp.values[2].to!int;
191 	}
192 
193 	/**
194 	 * Subscribe to a channel pattern
195 	 *
196 	 * Returns the number of channels currently subscribed to
197 	 */
198 	size_t psubscribe(string pattern, PCallback callback)
199 	{
200 		auto cmd = "PSUBSCRIBE " ~ pattern;
201 		send(cmd);
202 
203 		Response resp = queueUnlessType!"psubscribe"();
204 		pCallbacks[pattern] = callback;
205 
206 		return resp.values[2].to!int;
207 	}
208 
209 	/**
210 	 * Unsubscribe from a channel
211 	 *
212 	 * Returns the number of channels currently subscribed to
213 	 */
214 	size_t unsubscribe(string channel)
215 	{
216 		auto cmd = "UNSUBSCRIBE " ~ channel;
217 		send(cmd);
218 
219 		Response resp = queueUnlessType!"unsubscribe"();
220 		callbacks.remove(channel);
221 
222 		return resp.values[2].to!int;
223 	}
224 
225 	/**
226 	 * Unsubscribe from all channels
227 	 *
228 	 * Returns the number of channels currently subscribed to
229 	 */
230 	size_t unsubscribe()
231 	{
232 		send("UNSUBSCRIBE");
233 
234 		Response resp = queueUnlessType!"unsubscribe"(callbacks.length);
235 		callbacks = null;
236 
237 		return resp.values[2].to!int;
238 	}
239 
240 	/**
241 	 * Unsubscribe from a channel pattern
242 	 *
243 	 * Returns the number of channels currently subscribed to
244 	 */
245 	size_t punsubscribe(string pattern)
246 	{
247 		auto cmd = "PUNSUBSCRIBE " ~ pattern;
248 		send(cmd);
249 
250 		Response resp = queueUnlessType!"punsubscribe"();
251 		pCallbacks.remove(pattern);
252 
253 		return resp.values[2].to!int;
254 	}
255 
256 	/**
257 	 * Unsubscribe from all channel patterns
258 	 *
259 	 * Returns the number of channels currently subscribed to
260 	 */
261 	size_t punsubscribe()
262 	{
263 		send("PUNSUBSCRIBE");
264 
265 		Response resp = queueUnlessType!"punsubscribe"(pCallbacks.length);
266 		pCallbacks = null;
267 
268 		return resp.values[2].to!int;
269 	}
270 
271 	/**
272 	 * Close the redis connection
273 	 */
274 	Response quit()
275 	{
276 		send("QUIT");
277 
278 		return queueUnless(r => r.value == "OK");
279 	}
280 
281 	/**
282 	 * Send a PING command
283 	 */
284 	Response ping(string argument = null)
285 	{
286 		auto cmd = "PING " ~ argument;
287 
288 		send(cmd);
289 		return queueUnless(r => r.isType!"pong");
290 	}
291 
292 	/**
293 	 * Poll for queued messages on the redis server and call their callbacks
294 	 */
295 	void processMessages()
296 	{
297 		queue ~= receiveResponses(conn, 0);
298 
299 		foreach (arr; queue) {
300 			foreach (resp; arr)
301 				processMessage(resp);
302 		}
303 
304 		queue.length = 0;
305 		queue.assumeSafeAppend();
306 	}
307 }