最近在使用websocket过程中,由于使用分布式多节点方式,前端连接到ws时后端会有一个session,使用这个session可以由服务端主动发送消息给前端,首先想象到的是能否把session共享给其他节点,看WsSession源码,发现WsSession并没有系列化,所以无法共享;

1
2
public class WsSession implements Session {
}

后来发现可以通过redis广播的方式,这样做前提是当前端连接时将对应的key放入缓存,断开时将key移出缓存,其他配置就按照正常单节点websocket配置即可,当消息到来时,先判断是否判断待接受key对应连接用户是否在当前节点,如果在则直接发送,如果不在则查询缓存,缓存中存在时,则将消息通过redis广播转给其他节点,当其他节点接收到广消息时,查询对应的key是否当前节点,如果在就查询对应的session发送前端。

于是我便自己写了有一个工具jar,可以方便后续使用,在这里分享出来。

pom中加入依赖

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
<dependencies>
<!--websocket starter-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
<scope>provided</scope>
</dependency>
<!--hutool 工具包-->
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<scope>provided</scope>
</dependency>
<!--redis starter-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
</dependencies>

配置redis与websocket的配置,代码如下

1
2
3
4
5
6
7
8
9
10
11
12
//websocket 配置
@Configuration
public class WebSocketConfig {
/**
* 注入一个ServerEndpointExporter
* 该Bean会自动注册使用@ServerEndpoint注解申明的websocket endpoint
*/
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}

redis相关配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
//redis相关配置
@Configuration
public class RedisConfig {
//redisTemplate配置以及序列化相关
@Bean
@ConditionalOnMissingBean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(redisConnectionFactory);
ObjectMapper om = new ObjectMapper();
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
/* 过时替换 om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL); */
om.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
// key采用String的序列化方式
redisTemplate.setKeySerializer(stringRedisSerializer);
// hash的key也采用String的序列化方式
redisTemplate.setHashKeySerializer(stringRedisSerializer);
// value序列化方式采用jackson
redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
// hash的value序列化方式采用jackson
redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
redisTemplate.afterPropertiesSet();
return redisTemplate;
}
//redis广播监听bean
@Bean
@ConditionalOnMissingBean
MessageListenerAdapter listenerAdapter() {
return new MessageListenerAdapter(new RedisMessageListener(), "onMessage");
}
//通道监听配置
@Bean
@ConditionalOnMissingBean
RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(connectionFactory);
//订阅了一个叫chat 的通道
container.addMessageListener(listenerAdapter(), new PatternTopic("*"));
//这个container 可以添加多个 messageListener
return container;
}
}

redis广播监听

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@Slf4j
public class RedisMessageListener implements MessageListener {
@Override
public void onMessage(Message message, byte[] pattern) {
log.info("收到广播消息");
WebSocketUtil webSocketUtil = SpringUtil.getBean(WebSocketUtil.class);
byte[] body = message.getBody();
try {
WebSocketMsg webSocketMsg = JSONUtil.toBean(new String(body), WebSocketMsg.class);
//注意这里广播收到消息后,无需再向外广播
webSocketUtil.sendMessage(new String(message.getChannel()), webSocketMsg.getData(), webSocketMsg.getKey(), Boolean.TRUE);
} catch (ClassNotFoundException e) {
log.error(e.getMessage(), e);
throw new RuntimeException(e);
}
}
}


