Dask xarray vert interpolate

PHOTO EMBED

Thu Feb 08 2024 01:11:37 GMT+0000 (Coordinated Universal Time)

Saved by @diptish #dask #xarray #parallel

import xarray as xr
import numpy as np
from typing import *
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
#import proplot as pplt
import datetime as dt
import warnings
warnings.filterwarnings('ignore')
import os
import sys
import geocat.comp.interpolation
import dask


from dask.diagnostics import ProgressBar
from datetime import datetime

from dask.distributed import Client, LocalCluster
cluster = LocalCluster()
client = Client(cluster)

import warnings
warnings.filterwarnings("ignore")

client

def func(obj, ps,hyam,hybm,orog):
    p0Con=100000.0
    newLevs=np.array([100000., 97500., 95000., 92500., 90000., 87500., 85000., 82500., 80000.,
               77500., 75000., 70000., 65000., 60000., 55000., 50000., 45000., 40000.,
               35000., 30000., 25000., 22500., 20000., 17500., 15000., 12500., 10000.,
               7000., 5000., 3000., 2000., 1000., 700., 500., 300., 200., 100.], dtype=float)
    ta_new=geocat.comp.interpolation.interp_hybrid_to_pressure(obj[:,:,:,:], ps[:,:,:], hyam,
                    hybm, p0=p0Con, new_levels=newLevs,
                    lev_dim='lev', method='linear', extrapolate=True, variable='temperature',
                                                           t_bot=obj[:,0,:,:], phi_sfc=orog)
    return ta_new
OroDir='/blue/dhingmire/scripts/regridCMIP6ToERA5/vertical_interp/'
### constants
p0Con=100000.0
newLevs=np.array([100000., 97500., 95000., 92500., 90000., 87500., 85000., 82500., 80000.,
               77500., 75000., 70000., 65000., 60000., 55000., 50000., 45000., 40000.,
               35000., 30000., 25000., 22500., 20000., 17500., 15000., 12500., 10000.,
               7000., 5000., 3000., 2000., 1000., 700., 500., 300., 200., 100.], dtype=float)
#newLevs
inOroF=xr.open_dataset(OroDir+'orog_fx_CESM2_historical_r11i1p1f1_gn.nc')
orog=inOroF.orog*9.80616
#orog.plot()
InDir='/blue/dhingmire/CMIP6_WRFIn/CESM2/historical/ta/'
OutDir='/blue/dhingmire/CMIP6_WRFIn/CESM2/historical/remapped/ta/'

p0Con=100000.0
newLevs=np.array([100000., 97500., 95000., 92500., 90000., 87500., 85000., 82500., 80000.,
               77500., 75000., 70000., 65000., 60000., 55000., 50000., 45000., 40000.,
               35000., 30000., 25000., 22500., 20000., 17500., 15000., 12500., 10000.,
               7000., 5000., 3000., 2000., 1000., 700., 500., 300., 200., 100.], dtype=float)


outfiles=os.listdir(OutDir)
for file in os.listdir(InDir):
    print(file)

    inTemp=xr.open_dataset(InDir+file)

    ps=inTemp.ps
    hyam=inTemp.a
    hybm=inTemp.b
    ta=inTemp.ta

    ta_ref=geocat.comp.interpolation.interp_hybrid_to_pressure(ta[0,:,:,:], ps[0,:,:], hyam,
                    hybm, p0=p0Con, new_levels=newLevs,
                    lev_dim='lev', method='linear', extrapolate=True, variable='temperature',
                                                           t_bot=ta[0,0,:,:], phi_sfc=orog)
    chunk=10
    for t in np.unique(inTemp.time.dt.year.values):
        year=str(t)
        print(year)
        startDate=year+'-01-01'
        endDate=year+'-12-31'
        outF=OutDir+'ta_6hrPlev_'+year+'.nc'
        outFname='ta_6hrPlev_'+year+'.nc'
        
        if(not outFname in (outfiles)): 
        
            taIn=ta.sel(time=slice(startDate,endDate)).chunk({'lat': -1, 'lon': -1, 'time': chunk, 'lev':-1})
            psIn=ps.sel(time=slice(startDate,endDate)).chunk({'lat': -1, 'lon':-1, 'time': chunk})
            ta_Sample=ta_ref.expand_dims(dim={"time": ta.sel(time=slice(startDate,endDate)).time}).chunk({'lat': -1, 'lon': -1, 'time': chunk, 'plev':-1})


            mapped = taIn.map_blocks(func, args=[ psIn, hyam, hybm, orog],template=ta_Sample)
            ta_=mapped.persist()
            fb=ta_.chunk(time=100)
            fb.to_netcdf(outF)
            del(mapped)
            del(ta_)
            del(fb)
content_copyCOPY