import PyPlot
export lsimplot, stepplot, impulseplot, bodeplot, marginplot, nyquistplot, sigmaplot, pzplot, rlocusplot

@doc """`lsimplot(sys, u, t[, x0, method])`

`lsimplot(LTISystem[sys1, sys2...], u, t[, x0, method])`

Calculate the time response of the `LTISystem`(s) `sys` to input `u`. If `x0` is
ommitted, a zero vector is used.

Continuous time systems are discretized before simulation. By default, the
method is chosen based on the smoothness of the input signal. Optionally, the
`method` parameter can be specified as either `:zoh` or `:foh`.""" ->
function lsimplot(systems::Vector{LTISystem}, u::AbstractVecOrMat,
        t::AbstractVector, x0::VecOrMat=zeros(systems[1].nx, 1),
        method::Symbol=_issmooth(u) ? :foh : :zoh)
    if !_same_io_dims(systems...)
        error("All systems must have the same input/output dimensions")
    end
    ny, nu = size(systems[1])
    fig, axes = PyPlot.subplots(ny, 1, sharex=true)
    if ny == 1
        axes = [axes]
    end
    for s = systems
        y = lsim(s, u, t, x0, method)[1]
        for i=1:ny
            ax = axes[i]
            ydata = reshape(y[:, i], size(t, 1))
            if iscontinuous(s)
                ax[:plot](t, ydata)
            else
                ax[:step](t, ydata, where="post")
            end
        end
    end
    # Add labels and titles
    fig[:suptitle]("System Response", size=16)
    if ny != 1
        for i=1:ny
            axes[i, 1][:set_ylabel]("To: y($i)", size=12, color="0.30")
        end
    end
    fig[:text](0.5, 0.04, "Time (s)", ha="center", va="center", size=14)
    fig[:text](0.06, 0.5, "Amplitude", ha="center", va="center",
            rotation="vertical", size=14)
    PyPlot.draw()
    return fig
end
lsimplot(sys::LTISystem, u::AbstractVecOrMat, t::AbstractVector, args...) =
        lsimplot(LTISystem[sys], u, t, args...)


for (func, title) = ((:step, "Step Response"), (:impulse, "Impulse Response"))
    funcname = Symbol("$(func)plot")
    @eval begin
        function $funcname(systems::Vector{LTISystem}, Ts_list::Vector, Tf::Real)
            if !_same_io_dims(systems...)
                error("All systems must have the same input/output dimensions")
            end
            ny, nu = size(systems[1])
            fig, temp = PyPlot.subplots(ny, nu, sharex="col", sharey="row")
            # Ensure that `axes` is always a matrix of handles
            # note: changed [temp] to vcat(temp) to fix warning
            axes = ny == 1 ? reshape(vcat(temp), ny, nu) : temp
            for (s, Ts) in zip(systems, Ts_list)
                t = 0:Ts:Tf
                y = ($func)(s, t)[1]
                for i=1:ny
                    for j=1:nu
                        ax = axes[i, j]
                        ydata = reshape(y[:, i, j], size(t, 1))
                        if iscontinuous(s)
                            ax[:plot](t, ydata)
                        else
                            ax[:step](t, ydata, where="post")
                        end
                    end
                end
            end
            # Add labels and titles
            fig[:suptitle]($title, size=16)
            if ny*nu != 1
                for i=1:ny
                    axes[i, 1][:set_ylabel]("To: y($i)", size=12, color="0.30")
                end
                for j=1:nu
                    axes[1, j][:set_title]("From: u($j)", size=12, color="0.30")
                end
            end
            fig[:text](0.5, 0.04, "Time (s)", ha="center", va="center", size=14)
            fig[:text](0.06, 0.5, "Amplitude", ha="center", va="center",
                    rotation="vertical", size=14)
            PyPlot.draw()
            return fig
        end
        $funcname(systems::Vector{LTISystem}, Tf::Real) =
                $funcname(systems, map(_default_Ts, systems), Tf)
        $funcname(systems::Vector{LTISystem}) =
                $funcname(systems, _default_time_data(systems)...)
        $funcname(systems::Vector{LTISystem}, t::AbstractVector) =
                $funcname(systems, repmat([t[2] - t[1]], length(systems)), t[end])
        $funcname(sys::LTISystem, args...) = $funcname(LTISystem[sys], args...)
    end
