一、GroupBy: split-apply-combine
这部分同pandas的gorupby函数基本相同,实现对数据的分组归类等等。
split·将数据分为多个独立的组。
apply·对各个组进行操作。
combine·将各个组合并为一个数据对象。
1.split
创建一个dataset
ds = xr.Dataset({"u": (("lat", "lon"), np.random.rand(4, 3))},coords={"latitude": [10, 20, 30, 40], "country": ("x", list("abba"))})
print(ds)
#<xarray.Dataset>
#Dimensions: (lat: 4, latitude: 4, lon: 3, x: 4)
#Coordinates:
# * latitude (lat) int64 10 20 30 40
# country (lat) <U1 'a' 'b' 'b' 'a'
#Dimensions without coordinates: lat, lon, x
#Data variables:
# u (lat, lon) float64 0.7656 0.3545 0.7117 ... 0.3729 0.4142 0.7367
我对官网的例子加以修改以便更好的理解。
解释下数据结构,创建了一个二维数据u(lat, lon),坐标数据为latitude 和country ,强调一下这里创建的是dataset,而不是dataArray,分不清的可以再看看本系列的第一篇文章。坐标数据不等于u的坐标。创建coords部分都指明了latitude 和 country 都是针对lat的扩展。
我们可以这样理解,对于纬度的分类,我们可以按纬度的大小分,也就是"latitude": [10, 20, 30, 40] ; 我们也可以对纬度所在的国家分,"country": ("x", list("abba") ,那比如我们想求某个国家的数据的平均时就十分方便。
下边我们进行分组:
print(ds.groupby('country').groups)
#{'a': [0, 3], 'b': [1, 2]}
说明第0和第4个数是国家a的,第2和第3是国家b的。
.groups换成.mean() 则就是对分组求平均,以此类推。
print(list(ds.groupby('country')))
#[('a', <xarray.Dataset>
#Dimensions: (lat: 2, lon: 3)
#Coordinates:
# * lat (lat) int64 10 40
# country (lat) <U1 'a' 'a'
#Dimensions without coordinates: lon
#Data variables:
# u (lat, lon) float64 0.4561 0.1734 0.8716 0.529 0.867 0.8497), ('b', <xarray.Dataset>
#Dimensions: (lat: 2, lon: 3)
#Coordinates:
# * lat (lat) int64 20 30
# country (lat) <U1 'b' 'b'
#Dimensions without coordinates: lon
#Data variables:
# u (lat, lon) float64 0.4886 0.5533 0.9398 0.1959 0.1999 0.3381)]
必须添加一个list才可以将其分类结果打印出来。直接打印DatasetGroupBy object是不能输出结果的。
那么针对经纬度的坐标的分组怎么实现呢,比如说选出区间在多少到多少之间的?
.groupby_bins() 函数可以解决这一问题。
还是这个数据,"latitude": [10, 20, 30, 40]
那我们想以25为界,分为两组,0-25,25-50
x_bins = [0,25,50]
print(ds.groupby_bins('lat', x_bins).groups)
#{Interval(0, 25, closed='right'): [0, 1],
#Interval(25, 50, closed='right'): [2, 3]}
2. Apply
在进行了分组后,要对各个分组进行计算。
我们先从dataset 中取出 u 这个dataarray
u = ds["u"]
print(u)
#<xarray.DataArray 'u' (lat: 4, lon: 3)>
#array([[0.45610017, 0.17344946, 0.87160651],
# [0.48857117, 0.5533467 , 0.93982257],
# [0.19588396, 0.19992985, 0.33810041],
# [0.52899729, 0.86697401, 0.84972525]])
#Coordinates:
# * lat (lat) int64 10 20 30 40
# country (lat) <U1 'a' 'b' 'b' 'a'
#Dimensions without coordinates: lon
比如是实现前边提到的按国家进行数据平均,或者标准化
u.groupby('country').mean(dim='lat')
也可以通过map()函数使用一些自定义的函数,比如说标准化,
def standardize(x):
return (x - x.mean()) / x.std()
print(u.groupby('country').map(standardize))
这个用法是官方提供的,但是我的Xarray版本过低,还不支持这种用法(Xarray会定期更新,以至于可能我介绍过的一些方法有了更简便的操作,大家可以在评论区留言)。
强调一句,Xarray官方的更新是比较快的,很可能我写在这里的函数官方又给出了更新的版本,但是我没办法做到时刻与官方最新同步,所以如果遇到问题,最好的解决办法还是去查阅官方文档的对应部分。