前提

最近对网络编程方面比较有兴趣,在微服务实践上也用到了相对主流的RPC框架如Spring Cloud Gateway底层也切换为Reactor-Netty,像Redisson底层也是使用Netty封装通讯协议,最近调研和准备使用的SOFARpc也是基于Netty封装实现了多种协议的兼容。因此,基于Netty造一个轮子,在SpringBoot的加持下,实现一个轻量级的RPC框架。这篇博文介绍的是RPC框架协议的定义以及对应的编码解码处理的实现。

依赖引入

截止本文(2020-01-12)编写完成之时,Netty的最新版本为4.1.44.Final,而SpringBoot的最新版本为2.2.2.RELEASE,因此引入这两个版本的依赖,加上其他工具包和序列化等等的支持,pom文件的核心内容如下:

<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-dependencies</artifactId>
<version>${spring.boot.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.10</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.61</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>28.1-jre</version>
</dependency>
</dependencies>

部分参数的序列化会依赖到FastJson或者Jackson,具体看偏好而定。

自定义协议的定义

为了提高协议传输的效率,需要定制一套高效的RPC协议,设计协议所需的字段和类型。

基础Packet字段

字段名 字段类型 字节数(byte) 字段功能 备注
magicNumber int 2 魔数,类似于Java的字节码文件的魔数是0xcafebase
version int 2 版本号 预留字段,默认为1
serialNumber java.lang.String 4 请求流水号 十分重要,每个请求的唯一标识
messageType MessageType 1 消息类型 自定义的枚举类型,见下面的MessageType
attachments Map<String, String> 视实际情况而定 附件 K-V形式,类似于HTTP协议中的Header
// 消息枚举类型
@RequiredArgsConstructor
public enum MessageType {

/**
* 请求
*/
REQUEST((byte) 1),

/**
* 响应
*/
RESPONSE((byte) 2),

/**
* PING
*/
PING((byte) 3),

/**
* PONG
*/
PONG((byte) 4),

/**
* NULL
*/
NULL((byte) 5),

;

@Getter
private final Byte type;

public static MessageType fromValue(byte value) {
for (MessageType type : MessageType.values()) {
if (type.getType() == value) {
return type;
}
}
throw new IllegalArgumentException(String.format("value = %s", value));
}
}

// 基础Packet
@Data
public abstract class BaseMessagePacket implements Serializable {

/**
* 魔数
*/
private int magicNumber;

/**
* 版本号
*/
private int version;

/**
* 流水号
*/
private String serialNumber;

/**
* 消息类型
*/
private MessageType messageType;

/**
* 附件 - K-V形式
*/
private Map<String, String> attachments = new HashMap<>();

/**
* 添加附件
*/
public void addAttachment(String key, String value) {
attachments.put(key, value);
}
}

请求Packet扩展字段

字段名 字段类型 字节数(byte) 字段功能 备注
interfaceName java.lang.String 视实际情况而定 接口全类名
methodName java.lang.String 视实际情况而定 方法名
methodArgumentSignatures java.lang.String[] 视实际情况而定 方法参数签名字符串数组 存放方法参数类型全类名字符串数组
methodArguments java.lang.Object[] 视实际情况而定 方法参数数组 因为未知方法参数类型,所以用Object表示
@EqualsAndHashCode(callSuper = true)
@Data
public class RequestMessagePacket extends BaseMessagePacket {

/**
* 接口全类名
*/
private String interfaceName;

/**
* 方法名
*/
private String methodName;

/**
* 方法参数签名
*/
private String[] methodArgumentSignatures;

/**
* 方法参数
*/
private Object[] methodArguments;
}

响应Packet扩展字段

字段名 字段类型 字节数(byte) 字段功能 备注
errorCode java.lang.Long 4 响应码
message java.lang.String 视实际情况而定 响应消息 如果出现异常,message就是对应的异常信息
payload java.lang.Object 视实际情况而定 消息载荷 业务处理返回的消息载荷,定义为Object类型
@EqualsAndHashCode(callSuper = true)
@Data
public class ResponseMessagePacket extends BaseMessagePacket {

/**
* error code
*/
private Long errorCode;

/**
* 消息描述
*/
private String message;

/**
* 消息载荷
*/
private Object payload;
}

需要注意以下几点

  • 非基本类型在序列化和反序列化的时候,一定注意要先写入或者先读取序列的长度,以java.lang.String类型为例:
// 序列化 - 流水号
out.writeInt(packet.getSerialNumber().length());
out.writeCharSequence(packet.getSerialNumber(), ProtocolConstant.UTF_8);

// 反序列化 - 流水号
int serialNumberLength = in.readInt();
packet.setSerialNumber(in.readCharSequence(serialNumberLength, ProtocolConstant.UTF_8).toString());
  • 特殊编码的字符串在序列化的时候,要注意字符串编码的长度,例如UTF-8编码下一个中文字符占3个字节,这一点可以抽取一个工具类专门处理字符串的序列化:
public enum ByteBufferUtils {

// 单例
X;

public void encodeUtf8CharSequence(ByteBuf byteBuf, CharSequence charSequence) {
int writerIndex = byteBuf.writerIndex();
byteBuf.writeInt(0);
int length = ByteBufUtil.writeUtf8(byteBuf, charSequence);
byteBuf.setInt(writerIndex, length);
}
}
  • 方法参数数组的序列化和反序列化方案需要定制,笔者为了简化自定义协议,定义了方法参数签名数组,长度和方法参数数组一致,这样做方便后面编写服务端代码的时候,简化对方法参数数组进行反序列化以及宿主类目标方法的查找。注意一下Object[]的序列化和反序列化相对特殊,因为ByteBuf无法处理自定义类型的写入和读取(这个很好理解,网络编程就是面向01的编程):
write Object --> ByteBuf#writeInt() && ByteBuf#writeBytes()

read Object --> ByteBuf#readInt() && ByteBuf#readBytes() [<== 这个方法返回值是ByteBuf实例]
  • 最后注意释放ByteBuf的引用,否则有可能导致内存泄漏。

自定义协议编码解码实现

自定义协议编码解码主要包括四个部分的编码解码器:

  • 请求Packet编码器:RequestMessagePacketEncoder,主要用于客户端RequestMessagePacket实例序列化为二进制序列。
  • 请求Packet解码器:RequestMessagePacketDecoder,主要用于服务端把二进制序列反序列化为RequestMessagePacket实例。
  • 响应Packet编码器:ResponseMessagePacketEncoder,主要用于服务端ResponseMessagePacket实例序列化为二进制序列。
  • 响应Packet解码器:ResponseMessagePacketDecoder,主要用于客户端把二进制序列反序列化为ResponseMessagePacket实例。

画个图描述一下几个组件的交互流程(省略了部分入站和出站处理器):

序列化器Serializer的代码如下:

public interface Serializer {

byte[] encode(Object target);

Object decode(byte[] bytes, Class<?> targetClass);
}

// FastJson实现
public enum FastJsonSerializer implements Serializer {

// 单例
X;

@Override
public byte[] encode(Object target) {
return JSON.toJSONBytes(target);
}

@Override
public Object decode(byte[] bytes, Class<?> targetClass) {
return JSON.parseObject(bytes, targetClass);
}
}

请求Packet编码器RequestMessagePacketEncoder的代码如下:

@RequiredArgsConstructor
public class RequestMessagePacketEncoder extends MessageToByteEncoder<RequestMessagePacket> {

private final Serializer serializer;

@Override
protected void encode(ChannelHandlerContext context, RequestMessagePacket packet, ByteBuf out) throws Exception {
// 魔数
out.writeInt(packet.getMagicNumber());
// 版本
out.writeInt(packet.getVersion());
// 流水号
out.writeInt(packet.getSerialNumber().length());
out.writeCharSequence(packet.getSerialNumber(), ProtocolConstant.UTF_8);
// 消息类型
out.writeByte(packet.getMessageType().getType());
// 附件size
Map<String, String> attachments = packet.getAttachments();
out.writeInt(attachments.size());
// 附件内容
attachments.forEach((k, v) -> {
out.writeInt(k.length());
out.writeCharSequence(k, ProtocolConstant.UTF_8);
out.writeInt(v.length());
out.writeCharSequence(v, ProtocolConstant.UTF_8);
});
// 接口全类名
out.writeInt(packet.getInterfaceName().length());
out.writeCharSequence(packet.getInterfaceName(), ProtocolConstant.UTF_8);
// 方法名
out.writeInt(packet.getMethodName().length());
out.writeCharSequence(packet.getMethodName(), ProtocolConstant.UTF_8);
// 方法参数签名(String[]类型) - 非必须
if (null != packet.getMethodArgumentSignatures()) {
int len = packet.getMethodArgumentSignatures().length;
// 方法参数签名数组长度
out.writeInt(len);
for (int i = 0; i < len; i++) {
String methodArgumentSignature = packet.getMethodArgumentSignatures()[i];
out.writeInt(methodArgumentSignature.length());
out.writeCharSequence(methodArgumentSignature, ProtocolConstant.UTF_8);
}
} else {
out.writeInt(0);
}
// 方法参数(Object[]类型) - 非必须
if (null != packet.getMethodArguments()) {
int len = packet.getMethodArguments().length;
// 方法参数数组长度
out.writeInt(len);
for (int i = 0; i < len; i++) {
byte[] bytes = serializer.encode(packet.getMethodArguments()[i]);
out.writeInt(bytes.length);
out.writeBytes(bytes);
}
} else {
out.writeInt(0);
}
}
}

请求Packet解码器RequestMessagePacketDecoder的代码如下:

@RequiredArgsConstructor
public class RequestMessagePacketDecoder extends ByteToMessageDecoder {

@Override
protected void decode(ChannelHandlerContext context, ByteBuf in, List<Object> list) throws Exception {
RequestMessagePacket packet = new RequestMessagePacket();
// 魔数
packet.setMagicNumber(in.readInt());
// 版本
packet.setVersion(in.readInt());
// 流水号
int serialNumberLength = in.readInt();
packet.setSerialNumber(in.readCharSequence(serialNumberLength, ProtocolConstant.UTF_8).toString());
// 消息类型
byte messageTypeByte = in.readByte();
packet.setMessageType(MessageType.fromValue(messageTypeByte));
// 附件
Map<String, String> attachments = Maps.newHashMap();
packet.setAttachments(attachments);
int attachmentSize = in.readInt();
if (attachmentSize > 0) {
for (int i = 0; i < attachmentSize; i++) {
int keyLength = in.readInt();
String key = in.readCharSequence(keyLength, ProtocolConstant.UTF_8).toString();
int valueLength = in.readInt();
String value = in.readCharSequence(valueLength, ProtocolConstant.UTF_8).toString();
attachments.put(key, value);
}
}
// 接口全类名
int interfaceNameLength = in.readInt();
packet.setInterfaceName(in.readCharSequence(interfaceNameLength, ProtocolConstant.UTF_8).toString());
// 方法名
int methodNameLength = in.readInt();
packet.setMethodName(in.readCharSequence(methodNameLength, ProtocolConstant.UTF_8).toString());
// 方法参数签名
int methodArgumentSignatureArrayLength = in.readInt();
if (methodArgumentSignatureArrayLength > 0) {
String[] methodArgumentSignatures = new String[methodArgumentSignatureArrayLength];
for (int i = 0; i < methodArgumentSignatureArrayLength; i++) {
int methodArgumentSignatureLength = in.readInt();
methodArgumentSignatures[i] = in.readCharSequence(methodArgumentSignatureLength, ProtocolConstant.UTF_8).toString();
}
packet.setMethodArgumentSignatures(methodArgumentSignatures);
}
// 方法参数
int methodArgumentArrayLength = in.readInt();
if (methodArgumentArrayLength > 0) {
// 这里的Object[]实际上是ByteBuf[] - 后面需要二次加工为对应类型的实例
Object[] methodArguments = new Object[methodArgumentArrayLength];
for (int i = 0; i < methodArgumentArrayLength; i++) {
int byteLength = in.readInt();
methodArguments[i] = in.readBytes(byteLength);
}
packet.setMethodArguments(methodArguments);
}
list.add(packet);
}
}

响应Packet编码器ResponseMessagePacketEncoder的代码如下:

@RequiredArgsConstructor
public class ResponseMessagePacketEncoder extends MessageToByteEncoder<ResponseMessagePacket> {

private final Serializer serializer;

@Override
protected void encode(ChannelHandlerContext ctx, ResponseMessagePacket packet, ByteBuf out) throws Exception {
// 魔数
out.writeInt(packet.getMagicNumber());
// 版本
out.writeInt(packet.getVersion());
// 流水号
out.writeInt(packet.getSerialNumber().length());
out.writeCharSequence(packet.getSerialNumber(), ProtocolConstant.UTF_8);
// 消息类型
out.writeByte(packet.getMessageType().getType());
// 附件size
Map<String, String> attachments = packet.getAttachments();
out.writeInt(attachments.size());
// 附件内容
attachments.forEach((k, v) -> {
out.writeInt(k.length());
out.writeCharSequence(k, ProtocolConstant.UTF_8);
out.writeInt(v.length());
out.writeCharSequence(v, ProtocolConstant.UTF_8);
});
// error code
out.writeLong(packet.getErrorCode());
// message
String message = packet.getMessage();
ByteBufferUtils.X.encodeUtf8CharSequence(out, message);
// payload
byte[] bytes = serializer.encode(packet.getPayload());
out.writeInt(bytes.length);
out.writeBytes(bytes);
}
}

响应Packet解码器ResponseMessagePacketDecoder的代码如下:

public class ResponseMessagePacketDecoder extends ByteToMessageDecoder {

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
ResponseMessagePacket packet = new ResponseMessagePacket();
// 魔数
packet.setMagicNumber(in.readInt());
// 版本
packet.setVersion(in.readInt());
// 流水号
int serialNumberLength = in.readInt();
packet.setSerialNumber(in.readCharSequence(serialNumberLength, ProtocolConstant.UTF_8).toString());
// 消息类型
byte messageTypeByte = in.readByte();
packet.setMessageType(MessageType.fromValue(messageTypeByte));
// 附件
Map<String, String> attachments = Maps.newHashMap();
packet.setAttachments(attachments);
int attachmentSize = in.readInt();
if (attachmentSize > 0) {
for (int i = 0; i < attachmentSize; i++) {
int keyLength = in.readInt();
String key = in.readCharSequence(keyLength, ProtocolConstant.UTF_8).toString();
int valueLength = in.readInt();
String value = in.readCharSequence(valueLength, ProtocolConstant.UTF_8).toString();
attachments.put(key, value);
}
}
// error code
packet.setErrorCode(in.readLong());
// message
int messageLength = in.readInt();
packet.setMessage(in.readCharSequence(messageLength, ProtocolConstant.UTF_8).toString());
// payload - ByteBuf实例
int payloadLength = in.readInt();
packet.setPayload(in.readBytes(payloadLength));
out.add(packet);
}
}

