ContinuousTimePolicyGradients.jl is a package for development and implementation of continuous-time policy gradient (CTPG) methods.
- This package is WIP; may include verbose tutorials for Julia, DifferentialEquations.jl, etc.
- Thanks to Namhoon Cho for the shared materials and the initial efforts to investigate CTPG methods.
- Similar packages written in Julia focusing on control policy optimisation based on continuous-time adjoint sensitivity method include
- ctpg developed by Samuel Ainsworth
- control_neuralode developed by Ilya Orson
ContinuousTimePolicyGradients.jl has been developed considering control tasks as the main application. The CTPG method for cost gradient computation can be utilised to perform policy optimisation in the setup of either background planning (offline control law learning) or decision-time planning (online control profile optimisation). The following paper addresses the optimisation of structured neural controller using CTPG with the flight vehicle normal acceleration tracking controller as an illustrating example. Please consider citing the paper if you find this package useful.
- Namhoon Cho, and Hyo-Sang Shin, "Optimisation of Structured Neural Controller Based on Continuous-Time Policy Gradient," arXiv:2201.06262, January 2022.
CTPG_train(dynamics_plant::Function, dynamics_controller::Function, cost_running::Function, cost_terminal::Function, cost_regularisor::Function, policy_NN, scenario;
solve_alg = Tsit5(), sense_alg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), ensemble_alg = EnsembleThreads(), opt_1 = ADAM(0.01), opt_2 = LBFGS(), maxiters_1 = 100, maxiters_2 = 100, progress_plot = true, solve_kwargs...)
CTPG_train()
provides a high-level interface for optimisation of the neural networks inside an ODE-represented dynamics based on Continuous-Time Policy Gradient (CTPG) methods that belong to the adjoint sensitivity analysis techniques. The code implemented and the default values for keyword arguments are specified considering training of a neural controller as the main application. In the context herein, a neural controller refers to a dynamic controller that incorporates neural-network-represented components at some points in its mathematical description.
The code utilises the functionalities provided by the DiffEqFlux.jl and DiffEqSensitivity.jl packages, and the Automatic Differentiation (AD) capabilities provided by the Zygote.jl package that is integrated in DiffEqFlux.jl. CTPG_train()
presumes the consistency of the functions provided as its input arguments with the AD tool, hence, the dynamics and cost functions should maintain their transparence against AD tools.
The optimisation (training) problem minimises the cost function defined over deterministic samples of the initial plant state x₀
and the reference r
by performing ensemble simulation based on parallelised computation.
The signals are defined as described below:
t
: timex
: plant statey
: plant output (= sensor output)x_c
: controller stateu
: plant input (= controller output)r
: exogenous referencex_aug
: augmented forward dynamics state (=[x; x_c; ∫cost_running]
)p_NN
: neural network parameter
The arguments should be provided as explained below:
dynamics_plant
: Describes the dynamics of the plant to be controlled. Input argumentsx
andu
should be of Vector type.dynamics_controller
: Describes the dynamics of the controller that includes neural networks components. Input argumentsx_c
,y
,r
, andp_NN
should be of Vector type.dynamics_sensor
: Describes the dynamics of the sensor that measures output variables fed to the controller. Input argumentsx
should be of Vector type:cost_running
: Describes the running cost defined as the integrand of the Lagrange-form continuous functional. Input argumentsx
,y
,u
, andr
should be of Vector type.cost_terminal
: Describes the terminal cost defined as the Mayer-form problem cost function. Defines a Bolza-form problem along withcost_running
. Input argumentsx_f
andr
should be of Vector type.cost_regularisor
: Describes the regularisation term appended to the cost (loss) function. Input argumentp_NN
should be of Vector type.policy_NN
: The neural networks entering into the controller dynamics. DiffEqFlux-based FastChain is recommended for its construction.scenario
: Contains the parameters related with the ensemble-based training scenarios.ensemble
: A vector of the initial plant statex₀
and the referencer
constituting the trajectory realisations.t_span
: Time span for forward-pass integrationt_save
: Array of time points to be saved while solving ODE. Typically defined ast_save = t_span[1]:Δt_save:t_span[2]
dim_x
:length(x)
dim_x_c
:length(x_c)
The keyword arguments should be provided as explained below:
solve_alg
: The algorithm used for solving ODEs. Default value isTsit5()
sense_alg
: The algorithm used for adjoint sensitivity analysis. Default value isInterpolatingAdjoint(autojacvec = ZygoteVJP())
, because the control problems usually render theBacksolveAdjoint()
unstable. The vjp choiceautojacvec = ReverseDiffVJP(true)
is usually faster thanZygoteVJP()
, when the ODE function does not have any branching inside. Please refer to the DiffEqFlux documentation for further details.ensemble_alg
: The algorithm used for handling ensemble of ODEs. Default value isEnsembleThreads()
for multi-threaded computation in CPU.opt_1
: The algorithm used for the first phase of optimisation which rapidly delivers the parameter to a favourable region around a local minimum. Default value isADAM(0.01)
.opt_2
: The algorithm used for the second phase of opitmisaiton. Defalut value isLBFGS()
which refines the result of the first phase to find a more precise minimum. Please refer to the DiffEqFlux documentation for further details about two-phase composition of optimisers.maxiters_1
: The maximum number of iterations allowed for the first phase of optimisation withopt_1
. Defalut value is100
.maxiters_2
: The maximum number of iterations allowed for the second phase of optimisation withopt_2
. Defalut value is100
.progress_plot
: The indicator to plot the state history for a nominal condition among the ensemble during the learning process. Default value istrue
.i_nominal
: The index to select the case to plot usingprogress_plot
during optimisation process from theensemble
defined inscenario
. Defalut value isnothing
.p_NN_0
: Initial value of the NN parameters supplied by the user to bypass random initialisation ofp_NN
or to continue optimisation from the previous result. Defalut value isnothing
.solve_kwargs...
: Additional keyword arguments that are passed onto the ODE solver.
CTPG_train()
returns the following outputs:
result
: The final result of parameter optimisation.fwd_ensemble_sol
: The ensemble solution of forward simulation using the final neural network parameters.loss_history
: The history of loss function evaluated at each iteration.