|
|
|
@ -5,6 +5,53 @@ |
|
|
|
|
|
|
|
#include "rpc_transport.h" |
|
|
|
|
|
|
|
#ifdef _WIN32 |
|
|
|
/*
|
|
|
|
* 初始化Windows套接字库 |
|
|
|
*/ |
|
|
|
int rpc_winsock_init() { |
|
|
|
WSADATA wsaData; |
|
|
|
int result = WSAStartup(MAKEWORD(2, 2), &wsaData); |
|
|
|
if (result != 0) { |
|
|
|
fprintf(stderr, "WSAStartup failed: %d\n", result); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} |
|
|
|
return RPC_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
/*
|
|
|
|
* 清理Windows套接字库 |
|
|
|
*/ |
|
|
|
void rpc_winsock_cleanup() { |
|
|
|
WSACleanup(); |
|
|
|
} |
|
|
|
|
|
|
|
/*
|
|
|
|
* Windows平台的错误打印函数 |
|
|
|
*/ |
|
|
|
static void print_windows_error(const char* message) { |
|
|
|
int error_code = WSAGetLastError(); |
|
|
|
char* error_text = NULL; |
|
|
|
FormatMessageA( |
|
|
|
FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, |
|
|
|
NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), |
|
|
|
(LPSTR)&error_text, 0, NULL); |
|
|
|
if (error_text) { |
|
|
|
fprintf(stderr, "%s: %s\n", message, error_text); |
|
|
|
LocalFree(error_text); |
|
|
|
} else { |
|
|
|
fprintf(stderr, "%s: Error code %d\n", message, error_code); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#define PRINT_ERROR(msg) print_windows_error(msg) |
|
|
|
#else |
|
|
|
/*
|
|
|
|
* Linux/Unix平台的错误打印函数 |
|
|
|
*/ |
|
|
|
#define PRINT_ERROR(msg) perror(msg) |
|
|
|
#endif |
|
|
|
|
|
|
|
/*
|
|
|
|
* 初始化RPC服务器 |
|
|
|
*/ |
|
|
|
@ -15,17 +62,17 @@ int rpc_server_init(rpc_server_t* server, const char* host, uint16_t port, int b |
|
|
|
|
|
|
|
// 创建套接字
|
|
|
|
server->server_fd = socket(AF_INET, SOCK_STREAM, 0); |
|
|
|
if (server->server_fd < 0) { |
|
|
|
perror("socket creation failed"); |
|
|
|
if (server->server_fd == INVALID_SOCKET_VALUE) { |
|
|
|
PRINT_ERROR("socket creation failed"); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
// 设置套接字选项,允许地址重用
|
|
|
|
int opt = 1; |
|
|
|
if (setsockopt(server->server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, |
|
|
|
&opt, sizeof(opt))) { |
|
|
|
perror("setsockopt failed"); |
|
|
|
close(server->server_fd); |
|
|
|
&opt, sizeof(opt)) != 0) { |
|
|
|
PRINT_ERROR("setsockopt failed"); |
|
|
|
CLOSE_SOCKET(server->server_fd); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
@ -36,16 +83,16 @@ int rpc_server_init(rpc_server_t* server, const char* host, uint16_t port, int b |
|
|
|
server->address.sin_port = htons(port); |
|
|
|
|
|
|
|
// 绑定地址到套接字
|
|
|
|
if (bind(server->server_fd, (struct sockaddr*)&server->address, sizeof(server->address)) < 0) { |
|
|
|
perror("bind failed"); |
|
|
|
close(server->server_fd); |
|
|
|
if (bind(server->server_fd, (struct sockaddr*)&server->address, sizeof(server->address)) != 0) { |
|
|
|
PRINT_ERROR("bind failed"); |
|
|
|
CLOSE_SOCKET(server->server_fd); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
// 开始监听连接
|
|
|
|
if (listen(server->server_fd, backlog) < 0) { |
|
|
|
perror("listen failed"); |
|
|
|
close(server->server_fd); |
|
|
|
if (listen(server->server_fd, backlog) != 0) { |
|
|
|
PRINT_ERROR("listen failed"); |
|
|
|
CLOSE_SOCKET(server->server_fd); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
@ -63,8 +110,8 @@ int rpc_server_accept(rpc_server_t* server, rpc_transport_t* transport) { |
|
|
|
|
|
|
|
socklen_t addrlen = sizeof(transport->address); |
|
|
|
transport->socket_fd = accept(server->server_fd, (struct sockaddr*)&transport->address, &addrlen); |
|
|
|
if (transport->socket_fd < 0) { |
|
|
|
perror("accept failed"); |
|
|
|
if (transport->socket_fd == INVALID_SOCKET_VALUE) { |
|
|
|
PRINT_ERROR("accept failed"); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
@ -78,12 +125,12 @@ int rpc_server_accept(rpc_server_t* server, rpc_transport_t* transport) { |
|
|
|
* 关闭RPC服务器 |
|
|
|
*/ |
|
|
|
void rpc_server_close(rpc_server_t* server) { |
|
|
|
if (!server || server->server_fd < 0) { |
|
|
|
if (!server || server->server_fd == INVALID_SOCKET_VALUE) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
close(server->server_fd); |
|
|
|
server->server_fd = -1; |
|
|
|
CLOSE_SOCKET(server->server_fd); |
|
|
|
server->server_fd = INVALID_SOCKET_VALUE; |
|
|
|
printf("RPC Server closed\n"); |
|
|
|
} |
|
|
|
|
|
|
|
@ -97,8 +144,8 @@ int rpc_client_init(rpc_transport_t* transport, const char* server_host, uint16_ |
|
|
|
|
|
|
|
// 创建套接字
|
|
|
|
transport->socket_fd = socket(AF_INET, SOCK_STREAM, 0); |
|
|
|
if (transport->socket_fd < 0) { |
|
|
|
perror("socket creation failed"); |
|
|
|
if (transport->socket_fd == INVALID_SOCKET_VALUE) { |
|
|
|
PRINT_ERROR("socket creation failed"); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
@ -109,15 +156,15 @@ int rpc_client_init(rpc_transport_t* transport, const char* server_host, uint16_ |
|
|
|
|
|
|
|
// 将主机名转换为IP地址
|
|
|
|
if (inet_pton(AF_INET, server_host, &transport->address.sin_addr) <= 0) { |
|
|
|
perror("invalid address"); |
|
|
|
close(transport->socket_fd); |
|
|
|
PRINT_ERROR("invalid address"); |
|
|
|
CLOSE_SOCKET(transport->socket_fd); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
// 连接到服务器
|
|
|
|
if (connect(transport->socket_fd, (struct sockaddr*)&transport->address, sizeof(transport->address)) < 0) { |
|
|
|
perror("connection failed"); |
|
|
|
close(transport->socket_fd); |
|
|
|
if (connect(transport->socket_fd, (struct sockaddr*)&transport->address, sizeof(transport->address)) != 0) { |
|
|
|
PRINT_ERROR("connection failed"); |
|
|
|
CLOSE_SOCKET(transport->socket_fd); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
@ -136,9 +183,19 @@ int rpc_transport_send(rpc_transport_t* transport, const void* data, size_t data |
|
|
|
// 发送所有数据
|
|
|
|
size_t sent_bytes = 0; |
|
|
|
while (sent_bytes < data_size) { |
|
|
|
#ifdef _WIN32 |
|
|
|
int bytes = send(transport->socket_fd, (const char*)data + sent_bytes, ( |
|
|
|
#ifdef _WIN64 |
|
|
|
int |
|
|
|
#else |
|
|
|
int |
|
|
|
#endif |
|
|
|
)(data_size - sent_bytes), 0); |
|
|
|
#else |
|
|
|
ssize_t bytes = send(transport->socket_fd, (const char*)data + sent_bytes, data_size - sent_bytes, 0); |
|
|
|
if (bytes < 0) { |
|
|
|
perror("send failed"); |
|
|
|
#endif |
|
|
|
if (bytes <= 0) { |
|
|
|
PRINT_ERROR("send failed"); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} |
|
|
|
sent_bytes += bytes; |
|
|
|
@ -158,9 +215,19 @@ int rpc_transport_recv(rpc_transport_t* transport, void* buffer, size_t buffer_s |
|
|
|
// 接收所有数据
|
|
|
|
size_t recv_bytes = 0; |
|
|
|
while (recv_bytes < buffer_size) { |
|
|
|
#ifdef _WIN32 |
|
|
|
int bytes = recv(transport->socket_fd, (char*)buffer + recv_bytes, ( |
|
|
|
#ifdef _WIN64 |
|
|
|
int |
|
|
|
#else |
|
|
|
int |
|
|
|
#endif |
|
|
|
)(buffer_size - recv_bytes), 0); |
|
|
|
#else |
|
|
|
ssize_t bytes = recv(transport->socket_fd, (char*)buffer + recv_bytes, buffer_size - recv_bytes, 0); |
|
|
|
#endif |
|
|
|
if (bytes < 0) { |
|
|
|
perror("recv failed"); |
|
|
|
PRINT_ERROR("recv failed"); |
|
|
|
return RPC_NET_ERROR; |
|
|
|
} else if (bytes == 0) { |
|
|
|
// 连接关闭
|
|
|
|
@ -176,10 +243,10 @@ int rpc_transport_recv(rpc_transport_t* transport, void* buffer, size_t buffer_s |
|
|
|
* 关闭传输连接 |
|
|
|
*/ |
|
|
|
void rpc_transport_close(rpc_transport_t* transport) { |
|
|
|
if (!transport || transport->socket_fd < 0) { |
|
|
|
if (!transport || transport->socket_fd == INVALID_SOCKET_VALUE) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
close(transport->socket_fd); |
|
|
|
transport->socket_fd = -1; |
|
|
|
CLOSE_SOCKET(transport->socket_fd); |
|
|
|
transport->socket_fd = INVALID_SOCKET_VALUE; |
|
|
|
} |