require 'getopts'
require 'numru/gphys'

module NumRu

  class GPhys
    module EP_Flux

      module_function

      #<<< calculation method >>> --------------------------------------------- 
      def ep_full_sphere_in_pcoord(gp_u, gp_v, gp_omega, gp_t, 
                         flag_temp_or_theta=true, xypdims=[0,1,2]) ## get axis and name
        p "this is pcoord"
	raise ArgumentError,"xypdims's size (#{xypdims.size}) must be 3." if xypdims.size != 3 
	ax_lon = gp_u.axis(xypdims[0]) # Axis of longitude 
	ax_lat = gp_u.axis(xypdims[1]) # Axis of latitude 
	ax_p =   gp_u.axis(xypdims[2]) # Axis of vertical 
	lon_nm, lat_nm, p_nm = ax_lon.pos.name, ax_lat.pos.name, ax_p.pos.name
	gp_lon, gp_lat, gp_p = make_gphys(ax_lon, ax_lat, ax_p)
	
	## convert axes
	gp_lon = to_rad_if_deg(gp_lon)    # deg => rad (unit convesion)
	gp_lat = to_rad_if_deg(gp_lat)    # deg => rad (unit convesion)
	gp_t = to_theta_if_temperature(gp_t, gp_p, flag_temp_or_theta)	
                      # temperature => potential temperature (if flag is true)

	## replace grid (without duplicating data)
	grid = gp_u.grid_copy
	old_grid = gp_u.grid_copy                 # saved to use in outputs
	grid.axis(lon_nm).pos = gp_lon.data       # in radian
	grid.axis(lat_nm).pos = gp_lat.data       # in radian
	grid.axis(p_nm).pos = gp_p.data           # p 
	gp_u = GPhys.new(grid, gp_u.data)
	gp_v = GPhys.new(grid, gp_v.data)
	gp_omega = GPhys.new(grid, gp_omega.data)
	gp_t = GPhys.new(grid, gp_t.data)
	## get each term
        #  needed in F_phi and F_p
	uv_dash, vt_dash, uw_dash = eddy_products(gp_u, gp_v, gp_omega, gp_t, lon_nm)
	theta_mean = gp_t.mean(lon_nm)
	dtheta_dp = deriv(theta_mean, p_nm)
	cos_lat = cos(gp_lat)
        a_cos_lat = @@radius * cos_lat
  	a_cos_lat.data.rename!('a_cos_lat')
	a_cos_lat.data.set_att('long_name', 'radius * cos_lat')
        remove_0_at_poles(a_cos_lat)
        #  needed in F_phi only
	u_mean = gp_u.mean(lon_nm)
	du_dp  = deriv(u_mean, p_nm)
        #  needed in F_p only
	f_cor = 2 * (2 * PI / @@rot_period) * sin(gp_lat) 
  	f_cor.data.rename!('f_cor')
	f_cor.data.set_att('long_name', 'Coriolis parameter')
	ducos_dphi = deriv( u_mean * cos_lat, lat_nm)
	avort = (-ducos_dphi/a_cos_lat) + f_cor        # -- absolute vorticity
	avort.data.units = "s-1"
	avort.data.rename!('avort')
	avort.data.set_att('long_name', 'zonal mean absolute vorticity')

	## F_phi, F_p
	epflx_phi = ( - uv_dash + du_dp*vt_dash/dtheta_dp ) * cos_lat
	epflx_p   = ( - uw_dash + avort*vt_dash/dtheta_dp ) * cos_lat
	epflx_phi.data.name = "epflx_phi"; epflx_p.data.name = "epflx_p"
	epflx_phi.data.set_att("long_name", "EP flux phi component")
	epflx_p.data.set_att("long_name", "EP flux p component")

	## v_rmean, w_rmean
	v_mean = gp_v.mean(lon_nm); w_mean = gp_omega.mean(lon_nm)
	v_rmean = ( v_mean - deriv( (vt_dash/dtheta_dp), p_nm ) )
	w_rmean = ( w_mean + deriv( (vt_dash/dtheta_dp*cos_lat), lat_nm )/a_cos_lat )
	v_rmean.data.name = "v_rmean"; w_rmean.data.name = "w_rmean"
	v_rmean.data.set_att("long_name", "residual zonal mean V")
	w_rmean.data.set_att("long_name", "residual zonal mean W")

	## convert with past grid
	gp_ary = [] # grid convertes gphyss into 
	grid_xmean = old_grid.delete_axes(lon_nm)
	[epflx_phi, epflx_p, v_rmean, w_rmean, gp_lat, gp_p, u_mean, theta_mean, 
         uv_dash, vt_dash, uw_dash, dtheta_dp].each {|gp|  
	  if grid_xmean.shape.size != gp.shape.size
	    gp_ary << gp
	  else
	    gp_ary << GPhys.new(grid_xmean, gp.data) #back to the original grid
	  end
	}
	return gp_ary
      end

      def div_sphere_in_pcoord(gp_fphi, gp_fp, ypdims=[0,1])
	raise ArgumentError,"ypdims's size (#{ypdims.size}) must be 2." if ypdims.size != 2
	## get axis and name
	ax_lat = gp_fphi.axis(ypdims[0])    # Axis of latitude
	ax_p   = gp_fphi.axis(ypdims[1])    # Axis of vertical
        lat_nm, p_nm = ax_lat.pos.name, ax_p.pos.name
	gp_lat, gp_p = make_gphys(ax_lat, ax_p)
	## convert
	gp_lat = to_rad_if_deg(gp_lat)    # deg => rad (unit convesion)

	## replace grid (without duplicating data)
	grid = gp_fphi.grid_copy
	cp_grid = gp_fphi.grid_copy         # saved to use in outputs
	grid.axis(lat_nm).pos = gp_lat.data
	grid.axis(p_nm).pos = gp_p.data
	gp_fphi = GPhys.new(grid, gp_fphi.data)
	gp_fp = GPhys.new(grid, gp_fp.data)

	## d_F_phi_dz
	a_cos_lat = @@radius * cos(gp_lat)
        remove_0_at_poles(a_cos_lat)	
	d_gp_fphi_d_phi = deriv(gp_fphi * cos(gp_lat), lat_nm)
	## d_F_p_dp
	d_gp_fp_d_p =   deriv(gp_fp, p_nm)
	f_div = ( d_gp_fphi_d_phi / a_cos_lat )  + d_gp_fp_d_p

	f_div.data.name = "epflx_div"
	f_div.data.set_att("long_name", "EP Flux divergence")
	## convert with past grid
	return GPhys.new(cp_grid, f_div.data)
      end


      def to_theta_if_temperature(gp_t, gp_p, flag_temp_or_theta=true) 
 	if flag_temp_or_theta
	  gp_un = gp_t.data.units
	  gp_t = gp_t.convert_units(Units.new("K"))
	  gp_t = gp_t*(@@p00/gp_p)**((@@gas_const/@@cp).to_f)
