goProject/trunk/framework/webServer/context.go

375 lines
9.0 KiB
Go
Raw Normal View History

2025-01-06 16:01:02 +08:00
package webServer
import (
"bytes"
"compress/zlib"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"time"
"goutil/logUtil"
"goutil/netUtil"
"goutil/typeUtil"
"goutil/zlibUtil"
)
const (
defaultMaxMemory = 32 << 20 // 32 MB
)
// 请求上下文对象
type Context struct {
// 用户自定义数据(注册回调时设置通过Context传到用户回调) Add:raojianhua Time:2022-02-16 17:00
userData interface{}
// 请求对象
request *http.Request
// 响应对象
responseWriter http.ResponseWriter
// 数据是否已经解析数据
ifBodyParsed bool
// 请求数据
bodyContent []byte
// Form的数据是否已经解析
ifFormParsed bool
// MultipleForm的数据是否已经解析
ifMultipartFormParsed bool
// 是否经过了负载均衡的中转如果是则需要通过header获取客户端实际的请求地址
ifDelegate bool
// 请求的开始时间
StartTime time.Time
// 请求的结束时间
EndTime time.Time
// 处理请求数据的处理器(例如压缩、解密等)
requestDataHandler func(*Context, []byte) ([]byte, error)
// 处理响应数据的处理器(例如压缩、加密等)
responseDataHandler func(*Context, []byte) ([]byte, error)
}
// 获取用户自定义数据(注册回调时设置通过Context传到用户回调) Add:raojianhua Time:2022-02-16 17:00
func (this *Context) GetUserData() interface{} {
return this.userData
}
// 获取请求对象
func (this *Context) GetRequest() *http.Request {
return this.request
}
// 获取响应对象
func (this *Context) GetResponseWriter() http.ResponseWriter {
return this.responseWriter
}
// 获取请求路径(不带参数)
func (this *Context) GetRequestPath() string {
return this.request.URL.Path
}
// 获取请求的客户端的IP地址
func (this *Context) GetRequestIP() string {
if this.ifDelegate {
return netUtil.GetHttpAddr2(this.request).Host
}
return netUtil.GetHttpAddr(this.request).Host
}
// 获取请求执行的秒数
func (this *Context) GetExecuteSeconds() int64 {
return this.EndTime.Unix() - this.StartTime.Unix()
}
// 格式化context对象
func (this *Context) String() string {
var bodyContent string
if bytes, exist, err := this.GetRequestBytes(); err == nil && exist {
bodyContent = string(bytes)
}
return fmt.Sprintf("IP:%s, URL:%s, FormValue:%#v, BodyContent:%s",
this.GetRequestIP(), this.GetRequestPath(), this.GetFormValueData(), bodyContent)
}
func (this *Context) parseForm() {
if !this.ifFormParsed {
// 先保存一份body
this.parseBodyContent()
this.request.ParseMultipartForm(32 << 20)
this.ifFormParsed = true
}
}
// 获取请求的参数值包括GET/POST/PUT/DELETE等所有参数
func (this *Context) FormValue(key string) (value string) {
this.parseForm()
return this.request.FormValue(key)
}
// 获取POST的参数值
func (this *Context) PostFormValue(key string) (value string) {
this.parseForm()
return this.request.PostFormValue(key)
}
// 获取所有参数的MapData类型包括GET/POST/PUT/DELETE等所有参数
func (this *Context) GetFormValueData() typeUtil.MapData {
this.parseForm()
valueMap := make(map[string]interface{})
for k, v := range this.request.Form {
valueMap[k] = v[0]
}
return typeUtil.MapData(valueMap)
}
// 获取POST参数的MapData类型
func (this *Context) GetPostFormValueData() typeUtil.MapData {
this.parseForm()
valueMap := make(map[string]interface{})
for k, v := range this.request.PostForm {
valueMap[k] = v[0]
}
return typeUtil.MapData(valueMap)
}
func (this *Context) parseMultipartForm() {
if !this.ifMultipartFormParsed {
this.request.ParseMultipartForm(defaultMaxMemory)
this.ifMultipartFormParsed = true
}
}
// 获取MultipartForm的MapData类型
func (this *Context) GetMultipartFormValueData() typeUtil.MapData {
this.parseMultipartForm()
valueMap := make(map[string]interface{})
if this.request.MultipartForm != nil {
for k, v := range this.request.MultipartForm.Value {
valueMap[k] = v[0]
}
}
return typeUtil.MapData(valueMap)
}
func (this *Context) parseBodyContent() (err error) {
if this.ifBodyParsed {
return
}
defer func() {
this.request.Body.Close()
this.ifBodyParsed = true
if len(this.bodyContent) > 0 {
this.request.Body = ioutil.NopCloser(bytes.NewBuffer(this.bodyContent))
}
}()
this.bodyContent, err = ioutil.ReadAll(this.request.Body)
if err != nil {
logUtil.ErrorLog(fmt.Sprintf("url:%s,read body failed. Err%s", this.GetRequestPath(), err))
return
}
return
}
// 获取请求字节数据
// 返回值:
// []byte:请求字节数组
// exist:是否存在数据
// error:错误信息
func (this *Context) GetRequestBytes() (result []byte, exist bool, err error) {
if err = this.parseBodyContent(); err != nil {
return
}
result = this.bodyContent
if result == nil || len(result) == 0 {
return
}
// handle request data
if this.requestDataHandler != nil {
if result, err = this.requestDataHandler(this, result); err != nil {
return
}
}
exist = true
return
}
// 获取请求字符串数据
// 返回值:
// result:请求字符串数据
// exist:是否存在数据
// error:错误信息
func (this *Context) GetRequestString() (result string, exist bool, err error) {
var data []byte
if data, exist, err = this.GetRequestBytes(); err != nil || !exist {
return
}
result = string(data)
exist = true
return
}
// 反序列化为对象JSON
// obj:反序列化结果数据
// isCompressed:数据是否已经被压缩
// 返回值:
// 错误对象
func (this *Context) Unmarshal(obj interface{}) (exist bool, err error) {
var data []byte
if data, exist, err = this.GetRequestBytes(); err != nil || !exist {
return
}
// Unmarshal
if err = json.Unmarshal(data, &obj); err != nil {
logUtil.ErrorLog(fmt.Sprintf("Unmarshal %s failed. Err:%s", string(data), err))
return
}
exist = true
return
}
// 反序列化为对象zlib+JSON
// obj:反序列化结果数据
// isCompressed:数据是否已经被压缩
// 返回值:
// 错误对象
func (this *Context) UnmarshalZlib(obj interface{}) (exist bool, err error) {
var data []byte
if data, exist, err = this.GetRequestBytes(); err != nil || !exist {
return
}
zlibBytes, err := zlibUtil.Decompress(data)
if err != nil {
logUtil.ErrorLog(fmt.Sprintf("Decompress %v failed. Err:%s", string(data), err))
return
}
// Unmarshal
if err = json.Unmarshal(zlibBytes, &obj); err != nil {
logUtil.ErrorLog(fmt.Sprintf("Unmarshal %s failed. Err:%s", string(data), err))
return
}
exist = true
return
}
func (this *Context) writeBytes(data []byte) error {
var err error
if this.responseDataHandler != nil {
if data, err = this.responseDataHandler(this, data); err != nil {
return err
}
}
_, err = this.responseWriter.Write(data)
return err
}
// 输出字符串给客户端
func (this *Context) WriteString(result string) error {
return this.writeBytes([]byte(result))
}
// 输出json数据给客户端
func (this *Context) WriteJson(result interface{}) error {
data, err := json.Marshal(result)
if err != nil {
logUtil.ErrorLog(fmt.Sprintf("Marshal %v failed. Err:%s", result, err))
return err
}
return this.writeBytes(data)
}
// 输出zlib+json数据给客户端
func (this *Context) WriteZlibJson(result interface{}) error {
data, err := json.Marshal(result)
if err != nil {
logUtil.ErrorLog(fmt.Sprintf("Marshal %v failed. Err:%s", result, err))
return err
}
zlibBytes, err := zlibUtil.Compress(data, zlib.DefaultCompression)
if err != nil {
logUtil.ErrorLog(fmt.Sprintf("Compress %v failed. Err:%s", string(data), err))
return err
}
return this.writeBytes(zlibBytes)
}
/*
Http状态码301和302概念简单区别
302重定向表示临时性转移(Temporarily Moved )当一个网页URL需要短期变化时使用 301重定向是永久的重定向搜索引擎在抓取新内容的同时也将旧的网址替换为重定向之后的网址 302重定向是临时的重定向搜索引擎会抓取新的内容而保留旧的网址
如果是短链需要每次都访问原来的短链地址以便于进行统计则使用302重定向
*/
// 重定向到其它页面(301)
func (this *Context) RedirectTo(url string) {
this.responseWriter.Header().Set("Location", url)
this.responseWriter.WriteHeader(301)
}
// 重定向到其它页面(302)
func (this *Context) RedirectTo302(url string) {
this.responseWriter.Header().Set("Location", url)
this.responseWriter.WriteHeader(302)
}
// 新建API上下文对象
// request:请求对象
// responseWriter:应答写对象
// _ifDelegate:是否经过负载均衡
// 返回值:
// *Context:上下文
func newContext(request *http.Request, responseWriter http.ResponseWriter,
requestDataHandler func(*Context, []byte) ([]byte, error),
responseDataHandler func(*Context, []byte) ([]byte, error),
_ifDelegate bool) *Context {
return &Context{
request: request,
responseWriter: responseWriter,
StartTime: time.Now(),
EndTime: time.Now(),
requestDataHandler: requestDataHandler,
responseDataHandler: responseDataHandler,
ifDelegate: _ifDelegate,
}
}