核心的编码解码器已经编写完,接着要注意一下TCP协议二进制包发送的时候只保证了包的发送顺序、确认发送以及重传,无法保证二进制包是否完整(有些博客也称此类场景为粘包、半包等等,其实网络协议里面并没有定义这些术语,估计是有人杜撰出来),因此这里采取了定长帧编码和解码器LengthFieldPrependerLengthFieldBasedFrameDecoder,简单来说就是在消息帧的开头几位定义了整个帧的长度,读取到整个长度的消息帧才认为是一个完整的二进制报文。举个几个例子:

|<--------packet frame--------->|
| Length Field | Actual Content |
序号 Length Field Actual Content
0 4 abcd
1 9 throwable
2 14 {“name”:”doge”}

编写测试客户端和服务端

客户端代码如下:

@Slf4j
public class TestProtocolClient {

public static void main(String[] args) throws Exception {
int port = 9092;
EventLoopGroup workerGroup = new NioEventLoopGroup();
Bootstrap bootstrap = new Bootstrap();
try {
bootstrap.group(workerGroup);
bootstrap.channel(NioSocketChannel.class);
bootstrap.option(ChannelOption.SO_KEEPALIVE, Boolean.TRUE);
bootstrap.option(ChannelOption.TCP_NODELAY, Boolean.TRUE);
bootstrap.handler(new ChannelInitializer<SocketChannel>() {

@Override
protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast(new LengthFieldBasedFrameDecoder(1024, 0, 4, 0, 4));
ch.pipeline().addLast(new LengthFieldPrepender(4));
ch.pipeline().addLast(new RequestMessagePacketEncoder(FastJsonSerializer.X));
ch.pipeline().addLast(new ResponseMessagePacketDecoder());
ch.pipeline().addLast(new SimpleChannelInboundHandler<ResponseMessagePacket>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, ResponseMessagePacket packet) throws Exception {
Object targetPayload = packet.getPayload();
if (targetPayload instanceof ByteBuf) {
ByteBuf byteBuf = (ByteBuf) targetPayload;
int readableByteLength = byteBuf.readableBytes();
byte[] bytes = new byte[readableByteLength];
byteBuf.readBytes(bytes);
targetPayload = FastJsonSerializer.X.decode(bytes, String.class);
byteBuf.release();
}
packet.setPayload(targetPayload);
log.info("接收到来自服务端的响应消息,消息内容:{}", JSON.toJSONString(packet));
}
});
}
});
ChannelFuture future = bootstrap.connect("localhost", port).sync();
log.info("启动NettyClient[{}]成功...", port);
Channel channel = future.channel();
RequestMessagePacket packet = new RequestMessagePacket();
packet.setMagicNumber(ProtocolConstant.MAGIC_NUMBER);
packet.setVersion(ProtocolConstant.VERSION);
packet.setSerialNumber(SerialNumberUtils.X.generateSerialNumber());
packet.setMessageType(MessageType.REQUEST);
packet.setInterfaceName("club.throwable.contract.HelloService");
packet.setMethodName("sayHello");
packet.setMethodArgumentSignatures(new String[]{"java.lang.String"});
packet.setMethodArguments(new Object[]{"doge"});
channel.writeAndFlush(packet);
future.channel().closeFuture().sync();
} finally {
workerGroup.shutdownGracefully();
}
}
}

