From 101f176ae4e15d019b570ad5b37794e4bb1fd8ce Mon Sep 17 00:00:00 2001 From: Eric Blake Date: Jan 09 2014 21:13:01 +0000 Subject: maint: improve VIR_ERR_INVALID_STREAM usage For streams validation, we weren't consistent on whether to use VIR_FROM_NONE or VIR_FROM_STREAMS. Furthermore, in many API, we want to ensure that a stream is tied to the same connection as the other object we are operating on; while other API failed to validate the stream at all. And the difference between VIR_IS_STREAM and VIR_IS_CONNECTED_STREAM is moot; as in commit 6e130ddc, we know that reference counting means a valid stream will always be tied to a valid connection. Similar to previous patches, use a common macro to make it nicer. * src/datatypes.h (virCheckStreamReturn, virCheckStreamGoto): New macros. (VIR_IS_STREAM, VIR_IS_CONNECTED_STREAM): Drop unused macros. * src/libvirt.c: Use macro throughout. (virLibStreamError): Drop unused macro. Signed-off-by: Eric Blake --- diff --git a/src/datatypes.h b/src/datatypes.h index 29a1096..74b4a97 100644 --- a/src/datatypes.h +++ b/src/datatypes.h @@ -192,10 +192,31 @@ extern virClassPtr virStoragePoolClass; } \ } while (0) -# define VIR_IS_STREAM(obj) \ - (virObjectIsClass((obj), virStreamClass)) -# define VIR_IS_CONNECTED_STREAM(obj) \ - (VIR_IS_STREAM(obj) && virObjectIsClass((obj)->conn, virConnectClass)) +# define virCheckStreamReturn(obj, retval) \ + do { \ + virStreamPtr _st = (obj); \ + if (!virObjectIsClass(_st, virStreamClass) || \ + !virObjectIsClass(_st->conn, virConnectClass)) { \ + virReportErrorHelper(VIR_FROM_STREAMS, \ + VIR_ERR_INVALID_STREAM, \ + __FILE__, __FUNCTION__, __LINE__, \ + __FUNCTION__); \ + virDispatchError(NULL); \ + return retval; \ + } \ + } while (0) +# define virCheckStreamGoto(obj, label) \ + do { \ + virStreamPtr _st = (obj); \ + if (!virObjectIsClass(_st, virStreamClass) || \ + !virObjectIsClass(_st->conn, virConnectClass)) { \ + virReportErrorHelper(VIR_FROM_STREAMS, \ + VIR_ERR_INVALID_STREAM, \ + __FILE__, __FUNCTION__, __LINE__, \ + __FUNCTION__); \ + goto label; \ + } \ + } while (0) # define VIR_IS_NWFILTER(obj) \ (virObjectIsClass((obj), virNWFilterClass)) diff --git a/src/libvirt.c b/src/libvirt.c index 5be360d..153c152 100644 --- a/src/libvirt.c +++ b/src/libvirt.c @@ -528,9 +528,6 @@ DllMain(HINSTANCE instance ATTRIBUTE_UNUSED, #define virLibDomainError(code, ...) \ virReportErrorHelper(VIR_FROM_DOM, code, __FILE__, \ __FUNCTION__, __LINE__, __VA_ARGS__) -#define virLibStreamError(code, ...) \ - virReportErrorHelper(VIR_FROM_STREAMS, code, __FILE__, \ - __FUNCTION__, __LINE__, __VA_ARGS__) #define virLibNWFilterError(code, ...) \ virReportErrorHelper(VIR_FROM_NWFILTER, code, __FILE__, \ __FUNCTION__, __LINE__, __VA_ARGS__) @@ -3087,14 +3084,18 @@ virDomainScreenshot(virDomainPtr domain, virResetLastError(); virCheckDomainReturn(domain, NULL); - if (!VIR_IS_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_STREAM, __FUNCTION__); - return NULL; + virCheckStreamGoto(stream, error); + virCheckReadOnlyGoto(domain->conn->flags, error); + + if (domain->conn != stream->conn) { + virReportInvalidArg(stream, + _("stream in %s must match connection of domain '%s'"), + __FUNCTION__, domain->name); + goto error; } - virCheckReadOnlyGoto(domain->conn->flags | stream->conn->flags, error); if (domain->conn->driver->domainScreenshot) { - char * ret; + char *ret; ret = domain->conn->driver->domainScreenshot(domain, stream, screen, flags); @@ -13664,14 +13665,16 @@ virStorageVolDownload(virStorageVolPtr vol, virResetLastError(); virCheckStorageVolReturn(vol, -1); + virCheckStreamGoto(stream, error); + virCheckReadOnlyGoto(vol->conn->flags, error); - if (!VIR_IS_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_STREAM, __FUNCTION__); - return -1; + if (vol->conn != stream->conn) { + virReportInvalidArg(stream, + _("stream in %s must match connection of volume '%s'"), + __FUNCTION__, vol->name); + goto error; } - virCheckReadOnlyGoto(vol->conn->flags | stream->conn->flags, error); - if (vol->conn->storageDriver && vol->conn->storageDriver->storageVolDownload) { int ret; @@ -13728,14 +13731,16 @@ virStorageVolUpload(virStorageVolPtr vol, virResetLastError(); virCheckStorageVolReturn(vol, -1); + virCheckStreamGoto(stream, error); + virCheckReadOnlyGoto(vol->conn->flags, error); - if (!VIR_IS_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_STREAM, __FUNCTION__); - return -1; + if (vol->conn != stream->conn) { + virReportInvalidArg(stream, + _("stream in %s must match connection of volume '%s'"), + __FUNCTION__, vol->name); + goto error; } - virCheckReadOnlyGoto(vol->conn->flags | stream->conn->flags, error); - if (vol->conn->storageDriver && vol->conn->storageDriver->storageVolUpload) { int ret; @@ -15661,11 +15666,8 @@ virStreamRef(virStreamPtr stream) virResetLastError(); - if ((!VIR_IS_CONNECTED_STREAM(stream))) { - virLibConnError(VIR_ERR_INVALID_STREAM, __FUNCTION__); - virDispatchError(NULL); - return -1; - } + virCheckStreamReturn(stream, -1); + virObjectRef(stream); return 0; } @@ -15744,12 +15746,7 @@ virStreamSend(virStreamPtr stream, virResetLastError(); - if (!VIR_IS_CONNECTED_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__); - virDispatchError(NULL); - return -1; - } - + virCheckStreamReturn(stream, -1); virCheckNonNullArgGoto(data, error); if (stream->driver && @@ -15842,12 +15839,7 @@ virStreamRecv(virStreamPtr stream, virResetLastError(); - if (!VIR_IS_CONNECTED_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__); - virDispatchError(NULL); - return -1; - } - + virCheckStreamReturn(stream, -1); virCheckNonNullArgGoto(data, error); if (stream->driver && @@ -15921,12 +15913,7 @@ virStreamSendAll(virStreamPtr stream, virResetLastError(); - if (!VIR_IS_CONNECTED_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__); - virDispatchError(NULL); - return -1; - } - + virCheckStreamReturn(stream, -1); virCheckNonNullArgGoto(handler, cleanup); if (stream->flags & VIR_STREAM_NONBLOCK) { @@ -16019,12 +16006,7 @@ virStreamRecvAll(virStreamPtr stream, virResetLastError(); - if (!VIR_IS_CONNECTED_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__); - virDispatchError(NULL); - return -1; - } - + virCheckStreamReturn(stream, -1); virCheckNonNullArgGoto(handler, cleanup); if (stream->flags & VIR_STREAM_NONBLOCK) { @@ -16093,11 +16075,7 @@ virStreamEventAddCallback(virStreamPtr stream, virResetLastError(); - if (!VIR_IS_CONNECTED_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__); - virDispatchError(NULL); - return -1; - } + virCheckStreamReturn(stream, -1); if (stream->driver && stream->driver->streamEventAddCallback) { @@ -16136,11 +16114,7 @@ virStreamEventUpdateCallback(virStreamPtr stream, virResetLastError(); - if (!VIR_IS_CONNECTED_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__); - virDispatchError(NULL); - return -1; - } + virCheckStreamReturn(stream, -1); if (stream->driver && stream->driver->streamEventUpdateCallback) { @@ -16174,11 +16148,7 @@ virStreamEventRemoveCallback(virStreamPtr stream) virResetLastError(); - if (!VIR_IS_CONNECTED_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__); - virDispatchError(NULL); - return -1; - } + virCheckStreamReturn(stream, -1); if (stream->driver && stream->driver->streamEventRemoveCallback) { @@ -16219,11 +16189,7 @@ virStreamFinish(virStreamPtr stream) virResetLastError(); - if (!VIR_IS_CONNECTED_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__); - virDispatchError(NULL); - return -1; - } + virCheckStreamReturn(stream, -1); if (stream->driver && stream->driver->streamFinish) { @@ -16262,11 +16228,7 @@ virStreamAbort(virStreamPtr stream) virResetLastError(); - if (!VIR_IS_CONNECTED_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__); - virDispatchError(NULL); - return -1; - } + virCheckStreamReturn(stream, -1); if (!stream->driver) { VIR_DEBUG("aborting unused stream"); @@ -16310,11 +16272,7 @@ virStreamFree(virStreamPtr stream) virResetLastError(); - if (!VIR_IS_CONNECTED_STREAM(stream)) { - virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__); - virDispatchError(NULL); - return -1; - } + virCheckStreamReturn(stream, -1); /* XXX Enforce shutdown before free'ing resources ? */ @@ -19332,8 +19290,16 @@ virDomainOpenConsole(virDomainPtr dom, virCheckDomainReturn(dom, -1); conn = dom->conn; + virCheckStreamGoto(st, error); virCheckReadOnlyGoto(conn->flags, error); + if (conn != st->conn) { + virReportInvalidArg(st, + _("stream in %s must match connection of domain '%s'"), + __FUNCTION__, dom->name); + goto error; + } + if (conn->driver->domainOpenConsole) { int ret; ret = conn->driver->domainOpenConsole(dom, dev_name, st, flags); @@ -19388,8 +19354,16 @@ virDomainOpenChannel(virDomainPtr dom, virCheckDomainReturn(dom, -1); conn = dom->conn; + virCheckStreamGoto(st, error); virCheckReadOnlyGoto(conn->flags, error); + if (conn != st->conn) { + virReportInvalidArg(st, + _("stream in %s must match connection of domain '%s'"), + __FUNCTION__, dom->name); + goto error; + } + if (conn->driver->domainOpenChannel) { int ret; ret = conn->driver->domainOpenChannel(dom, name, st, flags);