#	  gp_t = gp_t*(@@p00/gp_p)**(@@gas_const/@@cp)
	  gp_t.data.set_att('long_name', "Potential Temperature")
	end
	return gp_t
      end
      
      
    end  
  end

  class NetCDFVar

    def get_with_miss_and_scaling2(*args) # make mask before scaling
      __interpret_missing_params if !defined?(@missval)
      packed_data = simple_get(*args)
      scaled_data = scaled_get(*args)
      sf = att('scale_factor')
      ao = att('add_offset')
      if @vmin || @vmax
	if sf && ao
	  csf = sf.get
	  cao = ao.get
	  vmin = (@vmin-cao)/csf if @vmin
	  vmax = (@vmax-cao)/csf if @vmax
	elsif
	  vmin = @vmin; vmax = @vmax
	end
	if vmin
	  mask = (packed_data >= vmin) 
	mask = mask.and(packed_data <= vmax) if vmax
	else
	  mask = (packed_data <= vmax)
	end
	data = NArrayMiss.to_nam(scaled_data, mask)
      elsif @missval	# only missing_value is present.
	eps = 1e-6
	missval = @missval[0].to_f
	vmin = missval - eps
	vmax = missval + eps
	mask = (packed_data <= vmin) 
	mask = mask.or(packed_data >= vmax)
	data = NArrayMiss.to_nam(scaled_data, mask)
      else
	data = scaled_data
      end
      data
    end

    def put_with_miss_after_scaling(data, *args)
      if data.is_a?( NArrayMiss )
	__interpret_missing_params if !defined?(@missval)
	if @missval
	  scaled_put_without_missval(data, *args)
	else
	  scaled_put(data.to_na, *args)
	end
      else
	scaled_put(data, *args)
      end
    end

    def scaled_put_without_missval(var,hash=nil)
      sf = att('scale_factor')
      ao = att('add_offset')
      if ( sf == nil && ao == nil ) then
	# no scaling --> just call put
	simple_put(var,hash)
      else
	if (sf != nil) 
	  csf = sf.get
	  if csf.is_a?(NArray) then  # --> should be a numeric
	    csf = csf[0]
	  elsif csf.is_a?(String)
	    raise TypeError, "scale_factor is not a numeric"
	  end
	  if(csf == 0) then; raise NetcdfError, "zero scale_factor"; end
	else
	  csf = 1.0      # assume 1 if not defined
	end
	if (ao != nil) 
	  cao = ao.get
	if cao.is_a?(NArray) then  # --> should be a numeric
	  cao = cao[0]
	elsif csf.is_a?(String)
	  raise NetcdfError, "add_offset is not a numeric"
	end
	else
	  cao = 0.0      # assume 0 if not defined
	end
	var = var.to_na( @missval[0]*csf + cao)
	simple_put( (var-cao)/csf, hash )
      end
    end  
    
  end