end

@doc """`stepplot(sys, args...)`, `stepplot(LTISystem[sys1, sys2...], args...)`

Plot the `step` response of the `LTISystem`(s) `sys`. A final time `Tf` or a
time vector `t` can be optionally provided.""" -> stepplot

@doc """`impulseplot(sys, args...)`, `impulseplot(LTISystem[sys1, sys2...], args...)`

Plot the `impulse` response of the `LTISystem`(s) `sys`. A final time `Tf` or a
time vector `t` can be optionally provided.""" -> impulseplot


## FREQUENCY PLOTS ##
@doc """`bodeplot(sys, args...)`, `bodeplot(LTISystem[sys1, sys2...], args...)`

Create a Bode plot of the `LTISystem`(s) `sys`. A frequency vector `w` can be
optionally provided.""" ->
function bodeplot(systems::Vector{LTISystem}, w::AbstractVector)
    if !_same_io_dims(systems...)
        error("All systems must have the same input/output dimensions")
    end
    ny, nu = size(systems[1])
    fig, axes = PyPlot.subplots(2*ny, nu, sharex="col", sharey="row")
    nw = length(w)
    for s = systems
        mag, phase = bode(s, w)[1:2]
        mag = 20*log10(mag)
        for j=1:nu
            for i=1:ny
                magdata = vec(mag[i, j, :])
                if all(magdata .== -Inf)
                    # 0 system, don't plot anything
                    continue
                end
                phasedata = vec(phase[i, j, :])
                axes[2*i - 1, j][:semilogx](w, magdata)
                axes[2*i, j][:semilogx](w, phasedata)
            end
        end
    end
    # Add labels and titles
    fig[:suptitle]("Bode Plot", size=16)
    if ny*nu != 1
        for i=1:2*ny
            div(i+1, 2)
            axes[i, 1][:set_ylabel]("To: y($(div(i + 1, 2)))",
                    size=12, color="0.30")
        end
        for j=1:nu
            axes[1, j][:set_title]("From: u($j)", size=12, color="0.30")
        end
        fig[:text](0.06, 0.5, "Phase (deg), Magnitude (dB)", ha="center",
                va="center", rotation="vertical", size=14)
    else
        axes[1, 1][:set_ylabel]("Magnitude (dB)", size=14)
        axes[2, 1][:set_ylabel]("Phase (deg)", size=14)
    end
    fig[:text](0.5, 0.04, "Frequency (rad/s)", ha="center",
            va="center", size=14)
    PyPlot.draw()
    return fig
end
bodeplot(systems::Vector{LTISystem}) =
    bodeplot(systems, _default_freq_vector(systems, :bode))
bodeplot(sys::LTISystem, args...) = bodeplot(LTISystem[sys], args...)