websocket的抽象类,提供一些基础常用方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@Slf4j
public class AbstractWebSocket {
/**
* 记录当前在线连接数
*/
private static final AtomicInteger ONLINE_COUNT = new AtomicInteger(0);


/**
* 存放所有在线的客户端
*/
private static final Map<Object, Session> CLIENTS = new ConcurrentHashMap<>();


/**
* 连接建立成功调用的方法
*/
public void superOnOpen(RedisTemplate<String, Object> redisTemplate, Session session, Object key) {
ONLINE_COUNT.incrementAndGet(); // 在线数加1
Session existSession = CLIENTS.get(key);
if (existSession == null) {
CLIENTS.put(key, session);
redisTemplate.opsForSet().add(getClass().getName(), key);
log.info("{}有新连接加入:{},当前在线人数为:{}", getClass(), session.getId(), ONLINE_COUNT.get());
}
}

/**
* 连接关闭调用的方法
*/
public void superOnClose(RedisTemplate<String, Object> redisTemplate, Session session, Object key) {
// 在线数减1
ONLINE_COUNT.decrementAndGet();
try (Session ignored = CLIENTS.remove(key)) {
redisTemplate.opsForSet().remove(getClass().getName(), key);
log.info("{}-->>离线", key);
} catch (IOException e) {
log.error(e.getMessage(), e);
}
log.info("有一连接关闭:{},当前在线人数为:{}", session.getId(), ONLINE_COUNT.get());
}

/**
* 服务端发送消息给客户端
*/
public void sendMessage(RedisTemplate<String, Object> redisTemplate, Object message, Object key) {
try {
Session toSession = CLIENTS.get(key);
if (toSession != null) {
log.info("本地有" + key);
toSession.getAsyncRemote().sendText(JSONUtil.toJsonStr(message));
} else {
Set<Object> members = redisTemplate.opsForSet().members(getClass().getName());
if (!CollectionUtils.isEmpty(members) && members.contains(key)) {
log.info("redis有" + key);
redisTemplate.convertAndSend(getClass().getName(), WebSocketMsg.of().setKey(key).setData(message));
}
}
} catch (Exception e) {
log.error("服务端发送消息给客户端失败", e);
}
}

public void sendMessageOnThis(Object message, Object key) {
try {
Session toSession = CLIENTS.get(key);
if (toSession != null) {
log.info("本地有" + key);
toSession.getAsyncRemote().sendText(JSONUtil.toJsonStr(message));
}
} catch (Exception e) {
log.error("服务端发送消息给客户端失败", e);
}
}

/**
* 获取session
*
* @param key key
* @return jakarta.websocket.Session
* @author HuangBoo
* @since 2023/3/19 14:32
*/
public Session getSessionByKey(Object key) {
return CLIENTS.get(key);
}
}

再提供一个bean用于消息发送

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@Component
public class WebSocketUtil {
private final RedisTemplate<String, Object> redisTemplate;

public WebSocketUtil(RedisTemplate<String, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
}

public void sendMessage(String clazzName, Object message, Object key,Boolean sendCurrentServer) throws ClassNotFoundException {
if (AbstractWebSocket.class.isAssignableFrom(Class.forName(clazzName))) {
AbstractWebSocket webSocket = (AbstractWebSocket) SpringUtil.getBean(Class.forName(clazzName));
if(Boolean.TRUE.equals(sendCurrentServer)){
webSocket.sendMessageOnThis(message, key);
}else{
webSocket.sendMessage(redisTemplate, message, key);
}
}
}
}

此外为了可以自动加载bean

在resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports加入以下内容

1
2
3
4
net.oiyou.websocket.config.WebSocketConfig
net.oiyou.websocket.redis.RedisConfig
cn.hutool.extra.spring.SpringUtil
net.oiyou.websocket.WebSocketUtil

然后就可以使用了

新建项目引入此依赖jar包

配置redis相关信息

1
2
3
spring.data.redis.host=vm.oiyou.net
spring.data.redis.port=6379
spring.data.redis.password=****

类继承AbstractWebSocket

这样就可以发送消息了,以下为一个实例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
package net.oiyou.demo.websocket.server;

import cn.hutool.extra.spring.SpringUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import jakarta.websocket.*;
import jakarta.websocket.server.PathParam;
import jakarta.websocket.server.ServerEndpoint;
import lombok.extern.slf4j.Slf4j;
import net.oiyou.websocket.AbstractWebSocket;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

/**
* @author HuangBoo
* @since 2023年03月19日 13:08
*/
@Slf4j
@ServerEndpoint(value = "/myWebSocket/{userId}")
@Component
public class MyWebSocket extends AbstractWebSocket {

@OnOpen
public void onOpen(Session session, @PathParam("userId") String userId) {
RedisTemplate<String,Object> redisTemplate = SpringUtil.getBean("redisTemplate");
superOnOpen(redisTemplate, session, userId);
}

@OnClose
public void onClose(Session session, @PathParam("userId") String userId) {
RedisTemplate<String,Object> redisTemplate = SpringUtil.getBean("redisTemplate");
super.superOnClose(redisTemplate, session, userId);
}

@OnError
public void onError(Session session, Throwable error) {
error.printStackTrace();
}

@OnMessage
public void onMessage(Session session, String message, @PathParam("userId") String userId) {
JSONObject data = JSONUtil.parseObj(message);
log.info("{}发送消息给{},内容是:{}", userId, data.get("userId"), message);
RedisTemplate<String,Object> redisTemplate = SpringUtil.getBean("redisTemplate");
sendMessage(redisTemplate, data, data.getStr("userId"));
}
}

另外再贴一下nginx负载均衡配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
http {
upstream ws_stream {
server 127.0.0.1:8080;
server 127.0.0.1:8081;
}

server {
listen 81;
location / {
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_pass http://ws_stream;

# 关键配置 start
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_connect_timeout 4s; #配置点1
proxy_read_timeout 60s; #配置点2,如果没效,可以考虑这个时间配置长一点
proxy_send_timeout 12s; #配置点3
# 关键配置 end
}
}
}