在 Go 语言下载文件 http.Get() 和 io.Copy() 章节我们讲解了如何使用 Go 实现一个下载器,可以实现下载任何大小的任意文件。
一般情况下,这个下载器已经足够使用了,但是,在下载文件时,我们往往想知道当前进度是多少。这时候,直接使用 io.Copy()
就无能为了。
实现下载文件进度的方法其实很简单,就是实现分段读取响应流。
因为响应是一个流 ( stream ) ,就像水管里的水一样,源源不竭,直到响应流读取完毕。所以实现下载进度的方法只有两种
-
使用 HTTP
Content-Range
请求头断点续传的思维,实现多段分别下载。 -
从返回的响应流中分段读取。
第一种方式是断点续传,有空我们以后再讲,本章节我们来讲讲如何实现分段读取。
分段读取
分段读取的思想就像拿着一个桶从水龙头接水一样。当水桶满了就先倒出来,然后继续接水。
io.TeeReader() 方法
Go 语言的 io
包下的 io.TeeReader()
方法实现了分段读取的思想。该方法的原型如下
func TeeReader(r Reader, w Writer) Reader
io.TeeReader()
方法返回一个 Reader
,用于写入从 r
分段读取的内容。
很多人应该会对 w
这个参数疑惑。其实它就是一个中间水桶的作用。它从一个输入流中读取指定大小字节的数据,并把数据写入到另一个输出流中。
该方法没有内部缓冲 - 写入必须在读取完成之前完成。写入时遇到的任何错误都会报告为读取错误。
我们写一个小的范例演示下这个方法的使用
package main import ( "bytes" "fmt" "io" "io/ioutil" "log" "strings" ) func main() { r := strings.NewReader("some io.Reader stream to be read\n") var buf bytes.Buffer tee := io.TeeReader(r, &buf) printall := func(r io.Reader) { b, err := ioutil.ReadAll(r) if err != nil { log.Fatal(err) } fmt.Printf("%s", b) } printall(tee) printall(&buf) }
运行结果如下
[yufei@www.twle.cn helloworld]$ go run down.go some io.Reader stream to be read some io.Reader stream to be read
io.TeeReader() 实现下载进度
我们模拟下载一个比较大的文件,比如 https://dl.google.com/go/go1.11.1.src.tar.gz
我们使用 io.TeeReader()
实现一个下载计数器来跟踪进度。
package main import ( "fmt" "io" "net/http" "os" "strings" ) // WriteCounter counts the number of bytes written to it. It implements to the io.Writer // interface and we can pass this into io.TeeReader() which will report progress on each // write cycle. type WriteCounter struct { Total uint64 } func (wc *WriteCounter) Write(p []byte) (int, error) { n := len(p) wc.Total += uint64(n) wc.PrintProgress() return n, nil } func (wc WriteCounter) PrintProgress() { // Clear the line by using a character return to go back to the start and remove // the remaining characters by filling it with spaces fmt.Printf("\r%s", strings.Repeat(" ", 35)) // Return again and print current status of download // We use the humanize package to print the bytes in a meaningful way (e.g. 10 MB) fmt.Printf("\rDownloading... %d B complete", wc.Total) } func main() { fmt.Println("Download Started") fileUrl := "https://dl.google.com/go/go1.11.1.src.tar.gz" err := DownloadFile("go1.11.1.src.tar.gz", fileUrl) if err != nil { panic(err) } fmt.Println("Download Finished") } // DownloadFile will download a url to a local file. It's efficient because it will // write as it downloads and not load the whole file into memory. We pass an io.TeeReader // into Copy() to report progress on the download. func DownloadFile(filepath string, url string) error { // Create the file, but give it a tmp file extension, this means we won't overwrite a // file until it's downloaded, but we'll remove the tmp extension once downloaded. out, err := os.Create(filepath + ".tmp") if err != nil { return err } //defer out.Close() 看评论 // Get the data resp, err := http.Get(url) if err != nil { return err } defer resp.Body.Close() // Create our progress reporter and pass it to be used alongside our writer counter := &WriteCounter{} _, err = io.Copy(out, io.TeeReader(resp.Body, counter)) if err != nil { return err } out.Close() // 看评论 // The progress use the same line so print a new line once it's finished downloading fmt.Print("\n") err = os.Rename(filepath+".tmp", filepath) if err != nil { return err } return nil }
运行结果如下
[yufei@www.twle.cn helloworld]$ go run down.go Download Started Downloading... 21097206 B complete Download Finished [yufei@www.twle.cn helloworld]$
你自己试一下,那个数字是会自己跳动的。
整个实现中,最重要的是 WriteCounter
这个结构体,这个结构体下的 Write(p []byte)
方法的参数 p
就是每段读取的内容。