@doc """`marginplot(sys, w)`

Create a bode plot of the `LTISystem` `sys` with gain and phase margins marked out. A frequency vector `w` can be optionally provided.""" ->
function marginplot(
    sys::TransferFunction,
    w::AbstractVector=_default_freq_vector(LTISystem[sys], :bode)
    )
    !issiso(sys) && error("marginplot only defined for siso systems")

    # Create the bode plot
    mag, phase = bode(sys, w)
    mag = 20*log10(mag)
    magdata = vec(mag) # deal with all(magdata .== Inf)?
    phasedata = vec(phase)

    fig, axes = PyPlot.subplots(2, 1, sharex="col")
    axes[1][:semilogx](w, magdata)
    axes[2][:semilogx](w, phasedata)

    # Make sure axis limits won't change by explicitly setting them to their current values
    axes[1][:axis](axes[1][:axis]())
    axes[2][:axis](axes[2][:axis]())

    # Calculate margins
    pm, wgm, gm, wpm = phasemargin(sys, deg=true)..., gainmargin(sys, dB=true)...
    fig[:suptitle]("Bode Plot\nGm = $(round(gm,1)) dB (at $(round(wpm,1)) rad/s), Pm = $(round(pm,1)) (at $(round(wgm,1)) rad/s)")

    # Calculate the gain at wpm and phase at wgm
    gw180 = 20*log10(abs(evalfr(sys, im*wpm)[1]))
    pwc   = rad2deg(angle(evalfr(sys, im*wgm)[1]))

    # Plot the margins ("?:" is used to skip plotting certain features if gm/pm is undefined)
    glow, ghigh = min(0, gw180), max(0, gw180)
    axes[1][:plot](
                    w[[1,end]], [0, 0],       ":k",
    (!isnan(wgm) ? ([wgm, wgm], [0, -1e6],    ":k") : () )...,
    (!isnan(wpm) ? ([wpm, wpm], [-1e6, glow], ":k",
                    [wpm, wpm], [glow, ghigh], "k") : () )...
    )

    axes[2][:plot](
                    w[[1,end]], [180, 180],   ":k",
                    w[[1,end]],-[180, 180],   ":k",
    (!isnan(wgm) ? ([wgm, wgm], [180, pwc],    "k",
                    [wgm, wgm], [pwc, 1e6],   ":k") : () )...,
    (!isnan(wpm) ? ([wpm, wpm], [180, 1e6],   ":k") : () )...
    )

    PyPlot.draw()
    return fig
end
marginplot(sys::LTISystem, args...) = marginplot(tf(sys))


@doc """`nyquistplot(sys, args...)`, `nyquistplot(LTISystem[sys1, sys2...], args...)`

Create a Nyquist plot of the `LTISystem`(s) `sys`. A frequency vector `w` can be
optionally provided.""" ->
function nyquistplot(systems::Vector{LTISystem}, w::AbstractVector)
    if !_same_io_dims(systems...)
        error("All systems must have the same input/output dimensions")
    end
    ny, nu = size(systems[1])
    nw = length(w)
    fig, temp = PyPlot.subplots(ny, nu, sharex="col", sharey="row")
    # Ensure that `axes` is always a matrix of handles
    axes = ny == 1 ? reshape([temp], ny, nu) : temp
    for s = systems
        re_resp, im_resp = nyquist(s, w)[1:2]
        for j=1:nu
            for i=1:ny
                redata = reshape(re_resp[i, j, :], nw)
                imdata = reshape(im_resp[i, j, :], nw)
                line = axes[i, j][:plot](redata, imdata)[1]
                color = line[:get_color]()
                # Plot the mirror
                ax = axes[i, j]
                ax[:plot](redata, -imdata, color=color)
                # Add arrows at the midpoint
                mp = div(nw, 2)
                ax[:arrow](redata[mp], imdata[mp], redata[mp + 1] - redata[mp],
                        imdata[mp + 1] - imdata[mp], color=color, width=0.003)
                ax[:arrow](redata[mp], -imdata[mp], redata[mp - 1] - redata[mp],
                        -imdata[mp - 1] + imdata[mp], color=color, width=0.003)
            end
        end
    end
    # Add labels and titles
    fig[:suptitle]("Nyquist Plot", size=16)
    if ny*nu != 1
        for i=1:ny
            axes[i, 1][:set_ylabel]("To: y($i)", size=12, color="0.30")
        end
        for j=1:nu
            ax = axes[1, j]
            ax[:set_title]("From: u($j)", size=12, color="0.30")
            # Ensure the x axis includes -1
            xlims = ax[:get_xlim]()
            ax[:set_xlim]([min(-1, xlims[1]), xlims[2]])
        end
    end
    fig[:text](0.06, 0.5, "Imaginary Axis", ha="center", va="center",
            rotation="vertical", size=14)
    fig[:text](0.5, 0.04, "Real Axis", ha="center", va="center", size=14)
    # Add axis ticks
    for ax in axes
        ax[:set_yticks]([0.0], minor=true)
        ax[:yaxis][:grid](true, which="minor")
        ax[:set_xticks]([0.0], minor=true)
        ax[:xaxis][:grid](true, which="minor")
    end
    PyPlot.draw()
    return fig
