123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- package main
- import (
- "context"
- "emqx.io/grpc/exhook/global"
- "emqx.io/grpc/exhook/model"
- pb "emqx.io/grpc/exhook/protobuf"
- "emqx.io/grpc/exhook/simple_zap"
- exhook "emqx.io/grpc/exhook/utils"
- "fmt"
- "github.com/bytedance/sonic"
- "github.com/patrickmn/go-cache"
- "github.com/pkg/errors"
- "go.uber.org/zap"
- "google.golang.org/grpc"
- "log"
- "net"
- "strings"
- "time"
- )
- func init() {
- err := global.SetupSetting()
- if err != nil {
- log.Fatalf("init setting err: %v", err)
- }
- }
- var cacheInstance = cache.New(1*time.Minute, 2*time.Minute)
- type server struct {
- pb.UnimplementedHookProviderServer
- }
- func (s *server) OnProviderLoaded(ctx context.Context, in *pb.ProviderLoadedRequest) (*pb.LoadedResponse, error) {
- hooks := []*pb.HookSpec{
- {Name: "message.publish"},
- }
- return &pb.LoadedResponse{Hooks: hooks}, nil
- }
- func (s *server) OnMessagePublish(ctx context.Context, in *pb.MessagePublishRequest) (*pb.ValuedResponse, error) {
- logger := simple_zap.WithCtx(ctx)
- topic := strings.TrimSuffix(in.GetMessage().GetTopic(), "/")
- payload := in.GetMessage().GetPayload()
- if payload == nil || len(payload) >= 1000 {
- logger.Info("消息体为空或大于等于1000字节", zap.String("key", topic))
- return nil, errors.New("消息体为空或大于等于1000字节")
- }
- var jsonPayload model.T
- err := sonic.Unmarshal(payload, &jsonPayload)
- if err != nil {
- logger.Info("json解析失败", zap.String("key", topic))
- return nil, errors.Wrap(err, topic+"json解析失败")
- }
- if int(jsonPayload.Type) == 2 {
- key := fmt.Sprintf("%s-%v", topic, jsonPayload.Data)
- data := fmt.Sprintf("%s-%v", topic, jsonPayload)
- //md5加密key缩短缓存中的数据
- md5 := exhook.MD5(key)
- if _, found := cacheInstance.Get(md5); found {
- return discardMessagePublish(ctx, in, func(response *pb.ValuedResponse) error {
- logger.Warn("消息重复被丢弃", zap.String("key", data))
- return nil
- })
- }
- cacheInstance.Set(md5, "alarm", cache.DefaultExpiration)
- }
- // 正常发送消息
- return &pb.ValuedResponse{
- Type: pb.ValuedResponse_STOP_AND_RETURN,
- Value: &pb.ValuedResponse_Message{
- Message: in.GetMessage(),
- },
- }, nil
- }
- func discardMessagePublish(ctx context.Context, in *pb.MessagePublishRequest, responseWriter func(*pb.ValuedResponse) error) (*pb.ValuedResponse, error) {
- // 增加对输入参数的非空验证,防止空指针异常
- if in == nil || in.Message == nil {
- return nil, errors.New("输入参数不能为空")
- }
- emptyPayload := []byte{}
- newMsg := &pb.Message{
- Id: in.Message.Id,
- Node: in.Message.Node,
- From: in.Message.From,
- Topic: in.Message.Topic,
- Payload: emptyPayload,
- Headers: map[string]string{"allow_publish": "false"},
- }
- reply := &pb.ValuedResponse{
- Type: pb.ValuedResponse_STOP_AND_RETURN,
- Value: &pb.ValuedResponse_Message{
- Message: newMsg,
- },
- }
- if err := responseWriter(reply); err != nil {
- simple_zap.Logger.Error("发送响应时出错", zap.String("key", reply.GetMessage().GetTopic()))
- return nil, errors.Wrap(err, "发送响应时出错")
- }
- // 正常发送消息之后
- return reply, nil
- }
- func main() {
- lis, err := net.Listen("tcp", global.ServerSetting.Port)
- if err != nil {
- log.Fatalf("failed to listen: %v", err)
- }
- s := grpc.NewServer()
- pb.RegisterHookProviderServer(s, &server{})
- log.Println("Started gRPC server on", global.ServerSetting.Port)
- if err := s.Serve(lis); err != nil {
- log.Fatalf("failed to serve: %v", err)
- }
- }
|