Пишем DNS proxy на Go

wt_k9byeikdbemab1xqi7dyhfwa.jpeg

Давно хотел решить проблему с рекламой. Наиболее простым способом сделать это на всех устройствах оказалось поднятие своего DNS сервера с блокированием запросов на получений IP адресов рекламных доменов.

Первым делом я начал с использования dnsmasq, но мне захотелось грузить списки из интернета и получать какую-нибудь статистику по использованию. Поэтому я и решил писать свой сервер.
Конечно, он написан не полностью с нуля, вся работа с DNS взята из библиотеки github.com/miekg/dns

Конфигурация
Работать программа начинает, конечно же, с загрузки конфигурационного файла. Сразу подумал о необходимости автоматической подгрузки конфига при его изменении дабы избежать рестарта сервера. Для этого пригодился пакет fsnotify.

Структура конфига:

type Config struct {
	Nameservers    []string      `yaml:"nameservers"`
	Blocklist      []string      `yaml:"blocklist"`
	BlockAddress4  string        `yaml:"blockAddress4"`
	BlockAddress6  string        `yaml:"blockAddress6"`
	ConfigUpdate   bool          `yaml:"configUpdate"`
	UpdateInterval time.Duration `yaml:"updateInterval"`
}

Тут самым интересным моментом является слежение за обновлениями файла конфигурации. С помощью библиотеки делается это довольно просто: мы создаем Watcher, цепляем к нему файл и слушаем события из канала. True Go!

Код
func configWatcher() {
	watcher, err := fsnotify.NewWatcher()
	if err != nil {
		log.Fatal(err)
	}
	defer watcher.Close()

	err = watcher.Add(*configFile)
	if err != nil {
		log.Fatal(err)
	}

	for {
		select {
		case event := <-watcher.Events:
			if event.Op&fsnotify.Write == fsnotify.Write {
				log.Println("Config file updated, reload config")
				c, err := loadConfig()
				if err != nil {
					log.Println("Bad config: ", err)
				} else {
					log.Println("Config successfuly updated")
					config = c
					if !c.ConfigUpdate {
						return
					}
				}
			}
		case err := <-watcher.Errors:
			log.Println("error:", err)
		}
	}
}

BlackList
Конечно, по-скольку целью стоит блокировка неугодных сайтов, то их необходимо где-то хранить. Для этого при небольшой нагрузке подойдет простая хэш-таблица пустых структур, где в качестве ключа используется блокируемый домен. Хочу заметить, что необходимо наличие точки на конце.
Но так-как у нас нет одновременного read/write, то можно обойтись без мьютексов.

Код
type BlackList struct {
	data map[string]struct{}
}

func (b *BlackList) Add(server string) bool {
	server = strings.Trim(server, " ")
	if len(server) == 0 {
		return false
	}

	if !strings.HasSuffix(server, ".") {
		server += "."
	}
	b.data[server] = struct{}{}

	return true
}

func (b *BlackList) Contains(server string) bool {
	_, ok := b.data[server]
	return ok
}

Кэширование
Изначально я думал обойтись без него, все-таки все мои устройства не создают существенного количества запросов. Но в один прекрасный вечер мой сервер каким-то образом обнаружили и начали флудить его одним и тем же запросом с частотой ~ 100 rps. Да, это немного, но ведь запросы проксируются на реальные namespace-сервера (в моем случае Google) и было бы очень неприятно получить блокировку.
Основной проблемой кэширования является большое количество различных запросов и их нужно хранить отдельно, поэтому получилась двухуровневая хеш-таблица.

Код
type Cache interface {
	Get(reqType uint16, domain string) dns.RR
	Set(reqType uint16, domain string, ip dns.RR)
}

type CacheItem struct {
	Ip dns.RR
	Die time.Time
}

type MemoryCache struct {
	cache map[uint16]map[string]*CacheItem
	locker sync.RWMutex
}

