main.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package main
  2. import (
  3. "context"
  4. "emqx.io/grpc/exhook/global"
  5. "emqx.io/grpc/exhook/model"
  6. pb "emqx.io/grpc/exhook/protobuf"
  7. "emqx.io/grpc/exhook/simple_zap"
  8. exhook "emqx.io/grpc/exhook/utils"
  9. "fmt"
  10. "github.com/bytedance/sonic"
  11. "github.com/patrickmn/go-cache"
  12. "github.com/pkg/errors"
  13. "go.uber.org/zap"
  14. "google.golang.org/grpc"
  15. "log"
  16. "net"
  17. "strings"
  18. "time"
  19. )
  20. func init() {
  21. err := global.SetupSetting()
  22. if err != nil {
  23. log.Fatalf("init setting err: %v", err)
  24. }
  25. }
  26. var cacheInstance = cache.New(1*time.Minute, 2*time.Minute)
  27. type server struct {
  28. pb.UnimplementedHookProviderServer
  29. }
  30. func (s *server) OnProviderLoaded(ctx context.Context, in *pb.ProviderLoadedRequest) (*pb.LoadedResponse, error) {
  31. hooks := []*pb.HookSpec{
  32. {Name: "message.publish"},
  33. }
  34. return &pb.LoadedResponse{Hooks: hooks}, nil
  35. }
  36. func (s *server) OnMessagePublish(ctx context.Context, in *pb.MessagePublishRequest) (*pb.ValuedResponse, error) {
  37. logger := simple_zap.WithCtx(ctx)
  38. topic := strings.TrimSuffix(in.GetMessage().GetTopic(), "/")
  39. payload := in.GetMessage().GetPayload()
  40. if payload == nil || len(payload) >= 1000 {
  41. logger.Info("消息体为空或大于等于1000字节", zap.String("key", topic))
  42. return nil, errors.New("消息体为空或大于等于1000字节")
  43. }
  44. var jsonPayload model.T
  45. err := sonic.Unmarshal(payload, &jsonPayload)
  46. if err != nil {
  47. logger.Info("json解析失败", zap.String("key", topic))
  48. return nil, errors.Wrap(err, topic+"json解析失败")
  49. }
  50. if int(jsonPayload.Type) == 2 {
  51. key := fmt.Sprintf("%s-%v", topic, jsonPayload.Data)
  52. data := fmt.Sprintf("%s-%v", topic, jsonPayload)
  53. //md5加密key缩短缓存中的数据
  54. md5 := exhook.MD5(key)
  55. if _, found := cacheInstance.Get(md5); found {
  56. return discardMessagePublish(ctx, in, func(response *pb.ValuedResponse) error {
  57. logger.Warn("消息重复被丢弃", zap.String("key", data))
  58. return nil
  59. })
  60. }
  61. cacheInstance.Set(md5, "alarm", cache.DefaultExpiration)
  62. }
  63. // 正常发送消息
  64. return &pb.ValuedResponse{
  65. Type: pb.ValuedResponse_STOP_AND_RETURN,
  66. Value: &pb.ValuedResponse_Message{
  67. Message: in.GetMessage(),
  68. },
  69. }, nil
  70. }
  71. func discardMessagePublish(ctx context.Context, in *pb.MessagePublishRequest, responseWriter func(*pb.ValuedResponse) error) (*pb.ValuedResponse, error) {
  72. // 增加对输入参数的非空验证,防止空指针异常
  73. if in == nil || in.Message == nil {
  74. return nil, errors.New("输入参数不能为空")
  75. }
  76. emptyPayload := []byte{}
  77. newMsg := &pb.Message{
  78. Id: in.Message.Id,
  79. Node: in.Message.Node,
  80. From: in.Message.From,
  81. Topic: in.Message.Topic,
  82. Payload: emptyPayload,
  83. Headers: map[string]string{"allow_publish": "false"},
  84. }
  85. reply := &pb.ValuedResponse{
  86. Type: pb.ValuedResponse_STOP_AND_RETURN,
  87. Value: &pb.ValuedResponse_Message{
  88. Message: newMsg,
  89. },
  90. }
  91. if err := responseWriter(reply); err != nil {
  92. simple_zap.Logger.Error("发送响应时出错", zap.String("key", reply.GetMessage().GetTopic()))
  93. return nil, errors.Wrap(err, "发送响应时出错")
  94. }
  95. // 正常发送消息之后
  96. return reply, nil
  97. }
  98. func main() {
  99. lis, err := net.Listen("tcp", global.ServerSetting.Port)
  100. if err != nil {
  101. log.Fatalf("failed to listen: %v", err)
  102. }
  103. s := grpc.NewServer()
  104. pb.RegisterHookProviderServer(s, &server{})
  105. log.Println("Started gRPC server on", global.ServerSetting.Port)
  106. if err := s.Serve(lis); err != nil {
  107. log.Fatalf("failed to serve: %v", err)
  108. }
  109. }