main.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package main
  2. import (
  3. "context"
  4. pb "emqx.io/grpc/exhook/protobuf"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/patrickmn/go-cache"
  8. "github.com/pkg/errors"
  9. "google.golang.org/grpc"
  10. "log"
  11. "net"
  12. "strings"
  13. "time"
  14. )
  15. const (
  16. port = ":9000"
  17. )
  18. var cacheInstance = cache.New(5*time.Minute, 10*time.Minute)
  19. type server struct {
  20. pb.UnimplementedHookProviderServer
  21. }
  22. func (s *server) OnProviderLoaded(ctx context.Context, in *pb.ProviderLoadedRequest) (*pb.LoadedResponse, error) {
  23. hooks := []*pb.HookSpec{
  24. {Name: "message.publish"},
  25. }
  26. return &pb.LoadedResponse{Hooks: hooks}, nil
  27. }
  28. func (s *server) OnMessagePublish(ctx context.Context, in *pb.MessagePublishRequest) (*pb.ValuedResponse, error) {
  29. log.Printf("[DEBUG] OnMessagePublish: %s", in.Message.Topic)
  30. topic := strings.TrimSuffix(in.GetMessage().GetTopic(), "/")
  31. payload := in.GetMessage().GetPayload()
  32. if payload == nil || len(payload) >= 1000 {
  33. return nil, errors.New("消息体为空或大于等于1000字节")
  34. }
  35. var jsonPayload map[string]interface{}
  36. if err := json.Unmarshal(payload, &jsonPayload); err != nil {
  37. return nil, errors.Wrap(err, "json解析失败")
  38. }
  39. typeVal, ok := jsonPayload["type"].(float64)
  40. if !ok {
  41. return nil, errors.New("json中'type'字段解析失败")
  42. }
  43. if int(typeVal) == 2 {
  44. key := fmt.Sprintf("%s-%v", topic, jsonPayload["data"])
  45. log.Printf("缓存中键的值: %s", key)
  46. if _, found := cacheInstance.Get(key); found {
  47. return discardMessagePublish(ctx, in, func(response *pb.ValuedResponse) error {
  48. log.Printf("丢弃重复消息")
  49. return nil
  50. })
  51. }
  52. cacheInstance.Set(key, "alarm", cache.DefaultExpiration)
  53. }
  54. // 正常发送消息
  55. return &pb.ValuedResponse{
  56. Type: pb.ValuedResponse_STOP_AND_RETURN,
  57. Value: &pb.ValuedResponse_Message{
  58. Message: in.GetMessage(),
  59. },
  60. }, nil
  61. }
  62. func discardMessagePublish(ctx context.Context, in *pb.MessagePublishRequest, responseWriter func(*pb.ValuedResponse) error) (*pb.ValuedResponse, error) {
  63. emptyPayload := []byte{}
  64. newMsg := &pb.Message{
  65. Id: in.Message.Id,
  66. Node: in.Message.Node,
  67. From: in.Message.From,
  68. Topic: in.Message.Topic,
  69. Payload: emptyPayload,
  70. Headers: map[string]string{"allow_publish": "false"},
  71. }
  72. reply := &pb.ValuedResponse{
  73. Type: pb.ValuedResponse_STOP_AND_RETURN,
  74. Value: &pb.ValuedResponse_Message{
  75. Message: newMsg,
  76. },
  77. }
  78. if err := responseWriter(reply); err != nil {
  79. return nil, errors.Wrap(err, "发送响应时出错")
  80. }
  81. return reply, nil
  82. }
  83. func main() {
  84. lis, err := net.Listen("tcp", port)
  85. if err != nil {
  86. log.Fatalf("failed to listen: %v", err)
  87. }
  88. s := grpc.NewServer()
  89. pb.RegisterHookProviderServer(s, &server{})
  90. log.Println("Started gRPC server on", port)
  91. if err := s.Serve(lis); err != nil {
  92. log.Fatalf("failed to serve: %v", err)
  93. }
  94. }