end
nyquistplot(systems::Vector{LTISystem}) =
    nyquistplot(systems, _default_freq_vector(systems, :nyquist))
nyquistplot(sys::LTISystem, args...) = nyquistplot(LTISystem[sys], args...)

@doc """`sigmaplot(sys, args...)`, `sigmaplot(LTISystem[sys1, sys2...], args...)`

Plot the singular values of the frequency response of the `LTISystem`(s) `sys`. A
frequency vector `w` can be optionally provided.""" ->
function sigmaplot(systems::Vector{LTISystem}, w::AbstractVector)
    if !_same_io_dims(systems...)
        error("All systems must have the same input/output dimensions")
    end
    ny, nu = size(systems[1])
    nw = length(w)
    fig, ax = PyPlot.subplots(1, 1)
    for s = systems
        sv = 20*log10(sigma(s, w)[1])
        # Plot the first singular value, grab the line color, then plot the
        # remaining values all in the same color.
        line = ax[:plot](w, sv[1, :]')[1]
        color = line[:get_color]()
        for i in 2:size(sv, 1)
            ax[:semilogx](w, sv[i, :]', color=color)
        end
    end
    ax[:set_title]("Sigma Plot", size=16)
    ax[:set_xlabel]("Frequency (rad/s)", size=14)
    ax[:set_ylabel]("Singular Values (dB)", size=14)
    PyPlot.draw()
    return fig
end
sigmaplot(systems::Vector{LTISystem}) =
    sigmaplot(systems, _default_freq_vector(systems, :sigma))
sigmaplot(sys::LTISystem, args...) = sigmaplot(LTISystem[sys], args...)


### NOTE: rlocus currently can't convert from statespce to TF ###

@doc """`rlocusplot(sys or LTISystem[sys1, sys2...], k=..., axlimits=..., asymptotes=true/false)`

Create a Root-locus plot of the `LTISystem`(s) `sys` and return the figure and axes as PyPlot objects. A gain vector `k` can be optionally provided. The plot axes can be set manually by providing `axislimits = [xmin, xmax, ymin, ymax]`.""" ->
function rlocusplot(
    systems::Vector{TransferFunction};
    k::AbstractVector=[-1.],
    axlimits::AbstractVector=[NaN, NaN, NaN, NaN],
    asymptotes::Bool=true
    )
    if !_same_io_dims(systems...)
        error("All systems must have the same input/output dimensions")
    end
    ny, nu = size(systems[1],1,2)

    # Get the plot of poles and zeros for the system. Axis limits are currently calculated based on pz-locations. Might want to base it on size of locus instead.
    fig, axes = pzplot(systems, axlimits=axlimits, dodraw=false)

    for s=systems, j=1:nu, i=1:ny
        ax = axes[i,j] # ax = axes[nu*(i-1) + j]
        scurr = s[i,j]
        p, ktemp = rlocus(scurr, k)

        locus = ax[:plot](real(p), imag(p))[1]
        if asymptotes
            # Draw asymptotes
            asy, midpoint = rlasymptotes(scurr.matrix[1].num, scurr.matrix[1].den)
            if !isempty(asy)
                for l=1:length(asy)
                    pmax = min(realmax(Float64)*1e-3, maxabs(p[[1, end],:]))
                    ax[:plot]([midpoint, midpoint+cos(asy[l])*pmax], [0, sin(asy[l])*pmax], ":", color="#C0C0C0")
                end
            end
        end # end asymptotes
    end

    # Add lables and titles
    fig[:suptitle]("Root Locus Plot", size=16)
    if ny*nu != 1
        for i=1:ny
            axes[i, 1][:set_ylabel]("To: y($i)", size=12, color="0.30")
        end
        for j=1:nu
            ax = axes[1, j]
            ax[:set_title]("From: u($j)", size=12, color="0.30")
            ### Set axis limits here ###
        end
    end
    PyPlot.draw()
    return fig, axes
end
rlocusplot(sys::LTISystem; args...) = rlocusplot(TransferFunction[sys]; args...)
rlocusplot(systems::Vector{LTISystem}; args...) = rlocusplot(map(TransferFunction, systems); args...)


@doc """`pzplot(sys; axlimits=...)`, `pzplot(LTISystem[sys1, sys2...; axlimits=...])`

Create a Pole-Zero map of the `LTISystem`(s) `sys` and return the figure and axes as PyPlot objects. The plot axes can be set manually by providing `axislimits = [xmin, xmax, ymin, ymax]`.""" ->
function pzplot(systems::Vector{TransferFunction}; axlimits::AbstractVector=[NaN, NaN, NaN, NaN], dodraw::Bool=true)
    ny, nu = size(systems[1], 1, 2)
    fig, temp = PyPlot.subplots(ny, nu, sharex="col", sharey="row")
    axes = (ny == 1 ? reshape(vcat(temp), ny, nu) : temp)
    picklimits = any(isnan(axlimits))
    for s = systems
        z, p, k = zpkdata(s)
        for i=1:ny, j=1:nu
            ax = axes[i,j]
            pcur, zcur = p[i,j], z[i,j]
            # Plot poles
            ax[:plot](real(pcur), imag(pcur), "xb")

            # Plot zeros
            ax[:plot](real(zcur), imag(zcur), "ob")

            # Plot axes
            bignum = 1e-3*realmax(Float64)
            ax[:plot]([-bignum 0; bignum  0], [0 -bignum; 0 bignum], ":k")

            # Determine appropriate axis axlimits if none are provided
            if picklimits
                center = (!isempty(pcur) ? mean(real(pcur)):0) + (!isempty(zcur) ? mean(real(zcur)):0)
                axismax = 2*max(abs(pcur-center)..., abs(zcur-center)..., 1)
                axlimits[1] = min(axlimits[1], center - axismax)
                axlimits[2] = max(axlimits[2], center + axismax)
                axlimits[3] = min(axlimits[3], -axismax)
                axlimits[4] = max(axlimits[4], axismax)
            end
            ax[:axis](axlimits)
        end

        # Set axis axlimits to 1 if they are (close to) zero
        abs(axlimits[1])<1e-3 && abs(axlimits[2])<1e-3 && (axlimits[1]=-1; axlimits[2]=1)
        abs(axlimits[3])<1e-3 && abs(axlimits[4])<1e-3 && (axlimits[3]=-1; axlimits[4]=1)
    end
    fig[:suptitle]("Pole-Zero map", size=16)
    fig[:text](0.06, 0.5, "Imaginary Axis", ha="center", va="center", rotation="vertical", size=14)
    fig[:text](0.5, 0.04, "Real Axis", ha="center", va="center", size=14)
    dodraw && PyPlot.draw()
    return fig, axes
end
pzplot(sys::LTISystem; args...) = pzplot(TransferFunction[sys]; args...)
pzplot(sys::Vector{LTISystem}; args...) = pzplot(map(TransferFunction, sys); args...)


# HELPERS:

function _same_io_dims(systems::LTISystem...)
    sizes = map(size, systems)
    return reduce(&, [s == sizes[1] for s in sizes])
end

function _default_time_data(systems::Vector{LTISystem})
    sample_times = [_default_Ts(sys) for sys in systems]
    final_times = [_default_Tf(systems[i], sample_times[i]) for i in 1:length(systems)]

    return sample_times, maximum(final_times)
end
_default_time_data(sys::LTISystem) = _default_time_data(LTISystem[sys])
