1// Copyright 2012 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package main
16
17import (
18 "fmt"
19 "log"
20 "net"
21 "net/http"
22 "net/http/httputil"
23 "os"
24 "os/exec"
25 "os/signal"
26 "sync"
27 "syscall"
28 "time"
29)
30
31type Proxy struct {
32 BuildLabel string
33 MaxIdleDuration time.Duration
34 PollUpdateInterval time.Duration
35
36 ul net.Listener
37 httpAddr string
38 httpsAddr string
39}
40
41func (p *Proxy) Run() error {
42 hl, err := net.Listen("tcp", "127.0.0.1:0")
43 if err != nil {
44 return fmt.Errorf("http listen failed: %v", err)
45 }
46 defer hl.Close()
47
48 hsl, err := net.Listen("tcp", "127.0.0.1:0")
49 if err != nil {
50 return fmt.Errorf("https listen failed: %v", err)
51 }
52 defer hsl.Close()
53
54 p.ul, err = DefaultSocket.Listen()
55 if err != nil {
56 c, derr := DefaultSocket.Dial()
57 if derr == nil {
58 c.Close()
59 fmt.Println("OK\nA proxy is already running... exiting")
60 return nil
61 } else if e, ok := derr.(*net.OpError); ok && e.Err == syscall.ECONNREFUSED {
62 // Nothing is listening on the socket, unlink it and try again.
63 syscall.Unlink(DefaultSocket.Path())
64 p.ul, err = DefaultSocket.Listen()
65 }
66 if err != nil {
67 return fmt.Errorf("unix listen failed on %v: %v", DefaultSocket.Path(), err)
68 }
69 }
70 defer p.ul.Close()
71 go p.closeOnSignal()
72 go p.closeOnUpdate()
73
74 p.httpAddr = hl.Addr().String()
75 p.httpsAddr = hsl.Addr().String()
76 fmt.Printf("OK\nListening on unix socket=%v http=%v https=%v\n",
77 p.ul.Addr(), p.httpAddr, p.httpsAddr)
78
79 result := make(chan error, 2)
80 go p.serveUnix(result)
81 go func() {
82 result <- http.Serve(hl, &httputil.ReverseProxy{
83 FlushInterval: 500 * time.Millisecond,
84 Director: func(r *http.Request) {},
85 })
86 }()
87 go func() {
88 result <- http.Serve(hsl, &httputil.ReverseProxy{
89 FlushInterval: 500 * time.Millisecond,
90 Director: func(r *http.Request) {
91 r.URL.Scheme = "https"
92 },
93 })
94 }()
95 return <-result
96}
97
98type socketContext struct {
99 sync.WaitGroup
100 mutex sync.Mutex
101 last time.Time
102}
103
104func (sc *socketContext) Done() {
105 sc.mutex.Lock()
106 defer sc.mutex.Unlock()
107 sc.last = time.Now()
108 sc.WaitGroup.Done()
109}
110
111func (p *Proxy) serveUnix(result chan<- error) {
112 sockCtx := &socketContext{}
113 go p.closeOnIdle(sockCtx)
114
115 var err error
116 for {
117 var uconn net.Conn
118 uconn, err = p.ul.Accept()
119 if err != nil {
120 err = fmt.Errorf("accept failed: %v", err)
121 break
122 }
123 sockCtx.Add(1)
124 go p.handleUnixConn(sockCtx, uconn)
125 }
126 sockCtx.Wait()
127 result <- err
128}
129
130func (p *Proxy) handleUnixConn(sockCtx *socketContext, uconn net.Conn) {
131 defer sockCtx.Done()
132 defer uconn.Close()
133 data := []byte(fmt.Sprintf("%v\n%v", p.httpsAddr, p.httpAddr))
134 uconn.SetDeadline(time.Now().Add(5 * time.Second))
135 for i := 0; i < 2; i++ {
136 if n, err := uconn.Write(data); err != nil {
137 log.Printf("error sending http addresses: %+v\n", err)
138 return
139 } else if n != len(data) {
140 log.Printf("sent %d data bytes, wanted %d\n", n, len(data))
141 return
142 }
143 if _, err := uconn.Read([]byte{0, 0, 0, 0}); err != nil {
144 log.Printf("error waiting for Ack: %+v\n", err)
145 return
146 }
147 }
148 // Wait without a deadline for the client to finish via EOF
149 uconn.SetDeadline(time.Time{})
150 uconn.Read([]byte{0, 0, 0, 0})
151}
152
153func (p *Proxy) closeOnIdle(sockCtx *socketContext) {
154 for d := p.MaxIdleDuration; d > 0; {
155 time.Sleep(d)
156 sockCtx.Wait()
157 sockCtx.mutex.Lock()
158 if d = sockCtx.last.Add(p.MaxIdleDuration).Sub(time.Now()); d <= 0 {
159 log.Println("graceful shutdown from idle timeout")
160 p.ul.Close()
161 }
162 sockCtx.mutex.Unlock()
163 }
164}
165
166func (p *Proxy) closeOnUpdate() {
167 for {
168 time.Sleep(p.PollUpdateInterval)
169 if out, err := exec.Command(os.Args[0], "--print_label").Output(); err != nil {
170 log.Printf("error polling for updated binary: %v\n", err)
171 } else if s := string(out[:len(out)-1]); p.BuildLabel != s {
172 log.Printf("graceful shutdown from updated binary: %q --> %q\n", p.BuildLabel, s)
173 p.ul.Close()
174 break
175 }
176 }
177}
178
179func (p *Proxy) closeOnSignal() {
180 ch := make(chan os.Signal, 10)
181 signal.Notify(ch, os.Interrupt, os.Kill, os.Signal(syscall.SIGTERM), os.Signal(syscall.SIGHUP))
182 sig := <-ch
183 p.ul.Close()
184 switch sig {
185 case os.Signal(syscall.SIGHUP):
186 log.Printf("graceful shutdown from signal: %v\n", sig)
187 default:
188 log.Fatalf("exiting from signal: %v\n", sig)
189 }
190}