服务端代码如下:

@Slf4j
public class TestProtocolServer {

public static void main(String[] args) throws Exception {
int port = 9092;
ServerBootstrap bootstrap = new ServerBootstrap();
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
bootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {

@Override
protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast(new LengthFieldBasedFrameDecoder(1024, 0, 4, 0, 4));
ch.pipeline().addLast(new LengthFieldPrepender(4));
ch.pipeline().addLast(new RequestMessagePacketDecoder());
ch.pipeline().addLast(new ResponseMessagePacketEncoder(FastJsonSerializer.X));
ch.pipeline().addLast(new SimpleChannelInboundHandler<RequestMessagePacket>() {

@Override
protected void channelRead0(ChannelHandlerContext ctx, RequestMessagePacket packet) throws Exception {
log.info("接收到来自客户端的请求消息,消息内容:{}", JSON.toJSONString(packet));
ResponseMessagePacket response = new ResponseMessagePacket();
response.setMagicNumber(packet.getMagicNumber());
response.setVersion(packet.getVersion());
response.setSerialNumber(packet.getSerialNumber());
response.setAttachments(packet.getAttachments());
response.setMessageType(MessageType.RESPONSE);
response.setErrorCode(200L);
response.setMessage("Success");
response.setPayload("{\"name\":\"throwable\"}");
ctx.writeAndFlush(response);
}
});
}
});
ChannelFuture future = bootstrap.bind(port).sync();
log.info("启动NettyServer[{}]成功...", port);
future.channel().closeFuture().sync();
} finally {
workerGroup.shutdownGracefully();
bossGroup.shutdownGracefully();
}
}
}

