|
@@ -2,11 +2,16 @@ package main
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
+ "emqx.io/grpc/exhook/global"
|
|
|
+ "emqx.io/grpc/exhook/model"
|
|
|
pb "emqx.io/grpc/exhook/protobuf"
|
|
|
- "encoding/json"
|
|
|
+ "emqx.io/grpc/exhook/simple_zap"
|
|
|
+ exhook "emqx.io/grpc/exhook/utils"
|
|
|
"fmt"
|
|
|
+ "github.com/bytedance/sonic"
|
|
|
"github.com/patrickmn/go-cache"
|
|
|
"github.com/pkg/errors"
|
|
|
+ "go.uber.org/zap"
|
|
|
"google.golang.org/grpc"
|
|
|
"log"
|
|
|
"net"
|
|
@@ -14,11 +19,14 @@ import (
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
-const (
|
|
|
- port = ":9000"
|
|
|
-)
|
|
|
+func init() {
|
|
|
+ err := global.SetupSetting()
|
|
|
+ if err != nil {
|
|
|
+ log.Fatalf("init setting err: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
-var cacheInstance = cache.New(5*time.Minute, 10*time.Minute)
|
|
|
+var cacheInstance = cache.New(1*time.Minute, 2*time.Minute)
|
|
|
|
|
|
type server struct {
|
|
|
pb.UnimplementedHookProviderServer
|
|
@@ -32,36 +40,35 @@ func (s *server) OnProviderLoaded(ctx context.Context, in *pb.ProviderLoadedRequ
|
|
|
}
|
|
|
|
|
|
func (s *server) OnMessagePublish(ctx context.Context, in *pb.MessagePublishRequest) (*pb.ValuedResponse, error) {
|
|
|
- log.Printf("[DEBUG] OnMessagePublish: %s", in.Message.Topic)
|
|
|
+ logger := simple_zap.WithCtx(ctx)
|
|
|
topic := strings.TrimSuffix(in.GetMessage().GetTopic(), "/")
|
|
|
payload := in.GetMessage().GetPayload()
|
|
|
|
|
|
if payload == nil || len(payload) >= 1000 {
|
|
|
+ logger.Info("消息体为空或大于等于1000字节", zap.String("key", topic))
|
|
|
return nil, errors.New("消息体为空或大于等于1000字节")
|
|
|
}
|
|
|
|
|
|
- var jsonPayload map[string]interface{}
|
|
|
- if err := json.Unmarshal(payload, &jsonPayload); err != nil {
|
|
|
- return nil, errors.Wrap(err, "json解析失败")
|
|
|
- }
|
|
|
-
|
|
|
- typeVal, ok := jsonPayload["type"].(float64)
|
|
|
- if !ok {
|
|
|
- return nil, errors.New("json中'type'字段解析失败")
|
|
|
+ var jsonPayload model.T
|
|
|
+ err := sonic.Unmarshal(payload, &jsonPayload)
|
|
|
+ if err != nil {
|
|
|
+ logger.Info("json解析失败", zap.String("key", topic))
|
|
|
+ return nil, errors.Wrap(err, topic+"json解析失败")
|
|
|
}
|
|
|
+ if int(jsonPayload.Type) == 2 {
|
|
|
+ key := fmt.Sprintf("%s-%v", topic, jsonPayload.Data)
|
|
|
+ data := fmt.Sprintf("%s-%v", topic, jsonPayload)
|
|
|
|
|
|
- if int(typeVal) == 2 {
|
|
|
- key := fmt.Sprintf("%s-%v", topic, jsonPayload["data"])
|
|
|
- log.Printf("缓存中键的值: %s", key)
|
|
|
- if _, found := cacheInstance.Get(key); found {
|
|
|
+ //md5加密key缩短缓存中的数据
|
|
|
+ md5 := exhook.MD5(key)
|
|
|
+ if _, found := cacheInstance.Get(md5); found {
|
|
|
return discardMessagePublish(ctx, in, func(response *pb.ValuedResponse) error {
|
|
|
- log.Printf("丢弃重复消息")
|
|
|
+ logger.Warn("消息重复被丢弃", zap.String("key", data))
|
|
|
return nil
|
|
|
})
|
|
|
}
|
|
|
- cacheInstance.Set(key, "alarm", cache.DefaultExpiration)
|
|
|
+ cacheInstance.Set(md5, "alarm", cache.DefaultExpiration)
|
|
|
}
|
|
|
-
|
|
|
// 正常发送消息
|
|
|
return &pb.ValuedResponse{
|
|
|
Type: pb.ValuedResponse_STOP_AND_RETURN,
|
|
@@ -72,6 +79,10 @@ func (s *server) OnMessagePublish(ctx context.Context, in *pb.MessagePublishRequ
|
|
|
}
|
|
|
|
|
|
func discardMessagePublish(ctx context.Context, in *pb.MessagePublishRequest, responseWriter func(*pb.ValuedResponse) error) (*pb.ValuedResponse, error) {
|
|
|
+ // 增加对输入参数的非空验证,防止空指针异常
|
|
|
+ if in == nil || in.Message == nil {
|
|
|
+ return nil, errors.New("输入参数不能为空")
|
|
|
+ }
|
|
|
emptyPayload := []byte{}
|
|
|
newMsg := &pb.Message{
|
|
|
Id: in.Message.Id,
|
|
@@ -90,20 +101,21 @@ func discardMessagePublish(ctx context.Context, in *pb.MessagePublishRequest, re
|
|
|
}
|
|
|
|
|
|
if err := responseWriter(reply); err != nil {
|
|
|
+ simple_zap.Logger.Error("发送响应时出错", zap.String("key", reply.GetMessage().GetTopic()))
|
|
|
return nil, errors.Wrap(err, "发送响应时出错")
|
|
|
}
|
|
|
-
|
|
|
+ // 正常发送消息之后
|
|
|
return reply, nil
|
|
|
}
|
|
|
|
|
|
func main() {
|
|
|
- lis, err := net.Listen("tcp", port)
|
|
|
+ lis, err := net.Listen("tcp", global.ServerSetting.Port)
|
|
|
if err != nil {
|
|
|
log.Fatalf("failed to listen: %v", err)
|
|
|
}
|
|
|
s := grpc.NewServer()
|
|
|
pb.RegisterHookProviderServer(s, &server{})
|
|
|
- log.Println("Started gRPC server on", port)
|
|
|
+ log.Println("Started gRPC server on", global.ServerSetting.Port)
|
|
|
if err := s.Serve(lis); err != nil {
|
|
|
log.Fatalf("failed to serve: %v", err)
|
|
|
}
|