syndilights/open-lighting-architecture/ola-0.8.4/common/rpc/StreamRpcChannel.cpp

608 lines
16 KiB
C++

/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Library General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*
* StreamRpcChannel.cpp
* Interface for the UDP RPC Channel
* Copyright (C) 2005-2008 Simon Newton
*/
#include <errno.h>
#include <google/protobuf/service.h>
#include <google/protobuf/message.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <string>
#include "common/rpc/Rpc.pb.h"
#include "common/rpc/SimpleRpcController.h"
#include "common/rpc/StreamRpcChannel.h"
#include "ola/Closure.h"
#include "ola/Logging.h"
namespace ola {
namespace rpc {
using google::protobuf::ServiceDescriptor;
const char StreamRpcChannel::K_RPC_RECEIVED_TYPE_VAR[] = "rpc-received-type";
const char StreamRpcChannel::K_RPC_RECEIVED_VAR[] = "rpc-received";
const char StreamRpcChannel::K_RPC_SENT_ERROR_VAR[] = "rpc-send-errors";
const char StreamRpcChannel::K_RPC_SENT_VAR[] = "rpc-sent";
const char StreamRpcChannel::STREAMING_NO_RESPONSE[] = "STREAMING_NO_RESPONSE";
StreamRpcChannel::StreamRpcChannel(Service *service,
ola::network::ConnectedSocket *socket,
ExportMap *export_map)
: m_service(service),
m_on_close(NULL),
m_socket(socket),
m_seq(0),
m_buffer(NULL),
m_buffer_size(0),
m_expected_size(0),
m_current_size(0),
m_export_map(export_map),
m_recv_type_map(NULL) {
socket->SetOnData(NewClosure(this, &StreamRpcChannel::SocketReady));
// init the counters
const char *vars[] = {
K_RPC_RECEIVED_VAR,
K_RPC_SENT_ERROR_VAR,
K_RPC_SENT_VAR,
};
if (m_export_map) {
for (unsigned int i = 0; i < sizeof(vars) / sizeof(vars[0]); ++i)
m_export_map->GetCounterVar(string(vars[i]));
m_recv_type_map = m_export_map->GetUIntMapVar(K_RPC_RECEIVED_TYPE_VAR,
"type");
}
}
StreamRpcChannel::~StreamRpcChannel() {
free(m_buffer);
}
/*
* Receive a message for this RPCChannel. Called when data is available on the
* socket.
*/
void StreamRpcChannel::SocketReady() {
if (!m_expected_size) {
// this is a new msg
unsigned int version;
if (ReadHeader(&version, &m_expected_size) < 0)
return;
if (!m_expected_size)
return;
if (version != PROTOCOL_VERSION) {
OLA_WARN << "protocol mismatch " << version << " != " <<
PROTOCOL_VERSION;
return;
}
m_current_size = 0;
m_buffer_size = AllocateMsgBuffer(m_expected_size);
if (m_buffer_size < m_expected_size) {
OLA_WARN << "buffer size to small " << m_buffer_size << " < " <<
m_expected_size;
return;
}
}
unsigned int data_read;
if (m_socket->Receive(m_buffer + m_current_size,
m_expected_size - m_current_size,
data_read) < 0) {
OLA_WARN << "something went wrong in socket recv\n";
return;
}
m_current_size += data_read;
if (m_current_size == m_expected_size) {
// we've got all of this message so parse it.
if (!HandleNewMsg(m_buffer, m_expected_size)) {
// this probably means we've messed the framing up, close the channel
OLA_WARN << "Errors detected on RPC channel, closing";
m_socket->Close();
}
m_expected_size = 0;
}
return;
}
/*
* Set the Closure to be called if a write on this channel fails. This is
* different from the Socket on close handler which is called when reads hit
* EOF/
*/
void StreamRpcChannel::SetOnClose(SingleUseClosure<void> *closure) {
if (closure != m_on_close) {
delete m_on_close;
m_on_close = closure;
}
}
/*
* Call a method with the given request and reply
* TODO(simonn): reduce the number of copies here
*/
void StreamRpcChannel::CallMethod(
const MethodDescriptor *method,
RpcController *controller,
const Message *request,
Message *reply,
google::protobuf::Closure *done) {
string output;
RpcMessage message;
bool is_streaming = false;
// Streaming methods are those with a reply set to STREAMING_NO_RESPONSE and
// no controller, request or closure provided
if (method->output_type()->name() == STREAMING_NO_RESPONSE) {
if (controller || reply || done) {
OLA_FATAL << "Calling streaming method " << method->name() <<
" but a controller, reply or closure in non-NULL";
return;
}
is_streaming = true;
}
message.set_type(is_streaming ? STREAM_REQUEST : REQUEST);
message.set_id(m_seq++);
message.set_name(method->name());
request->SerializeToString(&output);
message.set_buffer(output);
bool r = SendMsg(&message);
if (is_streaming)
return;
if (!r) {
// send failed, call the handler now
controller->SetFailed("Failed to send request");
done->Run();
return;
}
OutstandingResponse *response = GetOutstandingResponse(message.id());
if (response) {
// fail any outstanding response with the same id
OLA_WARN << "response " << response->id << " already pending, failing " <<
"now";
response->controller->SetFailed("Duplicate request found");
InvokeCallbackAndCleanup(response);
}
response = new OutstandingResponse();
response->id = message.id();
response->controller = controller;
response->callback = done;
response->reply = reply;
m_responses[message.id()] = response;
}
/*
* Called when a response is ready.
*/
void StreamRpcChannel::RequestComplete(OutstandingRequest *request) {
string output;
RpcMessage message;
if (request->controller->Failed()) {
SendRequestFailed(request);
return;
}
message.set_type(RESPONSE);
message.set_id(request->id);
request->response->SerializeToString(&output);
message.set_buffer(output);
SendMsg(&message);
DeleteOutstandingRequest(request);
}
// private
//-----------------------------------------------------------------------------
/*
* Write an RpcMessage to the write socket.
*/
bool StreamRpcChannel::SendMsg(RpcMessage *msg) {
if (m_socket->ReadDescriptor() == ola::network::Socket::INVALID_SOCKET) {
OLA_WARN << "RPC Socket closed, not sending messages";
return false;
}
string output;
msg->SerializeToString(&output);
int length = output.length();
uint32_t header;
StreamRpcHeader::EncodeHeader(&header, PROTOCOL_VERSION, length);
ssize_t ret = m_socket->Send(reinterpret_cast<const uint8_t*>(&header),
sizeof(header));
ret = m_socket->Send(reinterpret_cast<const uint8_t*>(output.data()),
length);
if (ret != length) {
if (ret == -1)
OLA_WARN << "Send failed " << strerror(errno);
else
OLA_WARN << "Failed to send full datagram, closing channel";
// At the point framing is screwed and we should shut the channel down
m_socket->Close();
if (m_on_close)
m_on_close->Run();
if (m_export_map)
(*m_export_map->GetCounterVar(K_RPC_SENT_ERROR_VAR))++;
return false;
}
if (m_export_map)
(*m_export_map->GetCounterVar(K_RPC_SENT_VAR))++;
return true;
}
/*
* Allocate an incomming message buffer
* @param size the size of the new buffer to allocate
* @returns the size of the new buffer
*/
int StreamRpcChannel::AllocateMsgBuffer(unsigned int size) {
unsigned int requested_size = size;
uint8_t *new_buffer;
if (size < m_buffer_size)
return size;
if (m_buffer_size == 0 && size < INITIAL_BUFFER_SIZE)
requested_size = INITIAL_BUFFER_SIZE;
if (requested_size > MAX_BUFFER_SIZE)
return m_buffer_size;
new_buffer = static_cast<uint8_t*>(realloc(m_buffer, requested_size));
if (new_buffer < 0)
return m_buffer_size;
m_buffer = new_buffer;
m_buffer_size = requested_size;
return requested_size;
}
/*
* Read 4 bytes and decode the header fields.
* @returns: -1 if there is no data is available, version and size are 0
*/
int StreamRpcChannel::ReadHeader(unsigned int *version,
unsigned int *size) const {
uint32_t header;
unsigned int data_read = 0;
*version = *size = 0;
if (m_socket->Receive(reinterpret_cast<uint8_t*>(&header),
sizeof(header), data_read)) {
OLA_WARN << "read header error: " << strerror(errno);
return -1;
}
if (!data_read)
return 0;
StreamRpcHeader::DecodeHeader(header, version, size);
return 0;
}
/*
* Parse a new message and handle it.
*/
bool StreamRpcChannel::HandleNewMsg(uint8_t *data, unsigned int size) {
RpcMessage msg;
if (!msg.ParseFromArray(data, size)) {
OLA_WARN << "Failed to parse RPC";
return false;
}
if (m_export_map)
(*m_export_map->GetCounterVar(K_RPC_RECEIVED_VAR))++;
switch (msg.type()) {
case REQUEST:
if (m_recv_type_map)
(*m_recv_type_map)["request"]++;
HandleRequest(&msg);
break;
case RESPONSE:
if (m_recv_type_map)
(*m_recv_type_map)["response"]++;
HandleResponse(&msg);
break;
case RESPONSE_CANCEL:
if (m_recv_type_map)
(*m_recv_type_map)["cancelled"]++;
HandleCanceledResponse(&msg);
break;
case RESPONSE_FAILED:
if (m_recv_type_map)
(*m_recv_type_map)["failed"]++;
HandleFailedResponse(&msg);
break;
case RESPONSE_NOT_IMPLEMENTED:
if (m_recv_type_map)
(*m_recv_type_map)["not-implemented"]++;
HandleNotImplemented(&msg);
break;
case STREAM_REQUEST:
if (m_recv_type_map)
(*m_recv_type_map)["stream_request"]++;
HandleStreamRequest(&msg);
break;
default:
OLA_WARN << "not sure of msg type " << msg.type();
break;
}
return true;
}
/*
* Handle a new RPC method call.
*/
void StreamRpcChannel::HandleRequest(RpcMessage *msg) {
if (!m_service) {
OLA_WARN << "no service registered";
return;
}
const ServiceDescriptor *service = m_service->GetDescriptor();
if (!service) {
OLA_WARN << "failed to get service descriptor";
return;
}
const MethodDescriptor *method = service->FindMethodByName(msg->name());
if (!method) {
OLA_WARN << "failed to get method descriptor";
SendNotImplemented(msg->id());
return;
}
Message* request_pb = m_service->GetRequestPrototype(method).New();
Message* response_pb = m_service->GetResponsePrototype(method).New();
if (!request_pb || !response_pb) {
OLA_WARN << "failed to get request or response objects";
return;
}
if (!request_pb->ParseFromString(msg->buffer())) {
OLA_WARN << "parsing of request pb failed";
return;
}
OutstandingRequest *request = new OutstandingRequest();
request->id = msg->id();
request->controller = new SimpleRpcController();
request->response = response_pb;
if (m_requests.find(msg->id()) != m_requests.end()) {
OLA_WARN << "dup sequence number for request " << msg->id();
SendRequestFailed(m_requests[msg->id()]);
}
m_requests[msg->id()] = request;
google::protobuf::Closure *callback = NewCallback(
this, &StreamRpcChannel::RequestComplete, request);
m_service->CallMethod(method, request->controller, request_pb, response_pb,
callback);
delete request_pb;
}
/*
* Handle a streaming RPC call. This doesn't return any response to the client.
*/
void StreamRpcChannel::HandleStreamRequest(RpcMessage *msg) {
if (!m_service) {
OLA_WARN << "no service registered";
return;
}
const ServiceDescriptor *service = m_service->GetDescriptor();
if (!service) {
OLA_WARN << "failed to get service descriptor";
return;
}
const MethodDescriptor *method = service->FindMethodByName(msg->name());
if (!method) {
OLA_WARN << "failed to get method descriptor";
SendNotImplemented(msg->id());
return;
}
if (method->output_type()->name() != STREAMING_NO_RESPONSE) {
OLA_WARN << "Streaming request recieved for " << method->name() <<
", but the output type isn't STREAMING_NO_RESPONSE";
return;
}
Message* request_pb = m_service->GetRequestPrototype(method).New();
if (!request_pb) {
OLA_WARN << "failed to get request or response objects";
return;
}
if (!request_pb->ParseFromString(msg->buffer())) {
OLA_WARN << "parsing of request pb failed";
return;
}
m_service->CallMethod(method, NULL, request_pb, NULL, NULL);
delete request_pb;
}
// server side
/*
* Notify the caller that the request failed.
*/
void StreamRpcChannel::SendRequestFailed(OutstandingRequest *request) {
RpcMessage message;
message.set_type(RESPONSE_FAILED);
message.set_id(request->id);
message.set_buffer(request->controller->ErrorText());
SendMsg(&message);
DeleteOutstandingRequest(request);
}
/*
* Sent if we get a request for a non-existant method.
*/
void StreamRpcChannel::SendNotImplemented(int msg_id) {
RpcMessage message;
message.set_type(RESPONSE_NOT_IMPLEMENTED);
message.set_id(msg_id);
SendMsg(&message);
}
/*
* Cleanup an outstanding request after the response has been returned
*/
void StreamRpcChannel::DeleteOutstandingRequest(OutstandingRequest *request) {
m_requests.erase(request->id);
delete request->controller;
delete request->response;
delete request;
}
// client side methods
/*
* Handle a RPC response by invoking the callback.
*/
void StreamRpcChannel::HandleResponse(RpcMessage *msg) {
OutstandingResponse *response = GetOutstandingResponse(msg->id());
if (response) {
response->reply->ParseFromString(msg->buffer());
InvokeCallbackAndCleanup(response);
}
}
/*
* Handle a RPC response by invoking the callback.
*/
void StreamRpcChannel::HandleFailedResponse(RpcMessage *msg) {
OutstandingResponse *response = GetOutstandingResponse(msg->id());
if (response) {
response->controller->SetFailed(msg->buffer());
InvokeCallbackAndCleanup(response);
}
}
/*
* Handle a RPC response by invoking the callback.
*/
void StreamRpcChannel::HandleCanceledResponse(RpcMessage *msg) {
OLA_INFO << "Received a canceled response";
OutstandingResponse *response = GetOutstandingResponse(msg->id());
if (response) {
response->controller->SetFailed(msg->buffer());
InvokeCallbackAndCleanup(response);
}
}
/*
* Handle a NOT_IMPLEMENTED by invoking the callback.
*/
void StreamRpcChannel::HandleNotImplemented(RpcMessage *msg) {
OLA_INFO << "Received a non-implemented response";
OutstandingResponse *response = GetOutstandingResponse(msg->id());
if (response) {
response->controller->SetFailed("Not Implemented");
InvokeCallbackAndCleanup(response);
}
}
/*
* Find the outstanding response with id msg_id.
*/
OutstandingResponse *StreamRpcChannel::GetOutstandingResponse(int msg_id) {
if (m_responses.find(msg_id) != m_responses.end()) {
return m_responses[msg_id];
}
return NULL;
}
/*
* Run the callback for a request.
*/
void StreamRpcChannel::InvokeCallbackAndCleanup(OutstandingResponse *response) {
if (response) {
int id = response->id;
response->callback->Run();
delete response;
m_responses.erase(id);
}
}
// StreamRpcHeader
//--------------------------------------------------------------
/**
* Encode a header
*/
void StreamRpcHeader::EncodeHeader(uint32_t *header, unsigned int version,
unsigned int size) {
*header = (version << 28) & VERSION_MASK;
*header |= size & SIZE_MASK;
}
/**
* Decode a header
*/
void StreamRpcHeader::DecodeHeader(uint32_t header, unsigned int *version,
unsigned int *size) {
*version = (header & VERSION_MASK) >> 28;
*size = header & SIZE_MASK;
}
} // rpc
} // ola