func (c *MemoryCache) Get(reqType uint16, domain string) dns.RR {
	c.locker.RLock()
	defer c.locker.RUnlock()

	if m, ok := c.cache[reqType]; ok {
		if ip, ok := m[domain]; ok {
			if ip.Die.After(time.Now()) {
				return ip.Ip
			}
		}
	}

	return nil
}

func (c *MemoryCache) Set(reqType uint16, domain string, ip dns.RR) {
	c.locker.Lock()
	defer c.locker.Unlock()

	var m map[string]*CacheItem

	m, ok := c.cache[reqType]
	if !ok {
		m = make(map[string]*CacheItem)
		c.cache[reqType] = m
	}

	m[domain] = &CacheItem{
		Ip: ip,
		Die: time.Now().Add(time.Duration(ip.Header().Ttl) * time.Second),
	}
}

Обработчик
Конечно, основной частью программы является обработчик входящих запросов, поэтому я его оставил на десерт. Основная логика примерно такая: мы получаем запрос, проверяем его наличие в блек-листе, проверяем наличие в кэше, проксируем запрос на реальные сервера.
Основной интерес представляет функция лукапа. В ней мы отправляем одновременно запрос сразу на все сервера (если успеем до прихода ответа) и ждем успешного ответа хотя бы от одного из них.

Код
func Lookup(req *dns.Msg) (*dns.Msg, error) {
	c := &dns.Client{
		Net:          "tcp",
		ReadTimeout:  time.Second * 5,
		WriteTimeout: time.Second * 5,
	}

	qName := req.Question[0].Name

	res := make(chan *dns.Msg, 1)
	var wg sync.WaitGroup
	L := func(nameserver string) {
		defer wg.Done()
		r, _, err := c.Exchange(req, nameserver)
		totalRequestsToGoogle.Inc()
		if err != nil {
			log.Printf("%s socket error on %s", qName, nameserver)
			log.Printf("error:%s", err.Error())
			return
		}
		if r != nil && r.Rcode != dns.RcodeSuccess {
			if r.Rcode == dns.RcodeServerFailure {
				return
			}
		}
		select {
		case res <- r:
		default:
		}
	}

	ticker := time.NewTicker(5 * time.Second)
	defer ticker.Stop()

	// Start lookup on each nameserver top-down, in every second
	for _, nameserver := range config.Nameservers {
		wg.Add(1)
		go L(nameserver)
		// but exit early, if we have an answer
		select {
		case r := <-res:
			return r, nil
		case <-ticker.C:
			continue
		}
	}

	// wait for all the namservers to finish
	wg.Wait()
	select {
	case r := <-res:
		return r, nil
	default:
		return nil, errors.New("can't resolve ip for" + qName)
	}
}

Метрики
Для метрики будем использовать клиент от prometheus. Используется он очень просто, сначала необходимо объявить счетчик, затем его зарегистрировать и в нужном месте вызвать метод Inc (). Главное не забыть запустить вебсервер с prometheus handler, чтобы он смог считывать метрики.

Код
var (
       totalRequestsTcp = prometheus.NewCounter(prometheus.CounterOpts(prometheus.Opts{
		Namespace: "dns",
		Subsystem: "requests",
		Name:      "total",
		Help:      "total requests",

		ConstLabels: map[string]string{
			"type": "tcp",
		},
	}))
)

func runPrometheus() {
	prometheus.MustRegister(totalRequestsTcp)

        http.Handle("/metrics", promhttp.Handler())
	log.Fatal(http.ListenAndServe(":9970", nil))
}

Думаю main не нуждается в представлении и описании. В данной статье код представлен в сокращенном формате

Полный код можно посмотреть в репозитории: github.com/GoWebProd/goDNS (конечно же приветствуются фиксы и дополнения). Также в репозитории есть файл для Docker и примерная конфигурация CI для Gitlab.

Спасибо за внимание.

© Habrahabr.ru