这里在测试的环境中,最大的消息帧长度暂时定义为1024。先启动服务端,再启动客户端,见控制台输出如下:

// 服务端
22:29:32.596 [main] INFO club.throwable.protocol.TestProtocolServer - 启动NettyServer[9092]成功...
...省略其他日志...
22:29:53.538 [nioEventLoopGroup-3-1] INFO club.throwable.protocol.TestProtocolServer - 接收到来自客户端的请求消息,消息内容:{"attachments":{},"interfaceName":"club.throwable.contract.HelloService","magicNumber":10086,"messageType":"REQUEST","methodArgumentSignatures":["java.lang.String"],"methodArguments":[{"contiguous":true,"direct":true,"readOnly":false,"readable":true,"writable":false}],"methodName":"sayHello","serialNumber":"7f992c7cf9f445258601def1cac9bec0","version":1}

// 客户端
22:31:28.360 [main] INFO club.throwable.protocol.TestProtocolClient - 启动NettyClient[9092]成功...
...省略其他日志...
22:31:39.320 [nioEventLoopGroup-2-1] INFO club.throwable.protocol.TestProtocolClient - 接收到来自服务端的响应消息,消息内容:{"attachments":{},"errorCode":200,"magicNumber":10086,"message":"Success","messageType":"RESPONSE","payload":"{\"name\":\"throwable\"}","serialNumber":"320808e709b34edbb91ba557780b58ad","version":1}

小结

一个基于Netty实现的简单的自定义协议基本完成,但是要编写一个优秀的RPC框架,还需要做服务端的宿主类和目标方法查询、调用,客户端的动态代理,NettyNIO模式下的同步调用改造,心跳处理,异常处理等等。后面会使用多篇文章逐个问题解决,网络编程其实挺好玩了,就是编码量会比较大(゜-゜)つロ

Demo项目:

(e-a-20200112 c-1-d)