end

################################################################################
include NumRu
include Misc::EMath

unless getopts("u:", "v:", "omega:", "temp:", 
                                     "temp_is_temperature:true", "output:epflx.nc")
  print "#{$0}:illegal options.\n"
  exit 1
end

z_axs = 'level'
z_range = [100, 1000]

nc_uwnd, var_uwnd =   ($OPT_u).split(/\s*@\s*/)
nc_vwnd, var_vwnd =   ($OPT_v).split(/\s*@\s*/)
nc_omega, var_omega = ($OPT_omega).split(/\s*@\s*/)
nc_temp, var_temp =   ($OPT_temp).split(/\s*@\s*/)

gp_u =     GPhys::IO.open( nc_uwnd,  var_uwnd ).cut(z_axs=>z_range[0]..z_range[-1])
gp_v =     GPhys::IO.open( nc_vwnd,  var_vwnd ).cut(z_axs=>z_range[0]..z_range[-1])
gp_omega = GPhys::IO.open( nc_omega, var_omega).cut(z_axs=>z_range[0]..z_range[-1])
gp_t =     GPhys::IO.open( nc_temp,  var_temp ).cut(z_axs=>z_range[0]..z_range[-1])

ofile = NetCDF.create($OPT_output)


if gp_u.rank == 3
  
  epflx_y, epflx_z, v_rmean, w_rmean, gp_lat, gp_z, = ary =
    GPhys::EP_Flux::ep_full_sphere_in_pcoord(gp_u, gp_v, gp_omega, gp_t, true)
  gp_lat.rename('phi')

  ary.each{|gp|                                  #  This part will not 
    gp.data.att_names.each{|nm|                  #  be needed in future.
      gp.data.del_att(nm) if /^valid_/ =~ nm     #  (Even now, it is not
    }                                            #  needed if the valid
  }                                              #  range is wide enough)
  
  epflx_div = GPhys::EP_Flux::div_sphere_in_pcoord(epflx_y, epflx_z)

  a       = GPhys::EP_Flux::radius               # get planetaly radius
  cos_phi = cos(gp_lat)                          # cosine phi

  GPhys::IO.write(ofile, epflx_y*a)
  GPhys::IO.write(ofile, epflx_z*a)
  GPhys::IO.write(ofile, v_rmean)
  GPhys::IO.write(ofile, w_rmean)
  GPhys::IO.write(ofile, epflx_div/cos_phi)
  GPhys::IO.write(ofile, gp_lat)
  GPhys::IO.write(ofile, gp_z)

  ofile.close

elsif gp_u.rank > 3
  
  nt = gp_u.shape[-1]
  i = 0
  GPhys::IO.each_along_dims_write([gp_u, gp_v, gp_omega, gp_t], ofile, -1){
    |u, v, omega, t|
    i += 1
    print "processing #{i} / #{nt} ..\n" if (i % (nt/20+1))==1
    epflx_y, epflx_z, v_rmean, w_rmean, gp_lat, gp_z, u_mean, theta_mean,
      uv_dash, vt_dash, uw_dash, dtheta_dz = ary =
      GPhys::EP_Flux::ep_full_sphere_in_pcoord(u, v, omega, t, true)
    epflx_div = GPhys::EP_Flux::div_sphere_in_pcoord(epflx_y, epflx_z)
    
    ary.each{|gp|                                  #  This part will not
      gp.data.att_names.each{|nm|                  #  be needed in future.
	gp.data.del_att(nm) if /^valid_/ =~ nm     #  (Even now, it is not
      }                                            #  needed if the valid
    }                                              #  range is wide enough)
    
    a       = GPhys::EP_Flux::radius               # get planetaly radius
    cos_phi = cos(gp_lat)                          # cosine phi
    
    if i==1    # time independent => write only once
      gp_lat.rename('phi')
      GPhys::IO.write(ofile, gp_lat)
    end
    [ epflx_y*a, epflx_z*a, v_rmean, w_rmean, epflx_div/cos_phi, u_mean, 
      theta_mean, uv_dash, vt_dash, uw_dash, dtheta_dz ]
  }
  
  ofile.close

end
